KEYS: remove a bogus NULL check
[cascardo/linux.git] / net / socket.c
index fe20c31..8809afc 100644 (file)
@@ -651,7 +651,8 @@ static inline int __sock_sendmsg(struct kiocb *iocb, struct socket *sock,
        return err ?: __sock_sendmsg_nosec(iocb, sock, msg, size);
 }
 
-int sock_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
+static int do_sock_sendmsg(struct socket *sock, struct msghdr *msg,
+                          size_t size, bool nosec)
 {
        struct kiocb iocb;
        struct sock_iocb siocb;
@@ -659,25 +660,22 @@ int sock_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
 
        init_sync_kiocb(&iocb, NULL);
        iocb.private = &siocb;
-       ret = __sock_sendmsg(&iocb, sock, msg, size);
+       ret = nosec ? __sock_sendmsg_nosec(&iocb, sock, msg, size) :
+                     __sock_sendmsg(&iocb, sock, msg, size);
        if (-EIOCBQUEUED == ret)
                ret = wait_on_sync_kiocb(&iocb);
        return ret;
 }
+
+int sock_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
+{
+       return do_sock_sendmsg(sock, msg, size, false);
+}
 EXPORT_SYMBOL(sock_sendmsg);
 
 static int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg, size_t size)
 {
-       struct kiocb iocb;
-       struct sock_iocb siocb;
-       int ret;
-
-       init_sync_kiocb(&iocb, NULL);
-       iocb.private = &siocb;
-       ret = __sock_sendmsg_nosec(&iocb, sock, msg, size);
-       if (-EIOCBQUEUED == ret)
-               ret = wait_on_sync_kiocb(&iocb);
-       return ret;
+       return do_sock_sendmsg(sock, msg, size, true);
 }
 
 int kernel_sendmsg(struct socket *sock, struct msghdr *msg,
@@ -691,8 +689,7 @@ int kernel_sendmsg(struct socket *sock, struct msghdr *msg,
         * the following is safe, since for compiler definitions of kvec and
         * iovec are identical, yielding the same in-core layout and alignment
         */
-       msg->msg_iov = (struct iovec *)vec;
-       msg->msg_iovlen = num;
+       iov_iter_init(&msg->msg_iter, WRITE, (struct iovec *)vec, num, size);
        result = sock_sendmsg(sock, msg, size);
        set_fs(oldfs);
        return result;
@@ -855,7 +852,7 @@ int kernel_recvmsg(struct socket *sock, struct msghdr *msg,
         * the following is safe, since for compiler definitions of kvec and
         * iovec are identical, yielding the same in-core layout and alignment
         */
-       msg->msg_iov = (struct iovec *)vec, msg->msg_iovlen = num;
+       iov_iter_init(&msg->msg_iter, READ, (struct iovec *)vec, num, size);
        result = sock_recvmsg(sock, msg, size, flags);
        set_fs(oldfs);
        return result;
@@ -915,8 +912,7 @@ static ssize_t do_sock_read(struct msghdr *msg, struct kiocb *iocb,
        msg->msg_namelen = 0;
        msg->msg_control = NULL;
        msg->msg_controllen = 0;
-       msg->msg_iov = (struct iovec *)iov;
-       msg->msg_iovlen = nr_segs;
+       iov_iter_init(&msg->msg_iter, READ, iov, nr_segs, size);
        msg->msg_flags = (file->f_flags & O_NONBLOCK) ? MSG_DONTWAIT : 0;
 
        return __sock_recvmsg(iocb, sock, msg, size, msg->msg_flags);
@@ -955,8 +951,7 @@ static ssize_t do_sock_write(struct msghdr *msg, struct kiocb *iocb,
        msg->msg_namelen = 0;
        msg->msg_control = NULL;
        msg->msg_controllen = 0;
-       msg->msg_iov = (struct iovec *)iov;
-       msg->msg_iovlen = nr_segs;
+       iov_iter_init(&msg->msg_iter, WRITE, iov, nr_segs, size);
        msg->msg_flags = (file->f_flags & O_NONBLOCK) ? MSG_DONTWAIT : 0;
        if (sock->type == SOCK_SEQPACKET)
                msg->msg_flags |= MSG_EOR;
@@ -1800,8 +1795,7 @@ SYSCALL_DEFINE6(sendto, int, fd, void __user *, buff, size_t, len,
        iov.iov_base = buff;
        iov.iov_len = len;
        msg.msg_name = NULL;
-       msg.msg_iov = &iov;
-       msg.msg_iovlen = 1;
+       iov_iter_init(&msg.msg_iter, WRITE, &iov, 1, len);
        msg.msg_control = NULL;
        msg.msg_controllen = 0;
        msg.msg_namelen = 0;
@@ -1858,10 +1852,9 @@ SYSCALL_DEFINE6(recvfrom, int, fd, void __user *, ubuf, size_t, size,
 
        msg.msg_control = NULL;
        msg.msg_controllen = 0;
-       msg.msg_iovlen = 1;
-       msg.msg_iov = &iov;
        iov.iov_len = size;
        iov.iov_base = ubuf;
+       iov_iter_init(&msg.msg_iter, READ, &iov, 1, size);
        /* Save some cycles and don't copy the address if not needed */
        msg.msg_name = addr ? (struct sockaddr *)&address : NULL;
        /* We assume all kernel code knows the size of sockaddr_storage */
@@ -1988,13 +1981,27 @@ struct used_address {
        unsigned int name_len;
 };
 
-static int copy_msghdr_from_user(struct msghdr *kmsg,
-                                struct msghdr __user *umsg)
+static ssize_t copy_msghdr_from_user(struct msghdr *kmsg,
+                                    struct user_msghdr __user *umsg,
+                                    struct sockaddr __user **save_addr,
+                                    struct iovec **iov)
 {
-       if (copy_from_user(kmsg, umsg, sizeof(struct msghdr)))
+       struct sockaddr __user *uaddr;
+       struct iovec __user *uiov;
+       size_t nr_segs;
+       ssize_t err;
+
+       if (!access_ok(VERIFY_READ, umsg, sizeof(*umsg)) ||
+           __get_user(uaddr, &umsg->msg_name) ||
+           __get_user(kmsg->msg_namelen, &umsg->msg_namelen) ||
+           __get_user(uiov, &umsg->msg_iov) ||
+           __get_user(nr_segs, &umsg->msg_iovlen) ||
+           __get_user(kmsg->msg_control, &umsg->msg_control) ||
+           __get_user(kmsg->msg_controllen, &umsg->msg_controllen) ||
+           __get_user(kmsg->msg_flags, &umsg->msg_flags))
                return -EFAULT;
 
-       if (kmsg->msg_name == NULL)
+       if (!uaddr)
                kmsg->msg_namelen = 0;
 
        if (kmsg->msg_namelen < 0)
@@ -2002,10 +2009,35 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
 
        if (kmsg->msg_namelen > sizeof(struct sockaddr_storage))
                kmsg->msg_namelen = sizeof(struct sockaddr_storage);
-       return 0;
+
+       if (save_addr)
+               *save_addr = uaddr;
+
+       if (uaddr && kmsg->msg_namelen) {
+               if (!save_addr) {
+                       err = move_addr_to_kernel(uaddr, kmsg->msg_namelen,
+                                                 kmsg->msg_name);
+                       if (err < 0)
+                               return err;
+               }
+       } else {
+               kmsg->msg_name = NULL;
+               kmsg->msg_namelen = 0;
+       }
+
+       if (nr_segs > UIO_MAXIOV)
+               return -EMSGSIZE;
+
+       err = rw_copy_check_uvector(save_addr ? READ : WRITE,
+                                   uiov, nr_segs,
+                                   UIO_FASTIOV, *iov, iov);
+       if (err >= 0)
+               iov_iter_init(&kmsg->msg_iter, save_addr ? READ : WRITE,
+                             *iov, nr_segs, err);
+       return err;
 }
 
-static int ___sys_sendmsg(struct socket *sock, struct msghdr __user *msg,
+static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
                         struct msghdr *msg_sys, unsigned int flags,
                         struct used_address *used_address)
 {
@@ -2017,34 +2049,15 @@ static int ___sys_sendmsg(struct socket *sock, struct msghdr __user *msg,
            __attribute__ ((aligned(sizeof(__kernel_size_t))));
        /* 20 is size of ipv6_pktinfo */
        unsigned char *ctl_buf = ctl;
-       int err, ctl_len, total_len;
-
-       err = -EFAULT;
-       if (MSG_CMSG_COMPAT & flags) {
-               if (get_compat_msghdr(msg_sys, msg_compat))
-                       return -EFAULT;
-       } else {
-               err = copy_msghdr_from_user(msg_sys, msg);
-               if (err)
-                       return err;
-       }
+       int ctl_len, total_len;
+       ssize_t err;
 
-       if (msg_sys->msg_iovlen > UIO_FASTIOV) {
-               err = -EMSGSIZE;
-               if (msg_sys->msg_iovlen > UIO_MAXIOV)
-                       goto out;
-               err = -ENOMEM;
-               iov = kmalloc(msg_sys->msg_iovlen * sizeof(struct iovec),
-                             GFP_KERNEL);
-               if (!iov)
-                       goto out;
-       }
+       msg_sys->msg_name = &address;
 
-       /* This will also move the address data into kernel space */
-       if (MSG_CMSG_COMPAT & flags) {
-               err = verify_compat_iovec(msg_sys, iov, &address, VERIFY_READ);
-       } else
-               err = verify_iovec(msg_sys, iov, &address, VERIFY_READ);
+       if (MSG_CMSG_COMPAT & flags)
+               err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
+       else
+               err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
        if (err < 0)
                goto out_freeiov;
        total_len = err;
@@ -2115,7 +2128,6 @@ out_freectl:
 out_freeiov:
        if (iov != iovstack)
                kfree(iov);
-out:
        return err;
 }
 
@@ -2123,7 +2135,7 @@ out:
  *     BSD sendmsg interface
  */
 
-long __sys_sendmsg(int fd, struct msghdr __user *msg, unsigned flags)
+long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned flags)
 {
        int fput_needed, err;
        struct msghdr msg_sys;
@@ -2140,7 +2152,7 @@ out:
        return err;
 }
 
-SYSCALL_DEFINE3(sendmsg, int, fd, struct msghdr __user *, msg, unsigned int, flags)
+SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int, flags)
 {
        if (flags & MSG_CMSG_COMPAT)
                return -EINVAL;
@@ -2177,7 +2189,7 @@ int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
 
        while (datagrams < vlen) {
                if (MSG_CMSG_COMPAT & flags) {
-                       err = ___sys_sendmsg(sock, (struct msghdr __user *)compat_entry,
+                       err = ___sys_sendmsg(sock, (struct user_msghdr __user *)compat_entry,
                                             &msg_sys, flags, &used_address);
                        if (err < 0)
                                break;
@@ -2185,7 +2197,7 @@ int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
                        ++compat_entry;
                } else {
                        err = ___sys_sendmsg(sock,
-                                            (struct msghdr __user *)entry,
+                                            (struct user_msghdr __user *)entry,
                                             &msg_sys, flags, &used_address);
                        if (err < 0)
                                break;
@@ -2215,7 +2227,7 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
        return __sys_sendmmsg(fd, mmsg, vlen, flags);
 }
 
-static int ___sys_recvmsg(struct socket *sock, struct msghdr __user *msg,
+static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
                         struct msghdr *msg_sys, unsigned int flags, int nosec)
 {
        struct compat_msghdr __user *msg_compat =
@@ -2223,44 +2235,22 @@ static int ___sys_recvmsg(struct socket *sock, struct msghdr __user *msg,
        struct iovec iovstack[UIO_FASTIOV];
        struct iovec *iov = iovstack;
        unsigned long cmsg_ptr;
-       int err, total_len, len;
+       int total_len, len;
+       ssize_t err;
 
        /* kernel mode address */
        struct sockaddr_storage addr;
 
        /* user mode address pointers */
        struct sockaddr __user *uaddr;
-       int __user *uaddr_len;
-
-       if (MSG_CMSG_COMPAT & flags) {
-               if (get_compat_msghdr(msg_sys, msg_compat))
-                       return -EFAULT;
-       } else {
-               err = copy_msghdr_from_user(msg_sys, msg);
-               if (err)
-                       return err;
-       }
+       int __user *uaddr_len = COMPAT_NAMELEN(msg);
 
-       if (msg_sys->msg_iovlen > UIO_FASTIOV) {
-               err = -EMSGSIZE;
-               if (msg_sys->msg_iovlen > UIO_MAXIOV)
-                       goto out;
-               err = -ENOMEM;
-               iov = kmalloc(msg_sys->msg_iovlen * sizeof(struct iovec),
-                             GFP_KERNEL);
-               if (!iov)
-                       goto out;
-       }
+       msg_sys->msg_name = &addr;
 
-       /* Save the user-mode address (verify_iovec will change the
-        * kernel msghdr to use the kernel address space)
-        */
-       uaddr = (__force void __user *)msg_sys->msg_name;
-       uaddr_len = COMPAT_NAMELEN(msg);
        if (MSG_CMSG_COMPAT & flags)
-               err = verify_compat_iovec(msg_sys, iov, &addr, VERIFY_WRITE);
+               err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov);
        else
-               err = verify_iovec(msg_sys, iov, &addr, VERIFY_WRITE);
+               err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov);
        if (err < 0)
                goto out_freeiov;
        total_len = err;
@@ -2303,7 +2293,6 @@ static int ___sys_recvmsg(struct socket *sock, struct msghdr __user *msg,
 out_freeiov:
        if (iov != iovstack)
                kfree(iov);
-out:
        return err;
 }
 
@@ -2311,7 +2300,7 @@ out:
  *     BSD recvmsg interface
  */
 
-long __sys_recvmsg(int fd, struct msghdr __user *msg, unsigned flags)
+long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned flags)
 {
        int fput_needed, err;
        struct msghdr msg_sys;
@@ -2328,7 +2317,7 @@ out:
        return err;
 }
 
-SYSCALL_DEFINE3(recvmsg, int, fd, struct msghdr __user *, msg,
+SYSCALL_DEFINE3(recvmsg, int, fd, struct user_msghdr __user *, msg,
                unsigned int, flags)
 {
        if (flags & MSG_CMSG_COMPAT)
@@ -2373,7 +2362,7 @@ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
                 * No need to ask LSM for more than the first datagram.
                 */
                if (MSG_CMSG_COMPAT & flags) {
-                       err = ___sys_recvmsg(sock, (struct msghdr __user *)compat_entry,
+                       err = ___sys_recvmsg(sock, (struct user_msghdr __user *)compat_entry,
                                             &msg_sys, flags & ~MSG_WAITFORONE,
                                             datagrams);
                        if (err < 0)
@@ -2382,7 +2371,7 @@ int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen,
                        ++compat_entry;
                } else {
                        err = ___sys_recvmsg(sock,
-                                            (struct msghdr __user *)entry,
+                                            (struct user_msghdr __user *)entry,
                                             &msg_sys, flags & ~MSG_WAITFORONE,
                                             datagrams);
                        if (err < 0)
@@ -2571,13 +2560,13 @@ SYSCALL_DEFINE2(socketcall, int, call, unsigned long __user *, args)
                                   (int __user *)a[4]);
                break;
        case SYS_SENDMSG:
-               err = sys_sendmsg(a0, (struct msghdr __user *)a1, a[2]);
+               err = sys_sendmsg(a0, (struct user_msghdr __user *)a1, a[2]);
                break;
        case SYS_SENDMMSG:
                err = sys_sendmmsg(a0, (struct mmsghdr __user *)a1, a[2], a[3]);
                break;
        case SYS_RECVMSG:
-               err = sys_recvmsg(a0, (struct msghdr __user *)a1, a[2]);
+               err = sys_recvmsg(a0, (struct user_msghdr __user *)a1, a[2]);
                break;
        case SYS_RECVMMSG:
                err = sys_recvmmsg(a0, (struct mmsghdr __user *)a1, a[2], a[3],