sock: convert sk_peek_offset functions to WRITE_ONCE
[cascardo/linux.git] / include / net / sock.h
index 255d3e0..09aec75 100644 (file)
@@ -178,7 +178,7 @@ struct sock_common {
        int                     skc_bound_dev_if;
        union {
                struct hlist_node       skc_bind_node;
-               struct hlist_nulls_node skc_portaddr_node;
+               struct hlist_node       skc_portaddr_node;
        };
        struct proto            *skc_prot;
        possible_net_t          skc_net;
@@ -438,6 +438,7 @@ struct sock {
                                                  struct sk_buff *skb);
        void                    (*sk_destruct)(struct sock *sk);
        struct sock_reuseport __rcu     *sk_reuseport_cb;
+       struct rcu_head         sk_rcu;
 };
 
 #define __sk_user_data(sk) ((*((void __rcu **)&(sk)->sk_user_data)))
@@ -458,26 +459,28 @@ struct sock {
 
 static inline int sk_peek_offset(struct sock *sk, int flags)
 {
-       if ((flags & MSG_PEEK) && (sk->sk_peek_off >= 0))
-               return sk->sk_peek_off;
-       else
-               return 0;
+       if (unlikely(flags & MSG_PEEK)) {
+               s32 off = READ_ONCE(sk->sk_peek_off);
+               if (off >= 0)
+                       return off;
+       }
+
+       return 0;
 }
 
 static inline void sk_peek_offset_bwd(struct sock *sk, int val)
 {
-       if (sk->sk_peek_off >= 0) {
-               if (sk->sk_peek_off >= val)
-                       sk->sk_peek_off -= val;
-               else
-                       sk->sk_peek_off = 0;
+       s32 off = READ_ONCE(sk->sk_peek_off);
+
+       if (unlikely(off >= 0)) {
+               off = max_t(s32, off - val, 0);
+               WRITE_ONCE(sk->sk_peek_off, off);
        }
 }
 
 static inline void sk_peek_offset_fwd(struct sock *sk, int val)
 {
-       if (sk->sk_peek_off >= 0)
-               sk->sk_peek_off += val;
+       sk_peek_offset_bwd(sk, -val);
 }
 
 /*
@@ -669,18 +672,18 @@ static inline void sk_add_bind_node(struct sock *sk,
        hlist_for_each_entry(__sk, list, sk_bind_node)
 
 /**
- * sk_nulls_for_each_entry_offset - iterate over a list at a given struct offset
+ * sk_for_each_entry_offset_rcu - iterate over a list at a given struct offset
  * @tpos:      the type * to use as a loop cursor.
  * @pos:       the &struct hlist_node to use as a loop cursor.
  * @head:      the head for your list.
  * @offset:    offset of hlist_node within the struct.
  *
  */
-#define sk_nulls_for_each_entry_offset(tpos, pos, head, offset)                       \
-       for (pos = (head)->first;                                              \
-            (!is_a_nulls(pos)) &&                                             \
+#define sk_for_each_entry_offset_rcu(tpos, pos, head, offset)                 \
+       for (pos = rcu_dereference((head)->first);                             \
+            pos != NULL &&                                                    \
                ({ tpos = (typeof(*tpos) *)((void *)pos - offset); 1;});       \
-            pos = pos->next)
+            pos = rcu_dereference(pos->next))
 
 static inline struct user_namespace *sk_user_ns(struct sock *sk)
 {
@@ -720,6 +723,7 @@ enum sock_flags {
                     */
        SOCK_FILTER_LOCKED, /* Filter cannot be changed anymore */
        SOCK_SELECT_ERR_QUEUE, /* Wake select on error queue */
+       SOCK_RCU_FREE, /* wait rcu grace period in sk_destruct() */
 };
 
 #define SK_FLAGS_TIMESTAMP ((1UL << SOCK_TIMESTAMP) | (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE))
@@ -1418,8 +1422,11 @@ void sk_send_sigurg(struct sock *sk);
 
 struct sockcm_cookie {
        u32 mark;
+       u16 tsflags;
 };
 
+int __sock_cmsg_send(struct sock *sk, struct msghdr *msg, struct cmsghdr *cmsg,
+                    struct sockcm_cookie *sockc);
 int sock_cmsg_send(struct sock *sk, struct msghdr *msg,
                   struct sockcm_cookie *sockc);
 
@@ -2007,6 +2014,13 @@ sock_skb_set_dropcount(const struct sock *sk, struct sk_buff *skb)
        SOCK_SKB_CB(skb)->dropcount = atomic_read(&sk->sk_drops);
 }
 
+static inline void sk_drops_add(struct sock *sk, const struct sk_buff *skb)
+{
+       int segs = max_t(u16, 1, skb_shinfo(skb)->gso_segs);
+
+       atomic_add(segs, &sk->sk_drops);
+}
+
 void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk,
                           struct sk_buff *skb);
 void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk,
@@ -2054,19 +2068,21 @@ static inline void sock_recv_ts_and_drops(struct msghdr *msg, struct sock *sk,
                sk->sk_stamp = skb->tstamp;
 }
 
-void __sock_tx_timestamp(const struct sock *sk, __u8 *tx_flags);
+void __sock_tx_timestamp(__u16 tsflags, __u8 *tx_flags);
 
 /**
  * sock_tx_timestamp - checks whether the outgoing packet is to be time stamped
  * @sk:                socket sending this packet
+ * @tsflags:   timestamping flags to use
  * @tx_flags:  completed with instructions for time stamping
  *
  * Note : callers should take care of initial *tx_flags value (usually 0)
  */
-static inline void sock_tx_timestamp(const struct sock *sk, __u8 *tx_flags)
+static inline void sock_tx_timestamp(const struct sock *sk, __u16 tsflags,
+                                    __u8 *tx_flags)
 {
-       if (unlikely(sk->sk_tsflags))
-               __sock_tx_timestamp(sk, tx_flags);
+       if (unlikely(tsflags))
+               __sock_tx_timestamp(tsflags, tx_flags);
        if (unlikely(sock_flag(sk, SOCK_WIFI_STATUS)))
                *tx_flags |= SKBTX_WIFI_STATUS;
 }