]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/socket.c
Merge tag 'for-linus' of git://git.kernel.org/pub/scm/virt/kvm/kvm
[linux.git] / net / socket.c
index 0fb0820edeec14fa076ce3105f2a15c057d2a731..50623218747f067c0ecf33a51d8d6b61e39ce139 100644 (file)
@@ -957,7 +957,7 @@ static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to)
                             .msg_iocb = iocb};
        ssize_t res;
 
-       if (file->f_flags & O_NONBLOCK)
+       if (file->f_flags & O_NONBLOCK || (iocb->ki_flags & IOCB_NOWAIT))
                msg.msg_flags = MSG_DONTWAIT;
 
        if (iocb->ki_pos != 0)
@@ -982,7 +982,7 @@ static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from)
        if (iocb->ki_pos != 0)
                return -ESPIPE;
 
-       if (file->f_flags & O_NONBLOCK)
+       if (file->f_flags & O_NONBLOCK || (iocb->ki_flags & IOCB_NOWAIT))
                msg.msg_flags = MSG_DONTWAIT;
 
        if (sock->type == SOCK_SEQPACKET)
@@ -1826,26 +1826,22 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
  *     include the -EINPROGRESS status for such sockets.
  */
 
-int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr,
+int __sys_connect_file(struct file *file, struct sockaddr_storage *address,
                       int addrlen, int file_flags)
 {
        struct socket *sock;
-       struct sockaddr_storage address;
        int err;
 
        sock = sock_from_file(file, &err);
        if (!sock)
                goto out;
-       err = move_addr_to_kernel(uservaddr, addrlen, &address);
-       if (err < 0)
-               goto out;
 
        err =
-           security_socket_connect(sock, (struct sockaddr *)&address, addrlen);
+           security_socket_connect(sock, (struct sockaddr *)address, addrlen);
        if (err)
                goto out;
 
-       err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen,
+       err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen,
                                 sock->file->f_flags | file_flags);
 out:
        return err;
@@ -1858,7 +1854,11 @@ int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
 
        f = fdget(fd);
        if (f.file) {
-               ret = __sys_connect_file(f.file, uservaddr, addrlen, 0);
+               struct sockaddr_storage address;
+
+               ret = move_addr_to_kernel(uservaddr, addrlen, &address);
+               if (!ret)
+                       ret = __sys_connect_file(f.file, &address, addrlen, 0);
                if (f.flags)
                        fput(f.file);
        }
@@ -2546,7 +2546,12 @@ static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys,
 
        if (sock->file->f_flags & O_NONBLOCK)
                flags |= MSG_DONTWAIT;
-       err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags);
+
+       if (unlikely(nosec))
+               err = sock_recvmsg_nosec(sock, msg_sys, flags);
+       else
+               err = sock_recvmsg(sock, msg_sys, flags);
+
        if (err < 0)
                goto out;
        len = err;