]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - sshrsa.c
first pass
[PuTTY.git] / sshrsa.c
index 0c1b2ef5a421d642a5fd57e26e26ad02a1902216..e565a64ac791ff7be104a17f27814f4962f32fc7 100644 (file)
--- a/sshrsa.c
+++ b/sshrsa.c
 #include "ssh.h"
 #include "misc.h"
 
-int makekey(unsigned char *data, int len, struct RSAKey *result,
-           unsigned char **keystr, int order)
+int makekey(const unsigned char *data, int len, struct RSAKey *result,
+           const unsigned char **keystr, int order)
 {
-    unsigned char *p = data;
+    const unsigned char *p = data;
     int i, n;
 
     if (len < 4)
@@ -59,7 +59,7 @@ int makekey(unsigned char *data, int len, struct RSAKey *result,
     return p - data;
 }
 
-int makeprivate(unsigned char *data, int len, struct RSAKey *result)
+int makeprivate(const unsigned char *data, int len, struct RSAKey *result)
 {
     return ssh1_read_bignum(data, len, &result->private_exponent);
 }
@@ -110,7 +110,7 @@ static void sha512_mpint(SHA512_State * s, Bignum b)
        lenbuf[0] = bignum_byte(b, len);
        SHA512_Bytes(s, lenbuf, 1);
     }
-    memset(lenbuf, 0, sizeof(lenbuf));
+    smemclr(lenbuf, sizeof(lenbuf));
 }
 
 /*
@@ -264,6 +264,7 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
            bitsleft--;
            bignum_set_bit(random, bits, v);
        }
+        bn_restore_invariant(random);
 
        /*
         * Now check that this number is strictly greater than
@@ -273,9 +274,18 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
            bignum_cmp(random, key->modulus) >= 0) {
            freebn(random);
            continue;
-       } else {
-           break;
        }
+
+        /*
+         * Also, make sure it has an inverse mod modulus.
+         */
+        random_inverse = modinv(random, key->modulus);
+        if (!random_inverse) {
+           freebn(random);
+           continue;
+        }
+
+        break;
     }
 
     /*
@@ -294,7 +304,6 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
      */
     random_encrypted = crt_modpow(random, key->exponent,
                                   key->modulus, key->p, key->q, key->iqmp);
-    random_inverse = modinv(random, key->modulus);
     input_blinded = modmul(input, random_encrypted, key->modulus);
     ret_blinded = crt_modpow(input_blinded, key->private_exponent,
                              key->modulus, key->p, key->q, key->iqmp);
@@ -413,16 +422,18 @@ int rsa_verify(struct RSAKey *key)
     pm1 = copybn(key->p);
     decbn(pm1);
     ed = modmul(key->exponent, key->private_exponent, pm1);
+    freebn(pm1);
     cmp = bignum_cmp(ed, One);
-    sfree(ed);
+    freebn(ed);
     if (cmp != 0)
        return 0;
 
     qm1 = copybn(key->q);
     decbn(qm1);
     ed = modmul(key->exponent, key->private_exponent, qm1);
+    freebn(qm1);
     cmp = bignum_cmp(ed, One);
-    sfree(ed);
+    freebn(ed);
     if (cmp != 0)
        return 0;
 
@@ -441,6 +452,8 @@ int rsa_verify(struct RSAKey *key)
 
        freebn(key->iqmp);
        key->iqmp = modinv(key->q, key->p);
+        if (!key->iqmp)
+            return 0;
     }
 
     /*
@@ -448,7 +461,7 @@ int rsa_verify(struct RSAKey *key)
      */
     n = modmul(key->iqmp, key->q, key->p);
     cmp = bignum_cmp(n, One);
-    sfree(n);
+    freebn(n);
     if (cmp != 0)
        return 0;
 
@@ -520,12 +533,15 @@ void freersakey(struct RSAKey *key)
  * Implementation of the ssh-rsa signing key type. 
  */
 
-static void getstring(char **data, int *datalen, char **p, int *length)
+static void getstring(const char **data, int *datalen,
+                      const char **p, int *length)
 {
     *p = NULL;
     if (*datalen < 4)
        return;
-    *length = GET_32BIT(*data);
+    *length = toint(GET_32BIT(*data));
+    if (*length < 0)
+        return;
     *datalen -= 4;
     *data += 4;
     if (*datalen < *length)
@@ -534,9 +550,9 @@ static void getstring(char **data, int *datalen, char **p, int *length)
     *data += *length;
     *datalen -= *length;
 }
-static Bignum getmp(char **data, int *datalen)
+static Bignum getmp(const char **data, int *datalen)
 {
-    char *p;
+    const char *p;
     int length;
     Bignum b;
 
@@ -547,15 +563,16 @@ static Bignum getmp(char **data, int *datalen)
     return b;
 }
 
-static void *rsa2_newkey(char *data, int len)
+static void rsa2_freekey(void *key);   /* forward reference */
+
+static void *rsa2_newkey(const struct ssh_signkey *self,
+                         const char *data, int len)
 {
-    char *p;
+    const char *p;
     int slen;
     struct RSAKey *rsa;
 
     rsa = snew(struct RSAKey);
-    if (!rsa)
-       return NULL;
     getstring(&data, &len, &p, &slen);
 
     if (!p || slen != 7 || memcmp(p, "ssh-rsa", 7)) {
@@ -568,6 +585,11 @@ static void *rsa2_newkey(char *data, int len)
     rsa->p = rsa->q = rsa->iqmp = NULL;
     rsa->comment = NULL;
 
+    if (!rsa->exponent || !rsa->modulus) {
+        rsa2_freekey(rsa);
+        return NULL;
+    }
+
     return rsa;
 }
 
@@ -664,13 +686,14 @@ static unsigned char *rsa2_private_blob(void *key, int *len)
     return blob;
 }
 
-static void *rsa2_createkey(unsigned char *pub_blob, int pub_len,
-                           unsigned char *priv_blob, int priv_len)
+static void *rsa2_createkey(const struct ssh_signkey *self,
+                            const unsigned char *pub_blob, int pub_len,
+                           const unsigned char *priv_blob, int priv_len)
 {
     struct RSAKey *rsa;
-    char *pb = (char *) priv_blob;
+    const char *pb = (const char *) priv_blob;
 
-    rsa = rsa2_newkey((char *) pub_blob, pub_len);
+    rsa = rsa2_newkey(self, (char *) pub_blob, pub_len);
     rsa->private_exponent = getmp(&pb, &priv_len);
     rsa->p = getmp(&pb, &priv_len);
     rsa->q = getmp(&pb, &priv_len);
@@ -684,14 +707,13 @@ static void *rsa2_createkey(unsigned char *pub_blob, int pub_len,
     return rsa;
 }
 
-static void *rsa2_openssh_createkey(unsigned char **blob, int *len)
+static void *rsa2_openssh_createkey(const struct ssh_signkey *self,
+                                    const unsigned char **blob, int *len)
 {
-    char **b = (char **) blob;
+    const char **b = (const char **) blob;
     struct RSAKey *rsa;
 
     rsa = snew(struct RSAKey);
-    if (!rsa)
-       return NULL;
     rsa->comment = NULL;
 
     rsa->modulus = getmp(b, len);
@@ -703,13 +725,12 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len)
 
     if (!rsa->modulus || !rsa->exponent || !rsa->private_exponent ||
        !rsa->iqmp || !rsa->p || !rsa->q) {
-       sfree(rsa->modulus);
-       sfree(rsa->exponent);
-       sfree(rsa->private_exponent);
-       sfree(rsa->iqmp);
-       sfree(rsa->p);
-       sfree(rsa->q);
-       sfree(rsa);
+        rsa2_freekey(rsa);
+       return NULL;
+    }
+
+    if (!rsa_verify(rsa)) {
+       rsa2_freekey(rsa);
        return NULL;
     }
 
@@ -745,53 +766,21 @@ static int rsa2_openssh_fmtkey(void *key, unsigned char *blob, int len)
     return bloblen;
 }
 
-static int rsa2_pubkey_bits(void *blob, int len)
+static int rsa2_pubkey_bits(const struct ssh_signkey *self,
+                            const void *blob, int len)
 {
     struct RSAKey *rsa;
     int ret;
 
-    rsa = rsa2_newkey((char *) blob, len);
+    rsa = rsa2_newkey(self, (const char *) blob, len);
+    if (!rsa)
+       return -1;
     ret = bignum_bitcount(rsa->modulus);
     rsa2_freekey(rsa);
 
     return ret;
 }
 
-static char *rsa2_fingerprint(void *key)
-{
-    struct RSAKey *rsa = (struct RSAKey *) key;
-    struct MD5Context md5c;
-    unsigned char digest[16], lenbuf[4];
-    char buffer[16 * 3 + 40];
-    char *ret;
-    int numlen, i;
-
-    MD5Init(&md5c);
-    MD5Update(&md5c, (unsigned char *)"\0\0\0\7ssh-rsa", 11);
-
-#define ADD_BIGNUM(bignum) \
-    numlen = (bignum_bitcount(bignum)+8)/8; \
-    PUT_32BIT(lenbuf, numlen); MD5Update(&md5c, lenbuf, 4); \
-    for (i = numlen; i-- ;) { \
-        unsigned char c = bignum_byte(bignum, i); \
-        MD5Update(&md5c, &c, 1); \
-    }
-    ADD_BIGNUM(rsa->exponent);
-    ADD_BIGNUM(rsa->modulus);
-#undef ADD_BIGNUM
-
-    MD5Final(digest, &md5c);
-
-    sprintf(buffer, "ssh-rsa %d ", bignum_bitcount(rsa->modulus));
-    for (i = 0; i < 16; i++)
-       sprintf(buffer + strlen(buffer), "%s%02x", i ? ":" : "",
-               digest[i]);
-    ret = snewn(strlen(buffer) + 1, char);
-    if (ret)
-       strcpy(ret, buffer);
-    return ret;
-}
-
 /*
  * This is the magic ASN.1/DER prefix that goes in the decoded
  * signature, between the string of FFs and the actual SHA hash
@@ -823,12 +812,12 @@ static const unsigned char asn1_weird_stuff[] = {
 
 #define ASN1_LEN ( (int) sizeof(asn1_weird_stuff) )
 
-static int rsa2_verifysig(void *key, char *sig, int siglen,
-                         char *data, int datalen)
+static int rsa2_verifysig(void *key, const char *sig, int siglen,
+                         const char *data, int datalen)
 {
     struct RSAKey *rsa = (struct RSAKey *) key;
     Bignum in, out;
-    char *p;
+    const char *p;
     int slen;
     int bytes, i, j, ret;
     unsigned char hash[20];
@@ -838,6 +827,8 @@ static int rsa2_verifysig(void *key, char *sig, int siglen,
        return 0;
     }
     in = getmp(&sig, &siglen);
+    if (!in)
+        return 0;
     out = modpow(in, rsa->exponent, rsa->modulus);
     freebn(in);
 
@@ -871,7 +862,7 @@ static int rsa2_verifysig(void *key, char *sig, int siglen,
     return ret;
 }
 
-static unsigned char *rsa2_sign(void *key, char *data, int datalen,
+static unsigned char *rsa2_sign(void *key, const char *data, int datalen,
                                int *siglen)
 {
     struct RSAKey *rsa = (struct RSAKey *) key;
@@ -923,17 +914,18 @@ const struct ssh_signkey ssh_rsa = {
     rsa2_createkey,
     rsa2_openssh_createkey,
     rsa2_openssh_fmtkey,
+    6 /* n,e,d,iqmp,q,p */,
     rsa2_pubkey_bits,
-    rsa2_fingerprint,
     rsa2_verifysig,
     rsa2_sign,
     "ssh-rsa",
-    "rsa2"
+    "rsa2",
+    NULL,
 };
 
 void *ssh_rsakex_newkey(char *data, int len)
 {
-    return rsa2_newkey(data, len);
+    return rsa2_newkey(&ssh_rsa, data, len);
 }
 
 void ssh_rsakex_freekey(void *key)
@@ -1068,11 +1060,11 @@ void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen,
 }
 
 static const struct ssh_kex ssh_rsa_kex_sha1 = {
-    "rsa1024-sha1", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha1
+    "rsa1024-sha1", NULL, KEXTYPE_RSA, &ssh_sha1, NULL,
 };
 
 static const struct ssh_kex ssh_rsa_kex_sha256 = {
-    "rsa2048-sha256", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha256
+    "rsa2048-sha256", NULL, KEXTYPE_RSA, &ssh_sha256, NULL,
 };
 
 static const struct ssh_kex *const rsa_kex_list[] = {