Merge tag 'armsoc-arm64' of git://git.kernel.org/pub/scm/linux/kernel/git/arm/arm-soc
[cascardo/linux.git] / net / rxrpc / af_rxrpc.c
index 1e8cf3d..44c9c2b 100644 (file)
@@ -45,7 +45,7 @@ u32 rxrpc_epoch;
 atomic_t rxrpc_debug_id;
 
 /* count of skbs currently in use */
-atomic_t rxrpc_n_skbs;
+atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs;
 
 struct workqueue_struct *rxrpc_workqueue;
 
@@ -106,19 +106,25 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx,
        case AF_INET:
                if (srx->transport_len < sizeof(struct sockaddr_in))
                        return -EINVAL;
-               _debug("INET: %x @ %pI4",
-                      ntohs(srx->transport.sin.sin_port),
-                      &srx->transport.sin.sin_addr);
                tail = offsetof(struct sockaddr_rxrpc, transport.sin.__pad);
                break;
 
+#ifdef CONFIG_AF_RXRPC_IPV6
        case AF_INET6:
+               if (srx->transport_len < sizeof(struct sockaddr_in6))
+                       return -EINVAL;
+               tail = offsetof(struct sockaddr_rxrpc, transport) +
+                       sizeof(struct sockaddr_in6);
+               break;
+#endif
+
        default:
                return -EAFNOSUPPORT;
        }
 
        if (tail < len)
                memset((void *)srx + tail, 0, len - tail);
+       _debug("INET: %pISp", &srx->transport);
        return 0;
 }
 
@@ -130,7 +136,8 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
        struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)saddr;
        struct sock *sk = sock->sk;
        struct rxrpc_local *local;
-       struct rxrpc_sock *rx = rxrpc_sk(sk), *prx;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       u16 service_id = srx->srx_service;
        int ret;
 
        _enter("%p,%p,%d", rx, saddr, len);
@@ -154,16 +161,13 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
                goto error_unlock;
        }
 
-       if (rx->srx.srx_service) {
-               write_lock_bh(&local->services_lock);
-               hlist_for_each_entry(prx, &local->services, listen_link) {
-                       if (prx->srx.srx_service == rx->srx.srx_service)
-                               goto service_in_use;
-               }
-
+       if (service_id) {
+               write_lock(&local->services_lock);
+               if (rcu_access_pointer(local->service))
+                       goto service_in_use;
                rx->local = local;
-               hlist_add_head_rcu(&rx->listen_link, &local->services);
-               write_unlock_bh(&local->services_lock);
+               rcu_assign_pointer(local->service, rx);
+               write_unlock(&local->services_lock);
 
                rx->sk.sk_state = RXRPC_SERVER_BOUND;
        } else {
@@ -176,7 +180,7 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
        return 0;
 
 service_in_use:
-       write_unlock_bh(&local->services_lock);
+       write_unlock(&local->services_lock);
        rxrpc_put_local(local);
        ret = -EADDRINUSE;
 error_unlock:
@@ -299,7 +303,7 @@ void rxrpc_kernel_end_call(struct socket *sock, struct rxrpc_call *call)
 {
        _enter("%d{%d}", call->debug_id, atomic_read(&call->usage));
        rxrpc_release_call(rxrpc_sk(sock->sk), call);
-       rxrpc_put_call(call, rxrpc_call_put);
+       rxrpc_put_call(call, rxrpc_call_put_kernel);
 }
 EXPORT_SYMBOL(rxrpc_kernel_end_call);
 
@@ -401,6 +405,23 @@ static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
 
        switch (rx->sk.sk_state) {
        case RXRPC_UNBOUND:
+               rx->srx.srx_family = AF_RXRPC;
+               rx->srx.srx_service = 0;
+               rx->srx.transport_type = SOCK_DGRAM;
+               rx->srx.transport.family = rx->family;
+               switch (rx->family) {
+               case AF_INET:
+                       rx->srx.transport_len = sizeof(struct sockaddr_in);
+                       break;
+#ifdef CONFIG_AF_RXRPC_IPV6
+               case AF_INET6:
+                       rx->srx.transport_len = sizeof(struct sockaddr_in6);
+                       break;
+#endif
+               default:
+                       ret = -EAFNOSUPPORT;
+                       goto error_unlock;
+               }
                local = rxrpc_lookup_local(&rx->srx);
                if (IS_ERR(local)) {
                        ret = PTR_ERR(local);
@@ -515,15 +536,16 @@ error:
 static unsigned int rxrpc_poll(struct file *file, struct socket *sock,
                               poll_table *wait)
 {
-       unsigned int mask;
        struct sock *sk = sock->sk;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       unsigned int mask;
 
        sock_poll_wait(file, sk_sleep(sk), wait);
        mask = 0;
 
        /* the socket is readable if there are any messages waiting on the Rx
         * queue */
-       if (!skb_queue_empty(&sk->sk_receive_queue))
+       if (!list_empty(&rx->recvmsg_q))
                mask |= POLLIN | POLLRDNORM;
 
        /* the socket is writable if there is space to add new data to the
@@ -550,7 +572,8 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
                return -EAFNOSUPPORT;
 
        /* we support transport protocol UDP/UDP6 only */
-       if (protocol != PF_INET)
+       if (protocol != PF_INET &&
+           IS_ENABLED(CONFIG_AF_RXRPC_IPV6) && protocol != PF_INET6)
                return -EPROTONOSUPPORT;
 
        if (sock->type != SOCK_DGRAM)
@@ -574,9 +597,11 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
        rx->family = protocol;
        rx->calls = RB_ROOT;
 
-       INIT_HLIST_NODE(&rx->listen_link);
-       INIT_LIST_HEAD(&rx->secureq);
-       INIT_LIST_HEAD(&rx->acceptq);
+       spin_lock_init(&rx->incoming_lock);
+       INIT_LIST_HEAD(&rx->sock_calls);
+       INIT_LIST_HEAD(&rx->to_be_accepted);
+       INIT_LIST_HEAD(&rx->recvmsg_q);
+       rwlock_init(&rx->recvmsg_lock);
        rwlock_init(&rx->call_lock);
        memset(&rx->srx, 0, sizeof(rx->srx));
 
@@ -584,6 +609,39 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
        return 0;
 }
 
+/*
+ * Kill all the calls on a socket and shut it down.
+ */
+static int rxrpc_shutdown(struct socket *sock, int flags)
+{
+       struct sock *sk = sock->sk;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       int ret = 0;
+
+       _enter("%p,%d", sk, flags);
+
+       if (flags != SHUT_RDWR)
+               return -EOPNOTSUPP;
+       if (sk->sk_state == RXRPC_CLOSE)
+               return -ESHUTDOWN;
+
+       lock_sock(sk);
+
+       spin_lock_bh(&sk->sk_receive_queue.lock);
+       if (sk->sk_state < RXRPC_CLOSE) {
+               sk->sk_state = RXRPC_CLOSE;
+               sk->sk_shutdown = SHUTDOWN_MASK;
+       } else {
+               ret = -ESHUTDOWN;
+       }
+       spin_unlock_bh(&sk->sk_receive_queue.lock);
+
+       rxrpc_discard_prealloc(rx);
+
+       release_sock(sk);
+       return ret;
+}
+
 /*
  * RxRPC socket destructor
  */
@@ -620,12 +678,10 @@ static int rxrpc_release_sock(struct sock *sk)
        sk->sk_state = RXRPC_CLOSE;
        spin_unlock_bh(&sk->sk_receive_queue.lock);
 
-       ASSERTCMP(rx->listen_link.next, !=, LIST_POISON1);
-
-       if (!hlist_unhashed(&rx->listen_link)) {
-               write_lock_bh(&rx->local->services_lock);
-               hlist_del_rcu(&rx->listen_link);
-               write_unlock_bh(&rx->local->services_lock);
+       if (rx->local && rx->local->service == rx) {
+               write_lock(&rx->local->services_lock);
+               rx->local->service = NULL;
+               write_unlock(&rx->local->services_lock);
        }
 
        /* try to flush out this socket */
@@ -678,7 +734,7 @@ static const struct proto_ops rxrpc_rpc_ops = {
        .poll           = rxrpc_poll,
        .ioctl          = sock_no_ioctl,
        .listen         = rxrpc_listen,
-       .shutdown       = sock_no_shutdown,
+       .shutdown       = rxrpc_shutdown,
        .setsockopt     = rxrpc_setsockopt,
        .getsockopt     = sock_no_getsockopt,
        .sendmsg        = rxrpc_sendmsg,
@@ -806,7 +862,8 @@ static void __exit af_rxrpc_exit(void)
        proto_unregister(&rxrpc_proto);
        rxrpc_destroy_all_calls();
        rxrpc_destroy_all_connections();
-       ASSERTCMP(atomic_read(&rxrpc_n_skbs), ==, 0);
+       ASSERTCMP(atomic_read(&rxrpc_n_tx_skbs), ==, 0);
+       ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0);
        rxrpc_destroy_all_locals();
 
        remove_proc_entry("rxrpc_conns", init_net.proc_net);