]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/ipv4/af_inet.c
Merge tag 'tty-5.3-rc2' of git://git.kernel.org/pub/scm/linux/kernel/git/gregkh/tty
[linux.git] / net / ipv4 / af_inet.c
index 52bdb881a5060fc7032ef892b6028f2bead3b19d..ed2301ef872e4c5d42539fa490169ab8fcaea44e 100644 (file)
@@ -784,10 +784,8 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
 }
 EXPORT_SYMBOL(inet_getname);
 
-int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
+int inet_send_prepare(struct sock *sk)
 {
-       struct sock *sk = sock->sk;
-
        sock_rps_record_flow(sk);
 
        /* We may need to bind the socket. */
@@ -795,7 +793,19 @@ int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
            inet_autobind(sk))
                return -EAGAIN;
 
-       return sk->sk_prot->sendmsg(sk, msg, size);
+       return 0;
+}
+EXPORT_SYMBOL_GPL(inet_send_prepare);
+
+int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
+{
+       struct sock *sk = sock->sk;
+
+       if (unlikely(inet_send_prepare(sk)))
+               return -EAGAIN;
+
+       return INDIRECT_CALL_2(sk->sk_prot->sendmsg, tcp_sendmsg, udp_sendmsg,
+                              sk, msg, size);
 }
 EXPORT_SYMBOL(inet_sendmsg);
 
@@ -804,11 +814,7 @@ ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
 {
        struct sock *sk = sock->sk;
 
-       sock_rps_record_flow(sk);
-
-       /* We may need to bind the socket. */
-       if (!inet_sk(sk)->inet_num && !sk->sk_prot->no_autobind &&
-           inet_autobind(sk))
+       if (unlikely(inet_send_prepare(sk)))
                return -EAGAIN;
 
        if (sk->sk_prot->sendpage)
@@ -817,6 +823,8 @@ ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset,
 }
 EXPORT_SYMBOL(inet_sendpage);
 
+INDIRECT_CALLABLE_DECLARE(int udp_recvmsg(struct sock *, struct msghdr *,
+                                         size_t, int, int, int *));
 int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
                 int flags)
 {
@@ -827,8 +835,9 @@ int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
        if (likely(!(flags & MSG_ERRQUEUE)))
                sock_rps_record_flow(sk);
 
-       err = sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
-                                  flags & ~MSG_DONTWAIT, &addr_len);
+       err = INDIRECT_CALL_2(sk->sk_prot->recvmsg, tcp_recvmsg, udp_recvmsg,
+                             sk, msg, size, flags & MSG_DONTWAIT,
+                             flags & ~MSG_DONTWAIT, &addr_len);
        if (err >= 0)
                msg->msg_namelen = addr_len;
        return err;