geneve: avoid using stale geneve socket.
[cascardo/linux.git] / drivers / net / geneve.c
index 9b3dc3c..42edd7b 100644 (file)
@@ -12,7 +12,6 @@
 
 #include <linux/kernel.h>
 #include <linux/module.h>
-#include <linux/netdevice.h>
 #include <linux/etherdevice.h>
 #include <linux/hash.h>
 #include <net/dst_metadata.h>
@@ -59,9 +58,9 @@ struct geneve_dev {
        struct hlist_node  hlist;       /* vni hash table */
        struct net         *net;        /* netns for packet i/o */
        struct net_device  *dev;        /* netdev for geneve tunnel */
-       struct geneve_sock *sock4;      /* IPv4 socket used for geneve tunnel */
+       struct geneve_sock __rcu *sock4;        /* IPv4 socket used for geneve tunnel */
 #if IS_ENABLED(CONFIG_IPV6)
-       struct geneve_sock *sock6;      /* IPv6 socket used for geneve tunnel */
+       struct geneve_sock __rcu *sock6;        /* IPv6 socket used for geneve tunnel */
 #endif
        u8                 vni[3];      /* virtual network ID for tunnel */
        u8                 ttl;         /* TTL override */
@@ -397,23 +396,6 @@ static struct socket *geneve_create_sock(struct net *net, bool ipv6,
        return sock;
 }
 
-static void geneve_notify_add_rx_port(struct geneve_sock *gs)
-{
-       struct net_device *dev;
-       struct sock *sk = gs->sock->sk;
-       struct net *net = sock_net(sk);
-       sa_family_t sa_family = geneve_get_sk_family(gs);
-       __be16 port = inet_sk(sk)->inet_sport;
-
-       rcu_read_lock();
-       for_each_netdev_rcu(net, dev) {
-               if (dev->netdev_ops->ndo_add_geneve_port)
-                       dev->netdev_ops->ndo_add_geneve_port(dev, sa_family,
-                                                            port);
-       }
-       rcu_read_unlock();
-}
-
 static int geneve_hlen(struct genevehdr *gh)
 {
        return sizeof(*gh) + gh->opt_len * 4;
@@ -471,7 +453,7 @@ static struct sk_buff **geneve_gro_receive(struct sock *sk,
 
        skb_gro_pull(skb, gh_len);
        skb_gro_postpull_rcsum(skb, gh, gh_len);
-       pp = ptype->callbacks.gro_receive(head, skb);
+       pp = call_gro_receive(ptype->callbacks.gro_receive, head, skb);
        flush = 0;
 
 out_unlock:
@@ -533,7 +515,7 @@ static struct geneve_sock *geneve_socket_create(struct net *net, __be16 port,
                INIT_HLIST_HEAD(&gs->vni_list[h]);
 
        /* Initialize the geneve udp offloads structure */
-       geneve_notify_add_rx_port(gs);
+       udp_tunnel_notify_add_rx_port(gs->sock, UDP_TUNNEL_TYPE_GENEVE);
 
        /* Mark socket as an encapsulation socket */
        memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
@@ -548,40 +530,32 @@ static struct geneve_sock *geneve_socket_create(struct net *net, __be16 port,
        return gs;
 }
 
-static void geneve_notify_del_rx_port(struct geneve_sock *gs)
-{
-       struct net_device *dev;
-       struct sock *sk = gs->sock->sk;
-       struct net *net = sock_net(sk);
-       sa_family_t sa_family = geneve_get_sk_family(gs);
-       __be16 port = inet_sk(sk)->inet_sport;
-
-       rcu_read_lock();
-       for_each_netdev_rcu(net, dev) {
-               if (dev->netdev_ops->ndo_del_geneve_port)
-                       dev->netdev_ops->ndo_del_geneve_port(dev, sa_family,
-                                                            port);
-       }
-
-       rcu_read_unlock();
-}
-
 static void __geneve_sock_release(struct geneve_sock *gs)
 {
        if (!gs || --gs->refcnt)
                return;
 
        list_del(&gs->list);
-       geneve_notify_del_rx_port(gs);
+       udp_tunnel_notify_del_rx_port(gs->sock, UDP_TUNNEL_TYPE_GENEVE);
        udp_tunnel_sock_release(gs->sock);
        kfree_rcu(gs, rcu);
 }
 
 static void geneve_sock_release(struct geneve_dev *geneve)
 {
-       __geneve_sock_release(geneve->sock4);
+       struct geneve_sock *gs4 = rtnl_dereference(geneve->sock4);
+#if IS_ENABLED(CONFIG_IPV6)
+       struct geneve_sock *gs6 = rtnl_dereference(geneve->sock6);
+
+       rcu_assign_pointer(geneve->sock6, NULL);
+#endif
+
+       rcu_assign_pointer(geneve->sock4, NULL);
+       synchronize_net();
+
+       __geneve_sock_release(gs4);
 #if IS_ENABLED(CONFIG_IPV6)
-       __geneve_sock_release(geneve->sock6);
+       __geneve_sock_release(gs6);
 #endif
 }
 
@@ -622,10 +596,10 @@ out:
        gs->flags = geneve->flags;
 #if IS_ENABLED(CONFIG_IPV6)
        if (ipv6)
-               geneve->sock6 = gs;
+               rcu_assign_pointer(geneve->sock6, gs);
        else
 #endif
-               geneve->sock4 = gs;
+               rcu_assign_pointer(geneve->sock4, gs);
 
        hash = geneve_net_vni_hash(geneve->vni);
        hlist_add_head_rcu(&geneve->hlist, &gs->vni_list[hash]);
@@ -639,9 +613,7 @@ static int geneve_open(struct net_device *dev)
        bool metadata = geneve->collect_md;
        int ret = 0;
 
-       geneve->sock4 = NULL;
 #if IS_ENABLED(CONFIG_IPV6)
-       geneve->sock6 = NULL;
        if (ipv6 || metadata)
                ret = geneve_sock_add(geneve, true);
 #endif
@@ -756,6 +728,9 @@ static struct rtable *geneve_get_v4_rt(struct sk_buff *skb,
        struct rtable *rt = NULL;
        __u8 tos;
 
+       if (!rcu_dereference(geneve->sock4))
+               return ERR_PTR(-EIO);
+
        memset(fl4, 0, sizeof(*fl4));
        fl4->flowi4_mark = skb->mark;
        fl4->flowi4_proto = IPPROTO_UDP;
@@ -808,11 +783,15 @@ static struct dst_entry *geneve_get_v6_dst(struct sk_buff *skb,
 {
        bool use_cache = ip_tunnel_dst_cache_usable(skb, info);
        struct geneve_dev *geneve = netdev_priv(dev);
-       struct geneve_sock *gs6 = geneve->sock6;
        struct dst_entry *dst = NULL;
        struct dst_cache *dst_cache;
+       struct geneve_sock *gs6;
        __u8 prio;
 
+       gs6 = rcu_dereference(geneve->sock6);
+       if (!gs6)
+               return ERR_PTR(-EIO);
+
        memset(fl6, 0, sizeof(*fl6));
        fl6->flowi6_mark = skb->mark;
        fl6->flowi6_proto = IPPROTO_UDP;
@@ -878,7 +857,7 @@ static netdev_tx_t geneve_xmit_skb(struct sk_buff *skb, struct net_device *dev,
                                   struct ip_tunnel_info *info)
 {
        struct geneve_dev *geneve = netdev_priv(dev);
-       struct geneve_sock *gs4 = geneve->sock4;
+       struct geneve_sock *gs4;
        struct rtable *rt = NULL;
        const struct iphdr *iip; /* interior IP header */
        int err = -EINVAL;
@@ -889,6 +868,10 @@ static netdev_tx_t geneve_xmit_skb(struct sk_buff *skb, struct net_device *dev,
        bool xnet = !net_eq(geneve->net, dev_net(geneve->dev));
        u32 flags = geneve->flags;
 
+       gs4 = rcu_dereference(geneve->sock4);
+       if (!gs4)
+               goto tx_error;
+
        if (geneve->collect_md) {
                if (unlikely(!info || !(info->mode & IP_TUNNEL_INFO_TX))) {
                        netdev_dbg(dev, "no tunnel metadata\n");
@@ -968,9 +951,9 @@ static netdev_tx_t geneve6_xmit_skb(struct sk_buff *skb, struct net_device *dev,
                                    struct ip_tunnel_info *info)
 {
        struct geneve_dev *geneve = netdev_priv(dev);
-       struct geneve_sock *gs6 = geneve->sock6;
        struct dst_entry *dst = NULL;
        const struct iphdr *iip; /* interior IP header */
+       struct geneve_sock *gs6;
        int err = -EINVAL;
        struct flowi6 fl6;
        __u8 prio, ttl;
@@ -979,6 +962,10 @@ static netdev_tx_t geneve6_xmit_skb(struct sk_buff *skb, struct net_device *dev,
        bool xnet = !net_eq(geneve->net, dev_net(geneve->dev));
        u32 flags = geneve->flags;
 
+       gs6 = rcu_dereference(geneve->sock6);
+       if (!gs6)
+               goto tx_error;
+
        if (geneve->collect_md) {
                if (unlikely(!info || !(info->mode & IP_TUNNEL_INFO_TX))) {
                        netdev_dbg(dev, "no tunnel metadata\n");
@@ -1170,29 +1157,20 @@ static struct device_type geneve_type = {
        .name = "geneve",
 };
 
-/* Calls the ndo_add_geneve_port of the caller in order to
+/* Calls the ndo_udp_tunnel_add of the caller in order to
  * supply the listening GENEVE udp ports. Callers are expected
- * to implement the ndo_add_geneve_port.
+ * to implement the ndo_udp_tunnel_add.
  */
 static void geneve_push_rx_ports(struct net_device *dev)
 {
        struct net *net = dev_net(dev);
        struct geneve_net *gn = net_generic(net, geneve_net_id);
        struct geneve_sock *gs;
-       sa_family_t sa_family;
-       struct sock *sk;
-       __be16 port;
-
-       if (!dev->netdev_ops->ndo_add_geneve_port)
-               return;
 
        rcu_read_lock();
-       list_for_each_entry_rcu(gs, &gn->sock_list, list) {
-               sk = gs->sock->sk;
-               sa_family = sk->sk_family;
-               port = inet_sk(sk)->inet_sport;
-               dev->netdev_ops->ndo_add_geneve_port(dev, sa_family, port);
-       }
+       list_for_each_entry_rcu(gs, &gn->sock_list, list)
+               udp_tunnel_push_rx_port(dev, gs->sock,
+                                       UDP_TUNNEL_TYPE_GENEVE);
        rcu_read_unlock();
 }
 
@@ -1555,7 +1533,7 @@ static int geneve_netdevice_event(struct notifier_block *unused,
 {
        struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 
-       if (event == NETDEV_OFFLOAD_PUSH_GENEVE)
+       if (event == NETDEV_UDP_TUNNEL_PUSH_INFO)
                geneve_push_rx_ports(dev);
 
        return NOTIFY_DONE;