]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - ssh.c
Make asynchronous agent_query() requests cancellable.
[PuTTY.git] / ssh.c
diff --git a/ssh.c b/ssh.c
index 7e74fb44915ef161f5abf8ede3bae514cdb0dd10..2212c2b77764cc55505363352fa054ba239ab846 100644 (file)
--- a/ssh.c
+++ b/ssh.c
@@ -577,6 +577,7 @@ struct ssh_channel {
            unsigned char msglen[4];
            unsigned lensofar, totallen;
             int outstanding_requests;
+            agent_pending_query *pending;
        } a;
        struct ssh_x11_channel {
            struct X11Connection *xconn;
@@ -985,6 +986,13 @@ struct ssh_tag {
      * with a newly cross-certified host key.
      */
     int cross_certifying;
+
+    /*
+     * Any asynchronous query to our SSH agent that we might have in
+     * flight from the main authentication loop. (Queries from
+     * agent-forwarding channels live in their channel structure.)
+     */
+    agent_pending_query *auth_agent_query;
 };
 
 static const char *ssh_pkt_type(Ssh ssh, int type)
@@ -3811,6 +3819,8 @@ static void ssh_agent_callback(void *sshv, void *reply, int replylen)
 {
     Ssh ssh = (Ssh) sshv;
 
+    ssh->auth_agent_query = NULL;
+
     ssh->agent_response = reply;
     ssh->agent_response_len = replylen;
 
@@ -3843,6 +3853,7 @@ static void ssh_agentf_callback(void *cv, void *reply, int replylen)
     struct ssh_channel *c = (struct ssh_channel *)cv;
     const void *sentreply = reply;
 
+    c->u.a.pending = NULL;
     c->u.a.outstanding_requests--;
     if (!sentreply) {
        /* Fake SSH_AGENT_FAILURE. */
@@ -4362,8 +4373,9 @@ static int do_ssh1_login(Ssh ssh, const unsigned char *in, int inlen,
            /* Request the keys held by the agent. */
            PUT_32BIT(s->request, 1);
            s->request[4] = SSH1_AGENTC_REQUEST_RSA_IDENTITIES;
-           if (!agent_query(s->request, 5, &r, &s->responselen,
-                            ssh_agent_callback, ssh)) {
+            ssh->auth_agent_query = agent_query(
+                s->request, 5, &r, &s->responselen, ssh_agent_callback, ssh);
+           if (ssh->auth_agent_query) {
                do {
                    crReturn(0);
                    if (pktin) {
@@ -4468,8 +4480,10 @@ static int do_ssh1_login(Ssh ssh, const unsigned char *in, int inlen,
                        memcpy(q, s->session_id, 16);
                        q += 16;
                        PUT_32BIT(q, 1);        /* response format */
-                       if (!agent_query(agentreq, len + 4, &vret, &retlen,
-                                        ssh_agent_callback, ssh)) {
+                        ssh->auth_agent_query = agent_query(
+                            agentreq, len + 4, &vret, &retlen,
+                            ssh_agent_callback, ssh);
+                       if (ssh->auth_agent_query) {
                            sfree(agentreq);
                            do {
                                crReturn(0);
@@ -5541,6 +5555,7 @@ static void ssh1_smsg_agent_open(Ssh ssh, struct Packet *pktin)
        c->type = CHAN_AGENT;   /* identify channel type */
        c->u.a.lensofar = 0;
        c->u.a.message = NULL;
+       c->u.a.pending = NULL;
        c->u.a.outstanding_requests = 0;
        send_packet(ssh, SSH1_MSG_CHANNEL_OPEN_CONFIRMATION,
                    PKT_INT, c->remoteid, PKT_INT, c->localid,
@@ -5707,9 +5722,11 @@ static int ssh_agent_channel_data(struct ssh_channel *c, char *data,
            void *reply;
            int replylen;
             c->u.a.outstanding_requests++;
-           if (agent_query(c->u.a.message, c->u.a.totallen, &reply, &replylen,
-                           ssh_agentf_callback, c))
-               ssh_agentf_callback(c, reply, replylen);
+            c->u.a.pending = agent_query(
+                c->u.a.message, c->u.a.totallen, &reply, &replylen,
+                ssh_agentf_callback, c);
+            if (!c->u.a.pending)
+                ssh_agentf_callback(c, reply, replylen);
            sfree(c->u.a.message);
             c->u.a.message = NULL;
            c->u.a.lensofar = 0;
@@ -8141,6 +8158,8 @@ static void ssh_channel_close_local(struct ssh_channel *c, char const *reason)
         msg = "Forwarded X11 connection terminated";
         break;
       case CHAN_AGENT:
+        if (c->u.a.pending)
+            agent_cancel_query(c->u.a.pending);
         sfree(c->u.a.message);
         break;
       case CHAN_SOCKDATA:
@@ -8788,6 +8807,7 @@ static void ssh2_msg_channel_open(Ssh ssh, struct Packet *pktin)
            c->type = CHAN_AGENT;       /* identify channel type */
            c->u.a.lensofar = 0;
             c->u.a.message = NULL;
+            c->u.a.pending = NULL;
             c->u.a.outstanding_requests = 0;
        }
     } else {
@@ -9291,8 +9311,10 @@ static void do_ssh2_authconn(Ssh ssh, const unsigned char *in, int inlen,
            /* Request the keys held by the agent. */
            PUT_32BIT(s->agent_request, 1);
            s->agent_request[4] = SSH2_AGENTC_REQUEST_IDENTITIES;
-           if (!agent_query(s->agent_request, 5, &r, &s->agent_responselen,
-                            ssh_agent_callback, ssh)) {
+            ssh->auth_agent_query = agent_query(
+                s->agent_request, 5, &r, &s->agent_responselen,
+                ssh_agent_callback, ssh);
+           if (ssh->auth_agent_query) {
                do {
                    crReturnV;
                    if (pktin) {
@@ -9730,9 +9752,10 @@ static void do_ssh2_authconn(Ssh ssh, const unsigned char *in, int inlen,
                    s->q += s->pktout->length - 5;
                    /* And finally the (zero) flags word. */
                    PUT_32BIT(s->q, 0);
-                   if (!agent_query(s->agentreq, s->len + 4,
-                                    &vret, &s->retlen,
-                                    ssh_agent_callback, ssh)) {
+                    ssh->auth_agent_query = agent_query(
+                        s->agentreq, s->len + 4, &vret, &s->retlen,
+                        ssh_agent_callback, ssh);
+                    if (ssh->auth_agent_query) {
                        do {
                            crReturnV;
                            if (pktin) {
@@ -11168,6 +11191,8 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
                                                      CONF_ssh_rekey_data));
     ssh->kex_in_progress = FALSE;
 
+    ssh->auth_agent_query = NULL;
+
 #ifndef NO_GSSAPI
     ssh->gsslibs = NULL;
 #endif
@@ -11283,6 +11308,10 @@ static void ssh_free(void *handle)
     bufchain_clear(&ssh->queued_incoming_data);
     sfree(ssh->username);
     conf_free(ssh->conf);
+
+    if (ssh->auth_agent_query)
+        agent_cancel_query(ssh->auth_agent_query);
+
 #ifndef NO_GSSAPI
     if (ssh->gsslibs)
        ssh_gss_cleanup(ssh->gsslibs);