X-Git-Url: https://asedeno.scripts.mit.edu/gitweb/?a=blobdiff_plain;f=sshrsa.c;h=e565a64ac791ff7be104a17f27814f4962f32fc7;hb=510f49e405e71ba5c97875e7a019364e1ef5fac9;hp=3c0feafe5b5700d56bacb5b782400af0692c7df6;hpb=ff294f4ffd306558dade2aa1192ac29ab440d454;p=PuTTY.git diff --git a/sshrsa.c b/sshrsa.c index 3c0feafe..e565a64a 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -10,10 +10,10 @@ #include "ssh.h" #include "misc.h" -int makekey(unsigned char *data, int len, struct RSAKey *result, - unsigned char **keystr, int order) +int makekey(const unsigned char *data, int len, struct RSAKey *result, + const unsigned char **keystr, int order) { - unsigned char *p = data; + const unsigned char *p = data; int i, n; if (len < 4) @@ -59,7 +59,7 @@ int makekey(unsigned char *data, int len, struct RSAKey *result, return p - data; } -int makeprivate(unsigned char *data, int len, struct RSAKey *result) +int makeprivate(const unsigned char *data, int len, struct RSAKey *result) { return ssh1_read_bignum(data, len, &result->private_exponent); } @@ -110,13 +110,87 @@ static void sha512_mpint(SHA512_State * s, Bignum b) lenbuf[0] = bignum_byte(b, len); SHA512_Bytes(s, lenbuf, 1); } - memset(lenbuf, 0, sizeof(lenbuf)); + smemclr(lenbuf, sizeof(lenbuf)); } /* - * This function is a wrapper on modpow(). It has the same effect - * as modpow(), but employs RSA blinding to protect against timing - * attacks. + * Compute (base ^ exp) % mod, provided mod == p * q, with p,q + * distinct primes, and iqmp is the multiplicative inverse of q mod p. + * Uses Chinese Remainder Theorem to speed computation up over the + * obvious implementation of a single big modpow. + */ +Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod, + Bignum p, Bignum q, Bignum iqmp) +{ + Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret; + + /* + * Reduce the exponent mod phi(p) and phi(q), to save time when + * exponentiating mod p and mod q respectively. Of course, since p + * and q are prime, phi(p) == p-1 and similarly for q. + */ + pm1 = copybn(p); + decbn(pm1); + qm1 = copybn(q); + decbn(qm1); + pexp = bigmod(exp, pm1); + qexp = bigmod(exp, qm1); + + /* + * Do the two modpows. + */ + presult = modpow(base, pexp, p); + qresult = modpow(base, qexp, q); + + /* + * Recombine the results. We want a value which is congruent to + * qresult mod q, and to presult mod p. + * + * We know that iqmp * q is congruent to 1 * mod p (by definition + * of iqmp) and to 0 mod q (obviously). So we start with qresult + * (which is congruent to qresult mod both primes), and add on + * (presult-qresult) * (iqmp * q) which adjusts it to be congruent + * to presult mod p without affecting its value mod q. + */ + if (bignum_cmp(presult, qresult) < 0) { + /* + * Can't subtract presult from qresult without first adding on + * p. + */ + Bignum tmp = presult; + presult = bigadd(presult, p); + freebn(tmp); + } + diff = bigsub(presult, qresult); + multiplier = bigmul(iqmp, q); + ret0 = bigmuladd(multiplier, diff, qresult); + + /* + * Finally, reduce the result mod n. + */ + ret = bigmod(ret0, mod); + + /* + * Free all the intermediate results before returning. + */ + freebn(pm1); + freebn(qm1); + freebn(pexp); + freebn(qexp); + freebn(presult); + freebn(qresult); + freebn(diff); + freebn(multiplier); + freebn(ret0); + + return ret; +} + +/* + * This function is a wrapper on modpow(). It has the same effect as + * modpow(), but employs RSA blinding to protect against timing + * attacks and also uses the Chinese Remainder Theorem (implemented + * above, in crt_modpow()) to speed up the main operation. */ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) { @@ -190,6 +264,7 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) bitsleft--; bignum_set_bit(random, bits, v); } + bn_restore_invariant(random); /* * Now check that this number is strictly greater than @@ -199,9 +274,18 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) bignum_cmp(random, key->modulus) >= 0) { freebn(random); continue; - } else { - break; } + + /* + * Also, make sure it has an inverse mod modulus. + */ + random_inverse = modinv(random, key->modulus); + if (!random_inverse) { + freebn(random); + continue; + } + + break; } /* @@ -218,10 +302,11 @@ static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key) * _y^d_, and use the _public_ exponent to compute (y^d)^e = y * from it, which is much faster to do. */ - random_encrypted = modpow(random, key->exponent, key->modulus); - random_inverse = modinv(random, key->modulus); + random_encrypted = crt_modpow(random, key->exponent, + key->modulus, key->p, key->q, key->iqmp); input_blinded = modmul(input, random_encrypted, key->modulus); - ret_blinded = modpow(input_blinded, key->private_exponent, key->modulus); + ret_blinded = crt_modpow(input_blinded, key->private_exponent, + key->modulus, key->p, key->q, key->iqmp); ret = modmul(ret_blinded, random_inverse, key->modulus); freebn(ret_blinded); @@ -337,16 +422,18 @@ int rsa_verify(struct RSAKey *key) pm1 = copybn(key->p); decbn(pm1); ed = modmul(key->exponent, key->private_exponent, pm1); + freebn(pm1); cmp = bignum_cmp(ed, One); - sfree(ed); + freebn(ed); if (cmp != 0) return 0; qm1 = copybn(key->q); decbn(qm1); ed = modmul(key->exponent, key->private_exponent, qm1); + freebn(qm1); cmp = bignum_cmp(ed, One); - sfree(ed); + freebn(ed); if (cmp != 0) return 0; @@ -365,6 +452,8 @@ int rsa_verify(struct RSAKey *key) freebn(key->iqmp); key->iqmp = modinv(key->q, key->p); + if (!key->iqmp) + return 0; } /* @@ -372,7 +461,7 @@ int rsa_verify(struct RSAKey *key) */ n = modmul(key->iqmp, key->q, key->p); cmp = bignum_cmp(n, One); - sfree(n); + freebn(n); if (cmp != 0) return 0; @@ -444,12 +533,15 @@ void freersakey(struct RSAKey *key) * Implementation of the ssh-rsa signing key type. */ -static void getstring(char **data, int *datalen, char **p, int *length) +static void getstring(const char **data, int *datalen, + const char **p, int *length) { *p = NULL; if (*datalen < 4) return; - *length = GET_32BIT(*data); + *length = toint(GET_32BIT(*data)); + if (*length < 0) + return; *datalen -= 4; *data += 4; if (*datalen < *length) @@ -458,9 +550,9 @@ static void getstring(char **data, int *datalen, char **p, int *length) *data += *length; *datalen -= *length; } -static Bignum getmp(char **data, int *datalen) +static Bignum getmp(const char **data, int *datalen) { - char *p; + const char *p; int length; Bignum b; @@ -471,15 +563,16 @@ static Bignum getmp(char **data, int *datalen) return b; } -static void *rsa2_newkey(char *data, int len) +static void rsa2_freekey(void *key); /* forward reference */ + +static void *rsa2_newkey(const struct ssh_signkey *self, + const char *data, int len) { - char *p; + const char *p; int slen; struct RSAKey *rsa; rsa = snew(struct RSAKey); - if (!rsa) - return NULL; getstring(&data, &len, &p, &slen); if (!p || slen != 7 || memcmp(p, "ssh-rsa", 7)) { @@ -492,6 +585,11 @@ static void *rsa2_newkey(char *data, int len) rsa->p = rsa->q = rsa->iqmp = NULL; rsa->comment = NULL; + if (!rsa->exponent || !rsa->modulus) { + rsa2_freekey(rsa); + return NULL; + } + return rsa; } @@ -588,13 +686,14 @@ static unsigned char *rsa2_private_blob(void *key, int *len) return blob; } -static void *rsa2_createkey(unsigned char *pub_blob, int pub_len, - unsigned char *priv_blob, int priv_len) +static void *rsa2_createkey(const struct ssh_signkey *self, + const unsigned char *pub_blob, int pub_len, + const unsigned char *priv_blob, int priv_len) { struct RSAKey *rsa; - char *pb = (char *) priv_blob; + const char *pb = (const char *) priv_blob; - rsa = rsa2_newkey((char *) pub_blob, pub_len); + rsa = rsa2_newkey(self, (char *) pub_blob, pub_len); rsa->private_exponent = getmp(&pb, &priv_len); rsa->p = getmp(&pb, &priv_len); rsa->q = getmp(&pb, &priv_len); @@ -608,14 +707,13 @@ static void *rsa2_createkey(unsigned char *pub_blob, int pub_len, return rsa; } -static void *rsa2_openssh_createkey(unsigned char **blob, int *len) +static void *rsa2_openssh_createkey(const struct ssh_signkey *self, + const unsigned char **blob, int *len) { - char **b = (char **) blob; + const char **b = (const char **) blob; struct RSAKey *rsa; rsa = snew(struct RSAKey); - if (!rsa) - return NULL; rsa->comment = NULL; rsa->modulus = getmp(b, len); @@ -627,13 +725,12 @@ static void *rsa2_openssh_createkey(unsigned char **blob, int *len) if (!rsa->modulus || !rsa->exponent || !rsa->private_exponent || !rsa->iqmp || !rsa->p || !rsa->q) { - sfree(rsa->modulus); - sfree(rsa->exponent); - sfree(rsa->private_exponent); - sfree(rsa->iqmp); - sfree(rsa->p); - sfree(rsa->q); - sfree(rsa); + rsa2_freekey(rsa); + return NULL; + } + + if (!rsa_verify(rsa)) { + rsa2_freekey(rsa); return NULL; } @@ -669,53 +766,21 @@ static int rsa2_openssh_fmtkey(void *key, unsigned char *blob, int len) return bloblen; } -static int rsa2_pubkey_bits(void *blob, int len) +static int rsa2_pubkey_bits(const struct ssh_signkey *self, + const void *blob, int len) { struct RSAKey *rsa; int ret; - rsa = rsa2_newkey((char *) blob, len); + rsa = rsa2_newkey(self, (const char *) blob, len); + if (!rsa) + return -1; ret = bignum_bitcount(rsa->modulus); rsa2_freekey(rsa); return ret; } -static char *rsa2_fingerprint(void *key) -{ - struct RSAKey *rsa = (struct RSAKey *) key; - struct MD5Context md5c; - unsigned char digest[16], lenbuf[4]; - char buffer[16 * 3 + 40]; - char *ret; - int numlen, i; - - MD5Init(&md5c); - MD5Update(&md5c, (unsigned char *)"\0\0\0\7ssh-rsa", 11); - -#define ADD_BIGNUM(bignum) \ - numlen = (bignum_bitcount(bignum)+8)/8; \ - PUT_32BIT(lenbuf, numlen); MD5Update(&md5c, lenbuf, 4); \ - for (i = numlen; i-- ;) { \ - unsigned char c = bignum_byte(bignum, i); \ - MD5Update(&md5c, &c, 1); \ - } - ADD_BIGNUM(rsa->exponent); - ADD_BIGNUM(rsa->modulus); -#undef ADD_BIGNUM - - MD5Final(digest, &md5c); - - sprintf(buffer, "ssh-rsa %d ", bignum_bitcount(rsa->modulus)); - for (i = 0; i < 16; i++) - sprintf(buffer + strlen(buffer), "%s%02x", i ? ":" : "", - digest[i]); - ret = snewn(strlen(buffer) + 1, char); - if (ret) - strcpy(ret, buffer); - return ret; -} - /* * This is the magic ASN.1/DER prefix that goes in the decoded * signature, between the string of FFs and the actual SHA hash @@ -747,12 +812,12 @@ static const unsigned char asn1_weird_stuff[] = { #define ASN1_LEN ( (int) sizeof(asn1_weird_stuff) ) -static int rsa2_verifysig(void *key, char *sig, int siglen, - char *data, int datalen) +static int rsa2_verifysig(void *key, const char *sig, int siglen, + const char *data, int datalen) { struct RSAKey *rsa = (struct RSAKey *) key; Bignum in, out; - char *p; + const char *p; int slen; int bytes, i, j, ret; unsigned char hash[20]; @@ -762,6 +827,8 @@ static int rsa2_verifysig(void *key, char *sig, int siglen, return 0; } in = getmp(&sig, &siglen); + if (!in) + return 0; out = modpow(in, rsa->exponent, rsa->modulus); freebn(in); @@ -795,7 +862,7 @@ static int rsa2_verifysig(void *key, char *sig, int siglen, return ret; } -static unsigned char *rsa2_sign(void *key, char *data, int datalen, +static unsigned char *rsa2_sign(void *key, const char *data, int datalen, int *siglen) { struct RSAKey *rsa = (struct RSAKey *) key; @@ -847,17 +914,18 @@ const struct ssh_signkey ssh_rsa = { rsa2_createkey, rsa2_openssh_createkey, rsa2_openssh_fmtkey, + 6 /* n,e,d,iqmp,q,p */, rsa2_pubkey_bits, - rsa2_fingerprint, rsa2_verifysig, rsa2_sign, "ssh-rsa", - "rsa2" + "rsa2", + NULL, }; void *ssh_rsakex_newkey(char *data, int len) { - return rsa2_newkey(data, len); + return rsa2_newkey(&ssh_rsa, data, len); } void ssh_rsakex_freekey(void *key) @@ -992,11 +1060,11 @@ void ssh_rsakex_encrypt(const struct ssh_hash *h, unsigned char *in, int inlen, } static const struct ssh_kex ssh_rsa_kex_sha1 = { - "rsa1024-sha1", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha1 + "rsa1024-sha1", NULL, KEXTYPE_RSA, &ssh_sha1, NULL, }; static const struct ssh_kex ssh_rsa_kex_sha256 = { - "rsa2048-sha256", NULL, KEXTYPE_RSA, NULL, NULL, 0, 0, &ssh_sha256 + "rsa2048-sha256", NULL, KEXTYPE_RSA, &ssh_sha256, NULL, }; static const struct ssh_kex *const rsa_kex_list[] = {