]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - ssh.c
Rewrite agent forwarding to serialise requests.
[PuTTY.git] / ssh.c
diff --git a/ssh.c b/ssh.c
index 2212c2b77764cc55505363352fa054ba239ab846..a214aa1f9b7985f16eee25168d8c3d0874b2eca7 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -573,10 +573,7 @@ struct ssh_channel {
     } v;
     union {
        struct ssh_agent_channel {
-           unsigned char *message;
-           unsigned char msglen[4];
-           unsigned lensofar, totallen;
-            int outstanding_requests;
+            bufchain inbuffer;
             agent_pending_query *pending;
        } a;
        struct ssh_x11_channel {
@@ -3780,6 +3777,8 @@ static void ssh_throttle_conn(Ssh ssh, int adjust)
     }
 }
 
+static void ssh_agentf_try_forward(struct ssh_channel *c);
+
 /*
  * Throttle or unthrottle _all_ local data streams (for when sends
  * on the SSH connection itself back up).
@@ -3806,7 +3805,12 @@ static void ssh_throttle_all(Ssh ssh, int enable, int bufsize)
            x11_override_throttle(c->u.x11.xconn, enable);
            break;
          case CHAN_AGENT:
-           /* Agent channels require no buffer management. */
+           /* Agent forwarding channels are buffer-managed by
+             * checking ssh->throttled_all in ssh_agentf_try_forward.
+             * So at the moment we _un_throttle again, we must make an
+             * attempt to do something. */
+            if (!enable)
+                ssh_agentf_try_forward(c);
            break;
          case CHAN_SOCKDATA:
            pfd_override_throttle(c->u.pfd.pf, enable);
@@ -3848,29 +3852,113 @@ static void ssh_dialog_callback(void *sshv, int ret)
     ssh_process_queued_incoming_data(ssh);
 }
 
-static void ssh_agentf_callback(void *cv, void *reply, int replylen)
+static void ssh_agentf_got_response(struct ssh_channel *c,
+                                    void *reply, int replylen)
 {
-    struct ssh_channel *c = (struct ssh_channel *)cv;
-    const void *sentreply = reply;
-
     c->u.a.pending = NULL;
-    c->u.a.outstanding_requests--;
-    if (!sentreply) {
-       /* Fake SSH_AGENT_FAILURE. */
-       sentreply = "\0\0\0\1\5";
+
+    if (!reply) {
+       /* The real agent didn't send any kind of reply at all for
+         * some reason, so fake an SSH_AGENT_FAILURE. */
+       reply = "\0\0\0\1\5";
        replylen = 5;
     }
-    ssh_send_channel_data(c, sentreply, replylen);
-    if (reply)
-       sfree(reply);
+
+    ssh_send_channel_data(c, reply, replylen);
+}
+
+static void ssh_agentf_callback(void *cv, void *reply, int replylen);
+
+static void ssh_agentf_try_forward(struct ssh_channel *c)
+{
+    unsigned datalen, lengthfield, messagelen;
+    unsigned char *message;
+    unsigned char msglen[4];
+    void *reply;
+    int replylen;
+
     /*
-     * If we've already seen an incoming EOF but haven't sent an
-     * outgoing one, this may be the moment to send it.
+     * Don't try to parallelise agent requests. Wait for each one to
+     * return before attempting the next.
      */
-    if (c->u.a.outstanding_requests == 0 && (c->closes & CLOSES_RCVD_EOF))
+    if (c->u.a.pending)
+        return;
+
+    /*
+     * If the outgoing side of the channel connection is currently
+     * throttled (for any reason, either that channel's window size or
+     * the entire SSH connection being throttled), don't submit any
+     * new forwarded requests to the real agent. This causes the input
+     * side of the agent forwarding not to be emptied, exerting the
+     * required back-pressure on the remote client, and encouraging it
+     * to read our responses before sending too many more requests.
+     */
+    if (c->ssh->throttled_all ||
+        (c->ssh->version == 2 && c->v.v2.remwindow == 0))
+        return;
+
+    while (1) {
+        /*
+         * Try to extract a complete message from the input buffer.
+         */
+        datalen = bufchain_size(&c->u.a.inbuffer);
+        if (datalen < 4)
+            break;         /* not even a length field available yet */
+
+        bufchain_fetch(&c->u.a.inbuffer, msglen, 4);
+        lengthfield = GET_32BIT(msglen);
+        if (lengthfield > datalen - 4)
+            break;          /* a whole message is not yet available */
+
+        messagelen = lengthfield + 4;
+
+        message = snewn(messagelen, unsigned char);
+        bufchain_fetch(&c->u.a.inbuffer, message, messagelen);
+        bufchain_consume(&c->u.a.inbuffer, messagelen);
+        c->u.a.pending = agent_query(
+            message, messagelen, &reply, &replylen, ssh_agentf_callback, c);
+        sfree(message);
+
+        if (c->u.a.pending)
+            return;   /* agent_query promised to reply in due course */
+
+        /*
+         * If the agent gave us an answer immediately, pass it
+         * straight on and go round this loop again.
+         */
+        ssh_agentf_got_response(c, reply, replylen);
+    }
+
+    /*
+     * If we get here (i.e. we left the above while loop via 'break'
+     * rather than 'return'), that means we've determined that the
+     * input buffer for the agent forwarding connection doesn't
+     * contain a complete request.
+     *
+     * So if there's potentially more data to come, we can return now,
+     * and wait for the remote client to send it. But if the remote
+     * has sent EOF, it would be a mistake to do that, because we'd be
+     * waiting a long time. So this is the moment to check for EOF,
+     * and respond appropriately.
+     */
+    if (c->closes & CLOSES_RCVD_EOF)
         sshfwd_write_eof(c);
 }
 
+static void ssh_agentf_callback(void *cv, void *reply, int replylen)
+{
+    struct ssh_channel *c = (struct ssh_channel *)cv;
+
+    ssh_agentf_got_response(c, reply, replylen);
+    sfree(reply);
+
+    /*
+     * Now try to extract and send further messages from the channel's
+     * input-side buffer.
+     */
+    ssh_agentf_try_forward(c);
+}
+
 /*
  * Client-initiated disconnection. Send a DISCONNECT if `wire_reason'
  * non-NULL, otherwise just close the connection. `client_reason' == NULL
@@ -5553,10 +5641,8 @@ static void ssh1_smsg_agent_open(Ssh ssh, struct Packet *pktin)
        c->remoteid = remoteid;
        c->halfopen = FALSE;
        c->type = CHAN_AGENT;   /* identify channel type */
-       c->u.a.lensofar = 0;
-       c->u.a.message = NULL;
        c->u.a.pending = NULL;
-       c->u.a.outstanding_requests = 0;
+        bufchain_init(&c->u.a.inbuffer);
        send_packet(ssh, SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
                    PKT_INT, c->remoteid, PKT_INT, c->localid,
                    PKT_END);
@@ -5697,42 +5783,18 @@ static void ssh1_msg_channel_close(Ssh ssh, struct Packet *pktin)
 static int ssh_agent_channel_data(struct ssh_channel *c, char *data,
                                  int length)
 {
-    while (length > 0) {
-       if (c->u.a.lensofar < 4) {
-           unsigned int l = min(4 - c->u.a.lensofar, (unsigned)length);
-           memcpy(c->u.a.msglen + c->u.a.lensofar, data, l);
-           data += l;
-           length -= l;
-           c->u.a.lensofar += l;
-       }
-       if (c->u.a.lensofar == 4) {
-           c->u.a.totallen = 4 + GET_32BIT(c->u.a.msglen);
-           c->u.a.message = snewn(c->u.a.totallen, unsigned char);
-           memcpy(c->u.a.message, c->u.a.msglen, 4);
-       }
-       if (c->u.a.lensofar >= 4 && length > 0) {
-           unsigned int l = min(c->u.a.totallen - c->u.a.lensofar,
-                                (unsigned)length);
-           memcpy(c->u.a.message + c->u.a.lensofar, data, l);
-           data += l;
-           length -= l;
-           c->u.a.lensofar += l;
-       }
-       if (c->u.a.lensofar == c->u.a.totallen) {
-           void *reply;
-           int replylen;
-            c->u.a.outstanding_requests++;
-            c->u.a.pending = agent_query(
-                c->u.a.message, c->u.a.totallen, &reply, &replylen,
-                ssh_agentf_callback, c);
-            if (!c->u.a.pending)
-                ssh_agentf_callback(c, reply, replylen);
-           sfree(c->u.a.message);
-            c->u.a.message = NULL;
-           c->u.a.lensofar = 0;
-       }
-    }
-    return 0;   /* agent channels never back up */
+    bufchain_add(&c->u.a.inbuffer, data, length);
+    ssh_agentf_try_forward(c);
+
+    /*
+     * We exert back-pressure on an agent forwarding client if and
+     * only if we're waiting for the response to an asynchronous agent
+     * request. This prevents the client running out of window while
+     * receiving the _first_ message, but means that if any message
+     * takes time to process, the client will be discouraged from
+     * sending an endless stream of further ones after it.
+     */
+    return (c->u.a.pending ? bufchain_size(&c->u.a.inbuffer) : 0);
 }
 
 static int ssh_channel_data(struct ssh_channel *c, int is_stderr,
@@ -7733,8 +7795,9 @@ static void ssh2_try_send_and_unthrottle(Ssh ssh, struct ssh_channel *c)
            x11_unthrottle(c->u.x11.xconn);
            break;
          case CHAN_AGENT:
-           /* agent sockets are request/response and need no
-            * buffer management */
+            /* Now that we've successfully sent all the outgoing
+             * replies we had, try to process more incoming data. */
+            ssh_agentf_try_forward(c);
            break;
          case CHAN_SOCKDATA:
            pfd_unthrottle(c->u.pfd.pf);
@@ -8160,7 +8223,8 @@ static void ssh_channel_close_local(struct ssh_channel *c, char const *reason)
       case CHAN_AGENT:
         if (c->u.a.pending)
             agent_cancel_query(c->u.a.pending);
-        sfree(c->u.a.message);
+        bufchain_clear(&c->u.a.inbuffer);
+       msg = "Agent-forwarding connection closed";
         break;
       case CHAN_SOCKDATA:
         assert(c->u.pfd.pf != NULL);
@@ -8248,10 +8312,10 @@ static void ssh_channel_got_eof(struct ssh_channel *c)
        assert(c->u.x11.xconn != NULL);
        x11_send_eof(c->u.x11.xconn);
     } else if (c->type == CHAN_AGENT) {
-        if (c->u.a.outstanding_requests == 0) {
-            /* Manufacture an outgoing EOF in response to the incoming one. */
-            sshfwd_write_eof(c);
-        }
+        /* Just call try_forward, which will respond to the EOF now if
+         * appropriate, or wait until the queue of outstanding
+         * requests is dealt with if not */
+        ssh_agentf_try_forward(c);
     } else if (c->type == CHAN_SOCKDATA) {
        assert(c->u.pfd.pf != NULL);
        pfd_send_eof(c->u.pfd.pf);
@@ -8805,10 +8869,8 @@ static void ssh2_msg_channel_open(Ssh ssh, struct Packet *pktin)
            error = "Agent forwarding is not enabled";
        else {
            c->type = CHAN_AGENT;       /* identify channel type */
-           c->u.a.lensofar = 0;
-            c->u.a.message = NULL;
+            bufchain_init(&c->u.a.inbuffer);
             c->u.a.pending = NULL;
-            c->u.a.outstanding_requests = 0;
        }
     } else {
        error = "Unsupported channel type requested";