Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / rxrpc / af_rxrpc.c
index 88effad..2d59c9b 100644 (file)
 #include <linux/net.h>
 #include <linux/slab.h>
 #include <linux/skbuff.h>
+#include <linux/random.h>
 #include <linux/poll.h>
 #include <linux/proc_fs.h>
 #include <linux/key-type.h>
 #include <net/net_namespace.h>
 #include <net/sock.h>
 #include <net/af_rxrpc.h>
+#define CREATE_TRACE_POINTS
 #include "ar-internal.h"
 
 MODULE_DESCRIPTION("RxRPC network protocol");
@@ -43,7 +45,7 @@ u32 rxrpc_epoch;
 atomic_t rxrpc_debug_id;
 
 /* count of skbs currently in use */
-atomic_t rxrpc_n_skbs;
+atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs;
 
 struct workqueue_struct *rxrpc_workqueue;
 
@@ -104,19 +106,25 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx,
        case AF_INET:
                if (srx->transport_len < sizeof(struct sockaddr_in))
                        return -EINVAL;
-               _debug("INET: %x @ %pI4",
-                      ntohs(srx->transport.sin.sin_port),
-                      &srx->transport.sin.sin_addr);
                tail = offsetof(struct sockaddr_rxrpc, transport.sin.__pad);
                break;
 
+#ifdef CONFIG_AF_RXRPC_IPV6
        case AF_INET6:
+               if (srx->transport_len < sizeof(struct sockaddr_in6))
+                       return -EINVAL;
+               tail = offsetof(struct sockaddr_rxrpc, transport) +
+                       sizeof(struct sockaddr_in6);
+               break;
+#endif
+
        default:
                return -EAFNOSUPPORT;
        }
 
        if (tail < len)
                memset((void *)srx + tail, 0, len - tail);
+       _debug("INET: %pISp", &srx->transport);
        return 0;
 }
 
@@ -128,7 +136,8 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
        struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)saddr;
        struct sock *sk = sock->sk;
        struct rxrpc_local *local;
-       struct rxrpc_sock *rx = rxrpc_sk(sk), *prx;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       u16 service_id = srx->srx_service;
        int ret;
 
        _enter("%p,%p,%d", rx, saddr, len);
@@ -152,16 +161,13 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
                goto error_unlock;
        }
 
-       if (rx->srx.srx_service) {
-               write_lock_bh(&local->services_lock);
-               list_for_each_entry(prx, &local->services, listen_link) {
-                       if (prx->srx.srx_service == rx->srx.srx_service)
-                               goto service_in_use;
-               }
-
+       if (service_id) {
+               write_lock(&local->services_lock);
+               if (rcu_access_pointer(local->service))
+                       goto service_in_use;
                rx->local = local;
-               list_add_tail(&rx->listen_link, &local->services);
-               write_unlock_bh(&local->services_lock);
+               rcu_assign_pointer(local->service, rx);
+               write_unlock(&local->services_lock);
 
                rx->sk.sk_state = RXRPC_SERVER_BOUND;
        } else {
@@ -174,7 +180,7 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
        return 0;
 
 service_in_use:
-       write_unlock_bh(&local->services_lock);
+       write_unlock(&local->services_lock);
        rxrpc_put_local(local);
        ret = -EADDRINUSE;
 error_unlock:
@@ -191,7 +197,7 @@ static int rxrpc_listen(struct socket *sock, int backlog)
 {
        struct sock *sk = sock->sk;
        struct rxrpc_sock *rx = rxrpc_sk(sk);
-       unsigned int max;
+       unsigned int max, old;
        int ret;
 
        _enter("%p,%d", rx, backlog);
@@ -210,9 +216,13 @@ static int rxrpc_listen(struct socket *sock, int backlog)
                        backlog = max;
                else if (backlog < 0 || backlog > max)
                        break;
+               old = sk->sk_max_ack_backlog;
                sk->sk_max_ack_backlog = backlog;
-               rx->sk.sk_state = RXRPC_SERVER_LISTENING;
-               ret = 0;
+               ret = rxrpc_service_prealloc(rx, GFP_KERNEL);
+               if (ret == 0)
+                       rx->sk.sk_state = RXRPC_SERVER_LISTENING;
+               else
+                       sk->sk_max_ack_backlog = old;
                break;
        default:
                ret = -EBUSY;
@@ -230,6 +240,8 @@ static int rxrpc_listen(struct socket *sock, int backlog)
  * @srx: The address of the peer to contact
  * @key: The security context to use (defaults to socket setting)
  * @user_call_ID: The ID to use
+ * @gfp: The allocation constraints
+ * @notify_rx: Where to send notifications instead of socket queue
  *
  * Allow a kernel service to begin a call on the nominated socket.  This just
  * sets up all the internal tracking structures and allocates connection and
@@ -242,7 +254,8 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
                                           struct sockaddr_rxrpc *srx,
                                           struct key *key,
                                           unsigned long user_call_ID,
-                                          gfp_t gfp)
+                                          gfp_t gfp,
+                                          rxrpc_notify_rx_t notify_rx)
 {
        struct rxrpc_conn_parameters cp;
        struct rxrpc_call *call;
@@ -269,6 +282,8 @@ struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock,
        cp.exclusive            = false;
        cp.service_id           = srx->srx_service;
        call = rxrpc_new_client_call(rx, &cp, srx, user_call_ID, gfp);
+       if (!IS_ERR(call))
+               call->notify_rx = notify_rx;
 
        release_sock(&rx->sk);
        _leave(" = %p", call);
@@ -278,40 +293,39 @@ EXPORT_SYMBOL(rxrpc_kernel_begin_call);
 
 /**
  * rxrpc_kernel_end_call - Allow a kernel service to end a call it was using
+ * @sock: The socket the call is on
  * @call: The call to end
  *
  * Allow a kernel service to end a call it was using.  The call must be
  * complete before this is called (the call should be aborted if necessary).
  */
-void rxrpc_kernel_end_call(struct rxrpc_call *call)
+void rxrpc_kernel_end_call(struct socket *sock, struct rxrpc_call *call)
 {
        _enter("%d{%d}", call->debug_id, atomic_read(&call->usage));
-       rxrpc_remove_user_ID(call->socket, call);
-       rxrpc_put_call(call);
+       rxrpc_release_call(rxrpc_sk(sock->sk), call);
+       rxrpc_put_call(call, rxrpc_call_put_kernel);
 }
 EXPORT_SYMBOL(rxrpc_kernel_end_call);
 
 /**
- * rxrpc_kernel_intercept_rx_messages - Intercept received RxRPC messages
+ * rxrpc_kernel_new_call_notification - Get notifications of new calls
  * @sock: The socket to intercept received messages on
- * @interceptor: The function to pass the messages to
+ * @notify_new_call: Function to be called when new calls appear
+ * @discard_new_call: Function to discard preallocated calls
  *
- * Allow a kernel service to intercept messages heading for the Rx queue on an
- * RxRPC socket.  They get passed to the specified function instead.
- * @interceptor should free the socket buffers it is given.  @interceptor is
- * called with the socket receive queue spinlock held and softirqs disabled -
- * this ensures that the messages will be delivered in the right order.
+ * Allow a kernel service to be given notifications about new calls.
  */
-void rxrpc_kernel_intercept_rx_messages(struct socket *sock,
-                                       rxrpc_interceptor_t interceptor)
+void rxrpc_kernel_new_call_notification(
+       struct socket *sock,
+       rxrpc_notify_new_call_t notify_new_call,
+       rxrpc_discard_new_call_t discard_new_call)
 {
        struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
 
-       _enter("");
-       rx->interceptor = interceptor;
+       rx->notify_new_call = notify_new_call;
+       rx->discard_new_call = discard_new_call;
 }
-
-EXPORT_SYMBOL(rxrpc_kernel_intercept_rx_messages);
+EXPORT_SYMBOL(rxrpc_kernel_new_call_notification);
 
 /*
  * connect an RxRPC socket
@@ -391,6 +405,23 @@ static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
 
        switch (rx->sk.sk_state) {
        case RXRPC_UNBOUND:
+               rx->srx.srx_family = AF_RXRPC;
+               rx->srx.srx_service = 0;
+               rx->srx.transport_type = SOCK_DGRAM;
+               rx->srx.transport.family = rx->family;
+               switch (rx->family) {
+               case AF_INET:
+                       rx->srx.transport_len = sizeof(struct sockaddr_in);
+                       break;
+#ifdef CONFIG_AF_RXRPC_IPV6
+               case AF_INET6:
+                       rx->srx.transport_len = sizeof(struct sockaddr_in6);
+                       break;
+#endif
+               default:
+                       ret = -EAFNOSUPPORT;
+                       goto error_unlock;
+               }
                local = rxrpc_lookup_local(&rx->srx);
                if (IS_ERR(local)) {
                        ret = PTR_ERR(local);
@@ -505,15 +536,16 @@ error:
 static unsigned int rxrpc_poll(struct file *file, struct socket *sock,
                               poll_table *wait)
 {
-       unsigned int mask;
        struct sock *sk = sock->sk;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       unsigned int mask;
 
        sock_poll_wait(file, sk_sleep(sk), wait);
        mask = 0;
 
        /* the socket is readable if there are any messages waiting on the Rx
         * queue */
-       if (!skb_queue_empty(&sk->sk_receive_queue))
+       if (!list_empty(&rx->recvmsg_q))
                mask |= POLLIN | POLLRDNORM;
 
        /* the socket is writable if there is space to add new data to the
@@ -540,7 +572,8 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
                return -EAFNOSUPPORT;
 
        /* we support transport protocol UDP/UDP6 only */
-       if (protocol != PF_INET)
+       if (protocol != PF_INET &&
+           IS_ENABLED(CONFIG_AF_RXRPC_IPV6) && protocol != PF_INET6)
                return -EPROTONOSUPPORT;
 
        if (sock->type != SOCK_DGRAM)
@@ -554,6 +587,7 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
                return -ENOMEM;
 
        sock_init_data(sock, sk);
+       sock_set_flag(sk, SOCK_RCU_FREE);
        sk->sk_state            = RXRPC_UNBOUND;
        sk->sk_write_space      = rxrpc_write_space;
        sk->sk_max_ack_backlog  = 0;
@@ -563,9 +597,11 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
        rx->family = protocol;
        rx->calls = RB_ROOT;
 
-       INIT_LIST_HEAD(&rx->listen_link);
-       INIT_LIST_HEAD(&rx->secureq);
-       INIT_LIST_HEAD(&rx->acceptq);
+       spin_lock_init(&rx->incoming_lock);
+       INIT_LIST_HEAD(&rx->sock_calls);
+       INIT_LIST_HEAD(&rx->to_be_accepted);
+       INIT_LIST_HEAD(&rx->recvmsg_q);
+       rwlock_init(&rx->recvmsg_lock);
        rwlock_init(&rx->call_lock);
        memset(&rx->srx, 0, sizeof(rx->srx));
 
@@ -573,6 +609,39 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
        return 0;
 }
 
+/*
+ * Kill all the calls on a socket and shut it down.
+ */
+static int rxrpc_shutdown(struct socket *sock, int flags)
+{
+       struct sock *sk = sock->sk;
+       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       int ret = 0;
+
+       _enter("%p,%d", sk, flags);
+
+       if (flags != SHUT_RDWR)
+               return -EOPNOTSUPP;
+       if (sk->sk_state == RXRPC_CLOSE)
+               return -ESHUTDOWN;
+
+       lock_sock(sk);
+
+       spin_lock_bh(&sk->sk_receive_queue.lock);
+       if (sk->sk_state < RXRPC_CLOSE) {
+               sk->sk_state = RXRPC_CLOSE;
+               sk->sk_shutdown = SHUTDOWN_MASK;
+       } else {
+               ret = -ESHUTDOWN;
+       }
+       spin_unlock_bh(&sk->sk_receive_queue.lock);
+
+       rxrpc_discard_prealloc(rx);
+
+       release_sock(sk);
+       return ret;
+}
+
 /*
  * RxRPC socket destructor
  */
@@ -609,15 +678,14 @@ static int rxrpc_release_sock(struct sock *sk)
        sk->sk_state = RXRPC_CLOSE;
        spin_unlock_bh(&sk->sk_receive_queue.lock);
 
-       ASSERTCMP(rx->listen_link.next, !=, LIST_POISON1);
-
-       if (!list_empty(&rx->listen_link)) {
-               write_lock_bh(&rx->local->services_lock);
-               list_del(&rx->listen_link);
-               write_unlock_bh(&rx->local->services_lock);
+       if (rx->local && rcu_access_pointer(rx->local->service) == rx) {
+               write_lock(&rx->local->services_lock);
+               rcu_assign_pointer(rx->local->service, NULL);
+               write_unlock(&rx->local->services_lock);
        }
 
        /* try to flush out this socket */
+       rxrpc_discard_prealloc(rx);
        rxrpc_release_calls_on_socket(rx);
        flush_workqueue(rxrpc_workqueue);
        rxrpc_purge_queue(&sk->sk_receive_queue);
@@ -666,7 +734,7 @@ static const struct proto_ops rxrpc_rpc_ops = {
        .poll           = rxrpc_poll,
        .ioctl          = sock_no_ioctl,
        .listen         = rxrpc_listen,
-       .shutdown       = sock_no_shutdown,
+       .shutdown       = rxrpc_shutdown,
        .setsockopt     = rxrpc_setsockopt,
        .getsockopt     = sock_no_getsockopt,
        .sendmsg        = rxrpc_sendmsg,
@@ -697,7 +765,13 @@ static int __init af_rxrpc_init(void)
 
        BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > FIELD_SIZEOF(struct sk_buff, cb));
 
-       rxrpc_epoch = get_seconds();
+       get_random_bytes(&rxrpc_epoch, sizeof(rxrpc_epoch));
+       rxrpc_epoch |= RXRPC_RANDOM_EPOCH;
+       get_random_bytes(&rxrpc_client_conn_ids.cur,
+                        sizeof(rxrpc_client_conn_ids.cur));
+       rxrpc_client_conn_ids.cur &= 0x3fffffff;
+       if (rxrpc_client_conn_ids.cur == 0)
+               rxrpc_client_conn_ids.cur = 1;
 
        ret = -ENOMEM;
        rxrpc_call_jar = kmem_cache_create(
@@ -788,7 +862,8 @@ static void __exit af_rxrpc_exit(void)
        proto_unregister(&rxrpc_proto);
        rxrpc_destroy_all_calls();
        rxrpc_destroy_all_connections();
-       ASSERTCMP(atomic_read(&rxrpc_n_skbs), ==, 0);
+       ASSERTCMP(atomic_read(&rxrpc_n_tx_skbs), ==, 0);
+       ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0);
        rxrpc_destroy_all_locals();
 
        remove_proc_entry("rxrpc_conns", init_net.proc_net);