diff --git a/kernel/audit.c b/kernel/audit.c index a871bf8..9953dbe 100644 --- a/kernel/audit.c +++ b/kernel/audit.c @@ -110,7 +110,7 @@ struct audit_net { * @pid: auditd PID * @portid: netlink portid * @net: the associated network namespace - * @lock: spinlock to protect write access + * @rcu: rcu head * * Description: * This struct is RCU protected; you must either hold the RCU lock for reading @@ -120,8 +120,9 @@ static struct auditd_connection { int pid; u32 portid; struct net *net; - spinlock_t lock; -} auditd_conn; + struct rcu_head rcu; +} *auditd_conn, null_conn; +static DEFINE_SPINLOCK(auditd_conn_lock); /* If audit_rate_limit is non-zero, limit the rate of sending audit records * to that number per second. This prevents DoS attacks, but results in @@ -223,9 +224,11 @@ struct audit_reply { int auditd_test_task(const struct task_struct *task) { int rc; + pid_t pid; rcu_read_lock(); - rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0); + pid = rcu_dereference(auditd_conn)->pid; + rc = (pid && task->tgid == pid ? 1 : 0); rcu_read_unlock(); return rc; @@ -426,30 +429,39 @@ static int audit_set_failure(u32 state) return audit_do_config_change("audit_failure", &audit_failure, state); } +static void auditd_conn_free(struct rcu_head *rcu) +{ + struct auditd_connection *cn = container_of(rcu, struct auditd_connection, rcu); + + if (cn == &null_conn) + return; + if (cn->net) + put_net(cn->net); + kfree(cn); +} + /** - * auditd_set - Set/Reset the auditd connection state - * @pid: auditd PID - * @portid: auditd netlink portid - * @net: auditd network namespace pointer + * auditd_set - Set the auditd connection state + * @new_coon: the new auditd connection * * Description: * This function will obtain and drop network namespace references as * necessary. */ -static void auditd_set(int pid, u32 portid, struct net *net) +static void auditd_set(struct auditd_connection *new_conn) { + struct auditd_connection *old_conn; unsigned long flags; - spin_lock_irqsave(&auditd_conn.lock, flags); - auditd_conn.pid = pid; - auditd_conn.portid = portid; - if (auditd_conn.net) - put_net(auditd_conn.net); - if (net) - auditd_conn.net = get_net(net); - else - auditd_conn.net = NULL; - spin_unlock_irqrestore(&auditd_conn.lock, flags); + if (new_conn->net) + get_net(new_conn->net); + spin_lock_irqsave(&auditd_conn_lock, flags); + old_conn = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, new_conn); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + call_rcu(&old_conn->rcu, auditd_conn_free); } /** @@ -548,8 +560,8 @@ static void auditd_reset(void) /* if it isn't already broken, break the connection */ rcu_read_lock(); - if (auditd_conn.pid) - auditd_set(0, 0, NULL); + if (rcu_dereference(auditd_conn)->pid) + auditd_set(&null_conn); rcu_read_unlock(); /* flush all of the main and retry queues to the hold queue */ @@ -585,15 +597,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) * section netlink_unicast() should safely return an error */ rcu_read_lock(); - if (!auditd_conn.pid) { + if (!rcu_dereference(auditd_conn)->pid) { rcu_read_unlock(); rc = -ECONNREFUSED; goto err; } - net = auditd_conn.net; + net = rcu_dereference(auditd_conn)->net; get_net(net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = rcu_dereference(auditd_conn)->portid; rcu_read_unlock(); rc = netlink_unicast(sk, skb, portid, 0); @@ -735,14 +747,14 @@ static int kauditd_thread(void *dummy) while (!kthread_should_stop()) { /* NOTE: see the lock comments in auditd_send_unicast_skb() */ rcu_read_lock(); - if (!auditd_conn.pid) { + if (!rcu_dereference(auditd_conn)->pid) { rcu_read_unlock(); goto main_queue; } - net = auditd_conn.net; + net = rcu_dereference(auditd_conn)->net; get_net(net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = rcu_dereference(auditd_conn)->portid; rcu_read_unlock(); /* attempt to flush the hold queue */ @@ -1103,7 +1115,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) s.enabled = audit_enabled; s.failure = audit_failure; rcu_read_lock(); - s.pid = auditd_conn.pid; + s.pid = rcu_dereference(auditd_conn)->pid; rcu_read_unlock(); s.rate_limit = audit_rate_limit; s.backlog_limit = audit_backlog_limit; @@ -1139,17 +1151,22 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) int new_pid = s.pid; pid_t auditd_pid; pid_t requesting_pid = task_tgid_vnr(current); + struct auditd_connection *new; + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; /* test the auditd connection */ audit_replace(requesting_pid); rcu_read_lock(); - auditd_pid = auditd_conn.pid; + auditd_pid = rcu_dereference(auditd_conn)->pid; /* only the current auditd can unregister itself */ if ((!new_pid) && (requesting_pid != auditd_pid)) { rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); + kfree(new); return -EACCES; } /* replacing a healthy auditd is not allowed */ @@ -1157,6 +1174,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); + kfree(new); return -EEXIST; } rcu_read_unlock(); @@ -1166,10 +1184,11 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) auditd_pid, 1); if (new_pid) { + new->pid = new_pid; + new->portid = NETLINK_CB(skb).portid; + new->net = sock_net(NETLINK_CB(skb).sk); /* register a new auditd connection */ - auditd_set(new_pid, - NETLINK_CB(skb).portid, - sock_net(NETLINK_CB(skb).sk)); + auditd_set(new); /* try to process any backlog */ wake_up_interruptible(&kauditd_wait); } else @@ -1448,7 +1467,7 @@ static void __net_exit audit_net_exit(struct net *net) struct audit_net *aunet = net_generic(net, audit_net_id); rcu_read_lock(); - if (net == auditd_conn.net) + if (net == rcu_dereference(auditd_conn)->net) auditd_reset(); rcu_read_unlock(); @@ -1470,8 +1489,9 @@ static int __init audit_init(void) if (audit_initialized == AUDIT_DISABLED) return 0; - memset(&auditd_conn, 0, sizeof(auditd_conn)); - spin_lock_init(&auditd_conn.lock); + auditd_conn = kzalloc(sizeof(*auditd_conn), GFP_KERNEL); + if (!auditd_conn) + return -ENOMEM; skb_queue_head_init(&audit_queue); skb_queue_head_init(&audit_retry_queue);