]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/ipv4/tcp.c
Merge branch 'tcp-address-KCSAN-reports-in-tcp_poll-part-I'
[linux.git] / net / ipv4 / tcp.c
index c59d0bd29c5c6fcbe38edb12b37c47ba4ed68899..b2ac4f074e2da21db57923fda722b6d23f170de9 100644 (file)
@@ -450,8 +450,8 @@ void tcp_init_sock(struct sock *sk)
 
        icsk->icsk_sync_mss = tcp_sync_mss;
 
-       sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[1];
-       sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1];
+       WRITE_ONCE(sk->sk_sndbuf, sock_net(sk)->ipv4.sysctl_tcp_wmem[1]);
+       WRITE_ONCE(sk->sk_rcvbuf, sock_net(sk)->ipv4.sysctl_tcp_rmem[1]);
 
        sk_sockets_allocated_inc(sk);
        sk->sk_route_forced_caps = NETIF_F_GSO;
@@ -477,7 +477,7 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
 static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
                                          int target, struct sock *sk)
 {
-       return (tp->rcv_nxt - tp->copied_seq >= target) ||
+       return (READ_ONCE(tp->rcv_nxt) - READ_ONCE(tp->copied_seq) >= target) ||
                (sk->sk_prot->stream_memory_read ?
                sk->sk_prot->stream_memory_read(sk) : false);
 }
@@ -546,7 +546,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
            (state != TCP_SYN_RECV || rcu_access_pointer(tp->fastopen_rsk))) {
                int target = sock_rcvlowat(sk, 0, INT_MAX);
 
-               if (tp->urg_seq == tp->copied_seq &&
+               if (READ_ONCE(tp->urg_seq) == READ_ONCE(tp->copied_seq) &&
                    !sock_flag(sk, SOCK_URGINLINE) &&
                    tp->urg_data)
                        target++;
@@ -607,7 +607,8 @@ int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg)
                unlock_sock_fast(sk, slow);
                break;
        case SIOCATMARK:
-               answ = tp->urg_data && tp->urg_seq == tp->copied_seq;
+               answ = tp->urg_data &&
+                      READ_ONCE(tp->urg_seq) == READ_ONCE(tp->copied_seq);
                break;
        case SIOCOUTQ:
                if (sk->sk_state == TCP_LISTEN)
@@ -616,7 +617,7 @@ int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg)
                if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV))
                        answ = 0;
                else
-                       answ = tp->write_seq - tp->snd_una;
+                       answ = READ_ONCE(tp->write_seq) - tp->snd_una;
                break;
        case SIOCOUTQNSD:
                if (sk->sk_state == TCP_LISTEN)
@@ -625,7 +626,8 @@ int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg)
                if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV))
                        answ = 0;
                else
-                       answ = tp->write_seq - tp->snd_nxt;
+                       answ = READ_ONCE(tp->write_seq) -
+                              READ_ONCE(tp->snd_nxt);
                break;
        default:
                return -ENOIOCTLCMD;
@@ -657,7 +659,7 @@ static void skb_entail(struct sock *sk, struct sk_buff *skb)
        tcb->sacked  = 0;
        __skb_header_release(skb);
        tcp_add_write_queue_tail(sk, skb);
-       sk->sk_wmem_queued += skb->truesize;
+       sk_wmem_queued_add(sk, skb->truesize);
        sk_mem_charge(sk, skb->truesize);
        if (tp->nonagle & TCP_NAGLE_PUSH)
                tp->nonagle &= ~TCP_NAGLE_PUSH;
@@ -1032,10 +1034,10 @@ ssize_t do_tcp_sendpages(struct sock *sk, struct page *page, int offset,
                skb->len += copy;
                skb->data_len += copy;
                skb->truesize += copy;
-               sk->sk_wmem_queued += copy;
+               sk_wmem_queued_add(sk, copy);
                sk_mem_charge(sk, copy);
                skb->ip_summed = CHECKSUM_PARTIAL;
-               tp->write_seq += copy;
+               WRITE_ONCE(tp->write_seq, tp->write_seq + copy);
                TCP_SKB_CB(skb)->end_seq += copy;
                tcp_skb_pcount_set(skb, 0);
 
@@ -1362,7 +1364,7 @@ int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size)
                if (!copied)
                        TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_PSH;
 
-               tp->write_seq += copy;
+               WRITE_ONCE(tp->write_seq, tp->write_seq + copy);
                TCP_SKB_CB(skb)->end_seq += copy;
                tcp_skb_pcount_set(skb, 0);
 
@@ -1668,9 +1670,9 @@ int tcp_read_sock(struct sock *sk, read_descriptor_t *desc,
                sk_eat_skb(sk, skb);
                if (!desc->count)
                        break;
-               tp->copied_seq = seq;
+               WRITE_ONCE(tp->copied_seq, seq);
        }
-       tp->copied_seq = seq;
+       WRITE_ONCE(tp->copied_seq, seq);
 
        tcp_rcv_space_adjust(sk);
 
@@ -1709,7 +1711,7 @@ int tcp_set_rcvlowat(struct sock *sk, int val)
 
        val <<= 1;
        if (val > sk->sk_rcvbuf) {
-               sk->sk_rcvbuf = val;
+               WRITE_ONCE(sk->sk_rcvbuf, val);
                tcp_sk(sk)->window_clamp = tcp_win_from_space(sk, val);
        }
        return 0;
@@ -1819,7 +1821,7 @@ static int tcp_zerocopy_receive(struct sock *sk,
 out:
        up_read(&current->mm->mmap_sem);
        if (length) {
-               tp->copied_seq = seq;
+               WRITE_ONCE(tp->copied_seq, seq);
                tcp_rcv_space_adjust(sk);
 
                /* Clean up data we have read: This will do ACK frames. */
@@ -2117,7 +2119,7 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
                        if (urg_offset < used) {
                                if (!urg_offset) {
                                        if (!sock_flag(sk, SOCK_URGINLINE)) {
-                                               ++*seq;
+                                               WRITE_ONCE(*seq, *seq + 1);
                                                urg_hole++;
                                                offset++;
                                                used--;
@@ -2139,7 +2141,7 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
                        }
                }
 
-               *seq += used;
+               WRITE_ONCE(*seq, *seq + used);
                copied += used;
                len -= used;
 
@@ -2166,7 +2168,7 @@ int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int nonblock,
 
 found_fin_ok:
                /* Process the FIN. */
-               ++*seq;
+               WRITE_ONCE(*seq, *seq + 1);
                if (!(flags & MSG_PEEK))
                        sk_eat_skb(sk, skb);
                break;
@@ -2562,6 +2564,7 @@ int tcp_disconnect(struct sock *sk, int flags)
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
        int old_state = sk->sk_state;
+       u32 seq;
 
        if (old_state != TCP_CLOSE)
                tcp_set_state(sk, TCP_CLOSE);
@@ -2588,7 +2591,7 @@ int tcp_disconnect(struct sock *sk, int flags)
                __kfree_skb(sk->sk_rx_skb_cache);
                sk->sk_rx_skb_cache = NULL;
        }
-       tp->copied_seq = tp->rcv_nxt;
+       WRITE_ONCE(tp->copied_seq, tp->rcv_nxt);
        tp->urg_data = 0;
        tcp_write_queue_purge(sk);
        tcp_fastopen_active_disable_ofo_check(sk);
@@ -2604,9 +2607,12 @@ int tcp_disconnect(struct sock *sk, int flags)
        tp->srtt_us = 0;
        tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT);
        tp->rcv_rtt_last_tsecr = 0;
-       tp->write_seq += tp->max_window + 2;
-       if (tp->write_seq == 0)
-               tp->write_seq = 1;
+
+       seq = tp->write_seq + tp->max_window + 2;
+       if (!seq)
+               seq = 1;
+       WRITE_ONCE(tp->write_seq, seq);
+
        icsk->icsk_backoff = 0;
        tp->snd_cwnd = 2;
        icsk->icsk_probes_out = 0;
@@ -2933,9 +2939,9 @@ static int do_tcp_setsockopt(struct sock *sk, int level,
                if (sk->sk_state != TCP_CLOSE)
                        err = -EPERM;
                else if (tp->repair_queue == TCP_SEND_QUEUE)
-                       tp->write_seq = val;
+                       WRITE_ONCE(tp->write_seq, val);
                else if (tp->repair_queue == TCP_RECV_QUEUE)
-                       tp->rcv_nxt = val;
+                       WRITE_ONCE(tp->rcv_nxt, val);
                else
                        err = -EINVAL;
                break;