vti: use right inner_mode for inbound inter address family policy checks
[cascardo/linux.git] / net / ipv4 / ip_vti.c
index a917903..5d7944f 100644 (file)
@@ -88,6 +88,7 @@ static int vti_rcv_cb(struct sk_buff *skb, int err)
        struct net_device *dev;
        struct pcpu_sw_netstats *tstats;
        struct xfrm_state *x;
+       struct xfrm_mode *inner_mode;
        struct ip_tunnel *tunnel = XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4;
        u32 orig_mark = skb->mark;
        int ret;
@@ -105,7 +106,19 @@ static int vti_rcv_cb(struct sk_buff *skb, int err)
        }
 
        x = xfrm_input_state(skb);
-       family = x->inner_mode->afinfo->family;
+
+       inner_mode = x->inner_mode;
+
+       if (x->sel.family == AF_UNSPEC) {
+               inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol);
+               if (inner_mode == NULL) {
+                       XFRM_INC_STATS(dev_net(skb->dev),
+                                      LINUX_MIB_XFRMINSTATEMODEERROR);
+                       return -EINVAL;
+               }
+       }
+
+       family = inner_mode->afinfo->family;
 
        skb->mark = be32_to_cpu(tunnel->parms.i_key);
        ret = xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, family);
@@ -557,6 +570,33 @@ static struct rtnl_link_ops vti_link_ops __read_mostly = {
        .get_link_net   = ip_tunnel_get_link_net,
 };
 
+static bool is_vti_tunnel(const struct net_device *dev)
+{
+       return dev->netdev_ops == &vti_netdev_ops;
+}
+
+static int vti_device_event(struct notifier_block *unused,
+                           unsigned long event, void *ptr)
+{
+       struct net_device *dev = netdev_notifier_info_to_dev(ptr);
+       struct ip_tunnel *tunnel = netdev_priv(dev);
+
+       if (!is_vti_tunnel(dev))
+               return NOTIFY_DONE;
+
+       switch (event) {
+       case NETDEV_DOWN:
+               if (!net_eq(tunnel->net, dev_net(dev)))
+                       xfrm_garbage_collect(tunnel->net);
+               break;
+       }
+       return NOTIFY_DONE;
+}
+
+static struct notifier_block vti_notifier_block __read_mostly = {
+       .notifier_call = vti_device_event,
+};
+
 static int __init vti_init(void)
 {
        const char *msg;
@@ -564,6 +604,8 @@ static int __init vti_init(void)
 
        pr_info("IPv4 over IPsec tunneling driver\n");
 
+       register_netdevice_notifier(&vti_notifier_block);
+
        msg = "tunnel device";
        err = register_pernet_device(&vti_net_ops);
        if (err < 0)
@@ -596,6 +638,7 @@ xfrm_proto_ah_failed:
 xfrm_proto_esp_failed:
        unregister_pernet_device(&vti_net_ops);
 pernet_dev_failed:
+       unregister_netdevice_notifier(&vti_notifier_block);
        pr_err("vti init: failed to register %s\n", msg);
        return err;
 }
@@ -607,6 +650,7 @@ static void __exit vti_fini(void)
        xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH);
        xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP);
        unregister_pernet_device(&vti_net_ops);
+       unregister_netdevice_notifier(&vti_notifier_block);
 }
 
 module_init(vti_init);