]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/ipv4/esp4.c
Merge tag 'tag-chrome-platform-for-v5.6' of git://git.kernel.org/pub/scm/linux/kernel...
[linux.git] / net / ipv4 / esp4.c
index 033c61d27148ffd31b31f4b40251a1ba2e7ff98a..103c7d599a3c907ed55d9eac8cbb288a2aaecae0 100644 (file)
@@ -18,6 +18,8 @@
 #include <net/icmp.h>
 #include <net/protocol.h>
 #include <net/udp.h>
+#include <net/tcp.h>
+#include <net/espintcp.h>
 
 #include <linux/highmem.h>
 
@@ -117,6 +119,132 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp)
                        put_page(sg_page(sg));
 }
 
+#ifdef CONFIG_INET_ESPINTCP
+struct esp_tcp_sk {
+       struct sock *sk;
+       struct rcu_head rcu;
+};
+
+static void esp_free_tcp_sk(struct rcu_head *head)
+{
+       struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu);
+
+       sock_put(esk->sk);
+       kfree(esk);
+}
+
+static struct sock *esp_find_tcp_sk(struct xfrm_state *x)
+{
+       struct xfrm_encap_tmpl *encap = x->encap;
+       struct esp_tcp_sk *esk;
+       __be16 sport, dport;
+       struct sock *nsk;
+       struct sock *sk;
+
+       sk = rcu_dereference(x->encap_sk);
+       if (sk && sk->sk_state == TCP_ESTABLISHED)
+               return sk;
+
+       spin_lock_bh(&x->lock);
+       sport = encap->encap_sport;
+       dport = encap->encap_dport;
+       nsk = rcu_dereference_protected(x->encap_sk,
+                                       lockdep_is_held(&x->lock));
+       if (sk && sk == nsk) {
+               esk = kmalloc(sizeof(*esk), GFP_ATOMIC);
+               if (!esk) {
+                       spin_unlock_bh(&x->lock);
+                       return ERR_PTR(-ENOMEM);
+               }
+               RCU_INIT_POINTER(x->encap_sk, NULL);
+               esk->sk = sk;
+               call_rcu(&esk->rcu, esp_free_tcp_sk);
+       }
+       spin_unlock_bh(&x->lock);
+
+       sk = inet_lookup_established(xs_net(x), &tcp_hashinfo, x->id.daddr.a4,
+                                    dport, x->props.saddr.a4, sport, 0);
+       if (!sk)
+               return ERR_PTR(-ENOENT);
+
+       if (!tcp_is_ulp_esp(sk)) {
+               sock_put(sk);
+               return ERR_PTR(-EINVAL);
+       }
+
+       spin_lock_bh(&x->lock);
+       nsk = rcu_dereference_protected(x->encap_sk,
+                                       lockdep_is_held(&x->lock));
+       if (encap->encap_sport != sport ||
+           encap->encap_dport != dport) {
+               sock_put(sk);
+               sk = nsk ?: ERR_PTR(-EREMCHG);
+       } else if (sk == nsk) {
+               sock_put(sk);
+       } else {
+               rcu_assign_pointer(x->encap_sk, sk);
+       }
+       spin_unlock_bh(&x->lock);
+
+       return sk;
+}
+
+static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb)
+{
+       struct sock *sk;
+       int err;
+
+       rcu_read_lock();
+
+       sk = esp_find_tcp_sk(x);
+       err = PTR_ERR_OR_ZERO(sk);
+       if (err)
+               goto out;
+
+       bh_lock_sock(sk);
+       if (sock_owned_by_user(sk))
+               err = espintcp_queue_out(sk, skb);
+       else
+               err = espintcp_push_skb(sk, skb);
+       bh_unlock_sock(sk);
+
+out:
+       rcu_read_unlock();
+       return err;
+}
+
+static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk,
+                                  struct sk_buff *skb)
+{
+       struct dst_entry *dst = skb_dst(skb);
+       struct xfrm_state *x = dst->xfrm;
+
+       return esp_output_tcp_finish(x, skb);
+}
+
+static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
+{
+       int err;
+
+       local_bh_disable();
+       err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb);
+       local_bh_enable();
+
+       /* EINPROGRESS just happens to do the right thing.  It
+        * actually means that the skb has been consumed and
+        * isn't coming back.
+        */
+       return err ?: -EINPROGRESS;
+}
+#else
+static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb)
+{
+       kfree_skb(skb);
+
+       return -EOPNOTSUPP;
+}
+#endif
+
 static void esp_output_done(struct crypto_async_request *base, int err)
 {
        struct sk_buff *skb = base->data;
@@ -147,7 +275,11 @@ static void esp_output_done(struct crypto_async_request *base, int err)
                secpath_reset(skb);
                xfrm_dev_resume(skb);
        } else {
-               xfrm_output_resume(skb, err);
+               if (!err &&
+                   x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
+                       esp_output_tail_tcp(x, skb);
+               else
+                       xfrm_output_resume(skb, err);
        }
 }
 
@@ -236,7 +368,7 @@ static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb,
        unsigned int len;
 
        len = skb->len + esp->tailen - skb_transport_offset(skb);
-       if (len + sizeof(struct iphdr) >= IP_MAX_MTU)
+       if (len + sizeof(struct iphdr) > IP_MAX_MTU)
                return ERR_PTR(-EMSGSIZE);
 
        uh = (struct udphdr *)esp->esph;
@@ -256,6 +388,41 @@ static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb,
        return (struct ip_esp_hdr *)(uh + 1);
 }
 
+#ifdef CONFIG_INET_ESPINTCP
+static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
+                                                   struct sk_buff *skb,
+                                                   struct esp_info *esp)
+{
+       __be16 *lenp = (void *)esp->esph;
+       struct ip_esp_hdr *esph;
+       unsigned int len;
+       struct sock *sk;
+
+       len = skb->len + esp->tailen - skb_transport_offset(skb);
+       if (len > IP_MAX_MTU)
+               return ERR_PTR(-EMSGSIZE);
+
+       rcu_read_lock();
+       sk = esp_find_tcp_sk(x);
+       rcu_read_unlock();
+
+       if (IS_ERR(sk))
+               return ERR_CAST(sk);
+
+       *lenp = htons(len);
+       esph = (struct ip_esp_hdr *)(lenp + 1);
+
+       return esph;
+}
+#else
+static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x,
+                                                   struct sk_buff *skb,
+                                                   struct esp_info *esp)
+{
+       return ERR_PTR(-EOPNOTSUPP);
+}
+#endif
+
 static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb,
                            struct esp_info *esp)
 {
@@ -276,6 +443,9 @@ static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb,
        case UDP_ENCAP_ESPINUDP_NON_IKE:
                esph = esp_output_udp_encap(skb, encap_type, esp, sport, dport);
                break;
+       case TCP_ENCAP_ESPINTCP:
+               esph = esp_output_tcp_encap(x, skb, esp);
+               break;
        }
 
        if (IS_ERR(esph))
@@ -296,7 +466,7 @@ int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *
        struct sk_buff *trailer;
        int tailen = esp->tailen;
 
-       /* this is non-NULL only with UDP Encapsulation */
+       /* this is non-NULL only with TCP/UDP Encapsulation */
        if (x->encap) {
                int err = esp_output_encap(x, skb, esp);
 
@@ -491,6 +661,9 @@ int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *
        if (sg != dsg)
                esp_ssg_unref(x, tmp);
 
+       if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP)
+               err = esp_output_tail_tcp(x, skb);
+
 error_free:
        kfree(tmp);
 error:
@@ -617,10 +790,14 @@ int esp_input_done2(struct sk_buff *skb, int err)
 
        if (x->encap) {
                struct xfrm_encap_tmpl *encap = x->encap;
+               struct tcphdr *th = (void *)(skb_network_header(skb) + ihl);
                struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
                __be16 source;
 
                switch (x->encap->encap_type) {
+               case TCP_ENCAP_ESPINTCP:
+                       source = th->source;
+                       break;
                case UDP_ENCAP_ESPINUDP:
                case UDP_ENCAP_ESPINUDP_NON_IKE:
                        source = uh->source;
@@ -1017,6 +1194,14 @@ static int esp_init_state(struct xfrm_state *x)
                case UDP_ENCAP_ESPINUDP_NON_IKE:
                        x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32);
                        break;
+#ifdef CONFIG_INET_ESPINTCP
+               case TCP_ENCAP_ESPINTCP:
+                       /* only the length field, TCP encap is done by
+                        * the socket
+                        */
+                       x->props.header_len += 2;
+                       break;
+#endif
                }
        }