+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ * - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ * - remote CID == local_cid (guest->host transport) will use guest->host
+ * transport for loopback (host->guest transports don't support loopback);
+ * - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+ const struct vsock_transport *new_transport;
+ struct sock *sk = sk_vsock(vsk);
+ unsigned int remote_cid = vsk->remote_addr.svm_cid;
+
+ switch (sk->sk_type) {
+ case SOCK_DGRAM:
+ new_transport = transport_dgram;
+ break;
+ case SOCK_STREAM:
+ if (remote_cid <= VMADDR_CID_HOST ||
+ (transport_g2h &&
+ remote_cid == transport_g2h->get_local_cid()))
+ new_transport = transport_g2h;
+ else
+ new_transport = transport_h2g;
+ break;
+ default:
+ return -ESOCKTNOSUPPORT;
+ }
+
+ if (vsk->transport) {
+ if (vsk->transport == new_transport)
+ return 0;
+
+ vsk->transport->release(vsk);
+ vsk->transport->destruct(vsk);
+ }
+
+ if (!new_transport)
+ return -ENODEV;
+
+ vsk->transport = new_transport;
+
+ return vsk->transport->init(vsk, psk);
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+ if (transport_g2h && cid == transport_g2h->get_local_cid())
+ return true;
+
+ if (transport_h2g && cid == VMADDR_CID_HOST)
+ return true;
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+