inet: add RCU protection to inet->opt
[cascardo/linux.git] / net / ipv4 / cipso_ipv4.c
index 094e150..2b3c23c 100644 (file)
@@ -112,7 +112,7 @@ int cipso_v4_rbm_strictvalid = 1;
 /* The maximum number of category ranges permitted in the ranged category tag
  * (tag #5).  You may note that the IETF draft states that the maximum number
  * of category ranges is 7, but if the low end of the last category range is
- * zero then it is possibile to fit 8 category ranges because the zero should
+ * zero then it is possible to fit 8 category ranges because the zero should
  * be omitted. */
 #define CIPSO_V4_TAG_RNG_CAT_MAX      8
 
@@ -438,7 +438,7 @@ cache_add_failure:
  *
  * Description:
  * Search the DOI definition list for a DOI definition with a DOI value that
- * matches @doi.  The caller is responsibile for calling rcu_read_[un]lock().
+ * matches @doi.  The caller is responsible for calling rcu_read_[un]lock().
  * Returns a pointer to the DOI definition on success and NULL on failure.
  */
 static struct cipso_v4_doi *cipso_v4_doi_search(u32 doi)
@@ -1293,7 +1293,7 @@ static int cipso_v4_gentag_rbm(const struct cipso_v4_doi *doi_def,
                        return ret_val;
 
                /* This will send packets using the "optimized" format when
-                * possibile as specified in  section 3.4.2.6 of the
+                * possible as specified in  section 3.4.2.6 of the
                 * CIPSO draft. */
                if (cipso_v4_rbm_optfmt && ret_val > 0 && ret_val <= 10)
                        tag_len = 14;
@@ -1752,7 +1752,7 @@ validate_return:
 }
 
 /**
- * cipso_v4_error - Send the correct reponse for a bad packet
+ * cipso_v4_error - Send the correct response for a bad packet
  * @skb: the packet
  * @error: the error code
  * @gateway: CIPSO gateway flag
@@ -1857,6 +1857,11 @@ static int cipso_v4_genopt(unsigned char *buf, u32 buf_len,
        return CIPSO_V4_HDR_LEN + ret_val;
 }
 
+static void opt_kfree_rcu(struct rcu_head *head)
+{
+       kfree(container_of(head, struct ip_options_rcu, rcu));
+}
+
 /**
  * cipso_v4_sock_setattr - Add a CIPSO option to a socket
  * @sk: the socket
@@ -1879,7 +1884,7 @@ int cipso_v4_sock_setattr(struct sock *sk,
        unsigned char *buf = NULL;
        u32 buf_len;
        u32 opt_len;
-       struct ip_options *opt = NULL;
+       struct ip_options_rcu *old, *opt = NULL;
        struct inet_sock *sk_inet;
        struct inet_connection_sock *sk_conn;
 
@@ -1915,22 +1920,25 @@ int cipso_v4_sock_setattr(struct sock *sk,
                ret_val = -ENOMEM;
                goto socket_setattr_failure;
        }
-       memcpy(opt->__data, buf, buf_len);
-       opt->optlen = opt_len;
-       opt->cipso = sizeof(struct iphdr);
+       memcpy(opt->opt.__data, buf, buf_len);
+       opt->opt.optlen = opt_len;
+       opt->opt.cipso = sizeof(struct iphdr);
        kfree(buf);
        buf = NULL;
 
        sk_inet = inet_sk(sk);
+
+       old = rcu_dereference_protected(sk_inet->inet_opt, sock_owned_by_user(sk));
        if (sk_inet->is_icsk) {
                sk_conn = inet_csk(sk);
-               if (sk_inet->opt)
-                       sk_conn->icsk_ext_hdr_len -= sk_inet->opt->optlen;
-               sk_conn->icsk_ext_hdr_len += opt->optlen;
+               if (old)
+                       sk_conn->icsk_ext_hdr_len -= old->opt.optlen;
+               sk_conn->icsk_ext_hdr_len += opt->opt.optlen;
                sk_conn->icsk_sync_mss(sk, sk_conn->icsk_pmtu_cookie);
        }
-       opt = xchg(&sk_inet->opt, opt);
-       kfree(opt);
+       rcu_assign_pointer(sk_inet->inet_opt, opt);
+       if (old)
+               call_rcu(&old->rcu, opt_kfree_rcu);
 
        return 0;
 
@@ -1960,7 +1968,7 @@ int cipso_v4_req_setattr(struct request_sock *req,
        unsigned char *buf = NULL;
        u32 buf_len;
        u32 opt_len;
-       struct ip_options *opt = NULL;
+       struct ip_options_rcu *opt = NULL;
        struct inet_request_sock *req_inet;
 
        /* We allocate the maximum CIPSO option size here so we are probably
@@ -1988,15 +1996,16 @@ int cipso_v4_req_setattr(struct request_sock *req,
                ret_val = -ENOMEM;
                goto req_setattr_failure;
        }
-       memcpy(opt->__data, buf, buf_len);
-       opt->optlen = opt_len;
-       opt->cipso = sizeof(struct iphdr);
+       memcpy(opt->opt.__data, buf, buf_len);
+       opt->opt.optlen = opt_len;
+       opt->opt.cipso = sizeof(struct iphdr);
        kfree(buf);
        buf = NULL;
 
        req_inet = inet_rsk(req);
        opt = xchg(&req_inet->opt, opt);
-       kfree(opt);
+       if (opt)
+               call_rcu(&opt->rcu, opt_kfree_rcu);
 
        return 0;
 
@@ -2016,34 +2025,34 @@ req_setattr_failure:
  * values on failure.
  *
  */
-static int cipso_v4_delopt(struct ip_options **opt_ptr)
+static int cipso_v4_delopt(struct ip_options_rcu **opt_ptr)
 {
        int hdr_delta = 0;
-       struct ip_options *opt = *opt_ptr;
+       struct ip_options_rcu *opt = *opt_ptr;
 
-       if (opt->srr || opt->rr || opt->ts || opt->router_alert) {
+       if (opt->opt.srr || opt->opt.rr || opt->opt.ts || opt->opt.router_alert) {
                u8 cipso_len;
                u8 cipso_off;
                unsigned char *cipso_ptr;
                int iter;
                int optlen_new;
 
-               cipso_off = opt->cipso - sizeof(struct iphdr);
-               cipso_ptr = &opt->__data[cipso_off];
+               cipso_off = opt->opt.cipso - sizeof(struct iphdr);
+               cipso_ptr = &opt->opt.__data[cipso_off];
                cipso_len = cipso_ptr[1];
 
-               if (opt->srr > opt->cipso)
-                       opt->srr -= cipso_len;
-               if (opt->rr > opt->cipso)
-                       opt->rr -= cipso_len;
-               if (opt->ts > opt->cipso)
-                       opt->ts -= cipso_len;
-               if (opt->router_alert > opt->cipso)
-                       opt->router_alert -= cipso_len;
-               opt->cipso = 0;
+               if (opt->opt.srr > opt->opt.cipso)
+                       opt->opt.srr -= cipso_len;
+               if (opt->opt.rr > opt->opt.cipso)
+                       opt->opt.rr -= cipso_len;
+               if (opt->opt.ts > opt->opt.cipso)
+                       opt->opt.ts -= cipso_len;
+               if (opt->opt.router_alert > opt->opt.cipso)
+                       opt->opt.router_alert -= cipso_len;
+               opt->opt.cipso = 0;
 
                memmove(cipso_ptr, cipso_ptr + cipso_len,
-                       opt->optlen - cipso_off - cipso_len);
+                       opt->opt.optlen - cipso_off - cipso_len);
 
                /* determining the new total option length is tricky because of
                 * the padding necessary, the only thing i can think to do at
@@ -2052,21 +2061,21 @@ static int cipso_v4_delopt(struct ip_options **opt_ptr)
                 * from there we can determine the new total option length */
                iter = 0;
                optlen_new = 0;
-               while (iter < opt->optlen)
-                       if (opt->__data[iter] != IPOPT_NOP) {
-                               iter += opt->__data[iter + 1];
+               while (iter < opt->opt.optlen)
+                       if (opt->opt.__data[iter] != IPOPT_NOP) {
+                               iter += opt->opt.__data[iter + 1];
                                optlen_new = iter;
                        } else
                                iter++;
-               hdr_delta = opt->optlen;
-               opt->optlen = (optlen_new + 3) & ~3;
-               hdr_delta -= opt->optlen;
+               hdr_delta = opt->opt.optlen;
+               opt->opt.optlen = (optlen_new + 3) & ~3;
+               hdr_delta -= opt->opt.optlen;
        } else {
                /* only the cipso option was present on the socket so we can
                 * remove the entire option struct */
                *opt_ptr = NULL;
-               hdr_delta = opt->optlen;
-               kfree(opt);
+               hdr_delta = opt->opt.optlen;
+               call_rcu(&opt->rcu, opt_kfree_rcu);
        }
 
        return hdr_delta;
@@ -2083,15 +2092,15 @@ static int cipso_v4_delopt(struct ip_options **opt_ptr)
 void cipso_v4_sock_delattr(struct sock *sk)
 {
        int hdr_delta;
-       struct ip_options *opt;
+       struct ip_options_rcu *opt;
        struct inet_sock *sk_inet;
 
        sk_inet = inet_sk(sk);
-       opt = sk_inet->opt;
-       if (opt == NULL || opt->cipso == 0)
+       opt = rcu_dereference_protected(sk_inet->inet_opt, 1);
+       if (opt == NULL || opt->opt.cipso == 0)
                return;
 
-       hdr_delta = cipso_v4_delopt(&sk_inet->opt);
+       hdr_delta = cipso_v4_delopt(&sk_inet->inet_opt);
        if (sk_inet->is_icsk && hdr_delta > 0) {
                struct inet_connection_sock *sk_conn = inet_csk(sk);
                sk_conn->icsk_ext_hdr_len -= hdr_delta;
@@ -2109,12 +2118,12 @@ void cipso_v4_sock_delattr(struct sock *sk)
  */
 void cipso_v4_req_delattr(struct request_sock *req)
 {
-       struct ip_options *opt;
+       struct ip_options_rcu *opt;
        struct inet_request_sock *req_inet;
 
        req_inet = inet_rsk(req);
        opt = req_inet->opt;
-       if (opt == NULL || opt->cipso == 0)
+       if (opt == NULL || opt->opt.cipso == 0)
                return;
 
        cipso_v4_delopt(&req_inet->opt);
@@ -2184,14 +2193,18 @@ getattr_return:
  */
 int cipso_v4_sock_getattr(struct sock *sk, struct netlbl_lsm_secattr *secattr)
 {
-       struct ip_options *opt;
+       struct ip_options_rcu *opt;
+       int res = -ENOMSG;
 
-       opt = inet_sk(sk)->opt;
-       if (opt == NULL || opt->cipso == 0)
-               return -ENOMSG;
-
-       return cipso_v4_getattr(opt->__data + opt->cipso - sizeof(struct iphdr),
-                               secattr);
+       rcu_read_lock();
+       opt = rcu_dereference(inet_sk(sk)->inet_opt);
+       if (opt && opt->opt.cipso)
+               res = cipso_v4_getattr(opt->opt.__data +
+                                               opt->opt.cipso -
+                                               sizeof(struct iphdr),
+                                      secattr);
+       rcu_read_unlock();
+       return res;
 }
 
 /**