]> asedeno.scripts.mit.edu Git - PuTTY.git/blobdiff - sshecc.c
Clean up hash selection in ECDSA.
[PuTTY.git] / sshecc.c
index ec2648b0896bdab6c0a01329bf66aa9ec2c20e9d..8b72d6f6767f31b30a9e653ea6d178c3f999be61 100644 (file)
--- a/sshecc.c
+++ b/sshecc.c
@@ -148,16 +148,7 @@ static int initialise_ecurve(struct ec_curve *curve, int bits, unsigned char *p,
     return 0;
 }
 
-unsigned char nistp256_oid[] = {0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07};
-int nistp256_oid_len = 8;
-unsigned char nistp384_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x22};
-int nistp384_oid_len = 5;
-unsigned char nistp521_oid[] = {0x2b, 0x81, 0x04, 0x00, 0x23};
-int nistp521_oid_len = 5;
-unsigned char curve25519_oid[] = {0x06, 0x0A, 0x2B, 0x06, 0x01, 0x04, 0x01, 0x97, 0x55, 0x01, 0x05, 0x01};
-int curve25519_oid_len = 12;
-
-struct ec_curve *ec_p256(void)
+static struct ec_curve *ec_p256(void)
 {
     static struct ec_curve curve = { 0 };
     static unsigned char initialised = 0;
@@ -205,6 +196,8 @@ struct ec_curve *ec_p256(void)
             return NULL;
         }
 
+        curve.name = "nistp256";
+
         /* Now initialised, no need to do it again */
         initialised = 1;
     }
@@ -212,7 +205,7 @@ struct ec_curve *ec_p256(void)
     return &curve;
 }
 
-struct ec_curve *ec_p384(void)
+static struct ec_curve *ec_p384(void)
 {
     static struct ec_curve curve = { 0 };
     static unsigned char initialised = 0;
@@ -272,6 +265,8 @@ struct ec_curve *ec_p384(void)
             return NULL;
         }
 
+        curve.name = "nistp384";
+
         /* Now initialised, no need to do it again */
         initialised = 1;
     }
@@ -279,7 +274,7 @@ struct ec_curve *ec_p384(void)
     return &curve;
 }
 
-struct ec_curve *ec_p521(void)
+static struct ec_curve *ec_p521(void)
 {
     static struct ec_curve curve = { 0 };
     static unsigned char initialised = 0;
@@ -357,6 +352,8 @@ struct ec_curve *ec_p521(void)
             return NULL;
         }
 
+        curve.name = "nistp521";
+
         /* Now initialised, no need to do it again */
         initialised = 1;
     }
@@ -364,7 +361,7 @@ struct ec_curve *ec_p521(void)
     return &curve;
 }
 
-struct ec_curve *ec_curve25519(void)
+static struct ec_curve *ec_curve25519(void)
 {
     static struct ec_curve curve = { 0 };
     static unsigned char initialised = 0;
@@ -400,13 +397,18 @@ struct ec_curve *ec_curve25519(void)
             return NULL;
         }
 
+        /* This curve doesn't need a name, because it's never used in
+         * any format that embeds the curve name */
+        curve.name = NULL;
+
         /* Now initialised, no need to do it again */
         initialised = 1;
     }
 
     return &curve;
 }
-struct ec_curve *ec_ed25519(void)
+
+static struct ec_curve *ec_ed25519(void)
 {
     static struct ec_curve curve = { 0 };
     static unsigned char initialised = 0;
@@ -444,6 +446,9 @@ struct ec_curve *ec_ed25519(void)
             0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0x58
         };
 
+        /* This curve doesn't need a name, because it's never used in
+         * any format that embeds the curve name */
+        curve.name = NULL;
 
         if (!initialise_ecurve(&curve, 256, q, l, d, Bx, By)) {
             return NULL;
@@ -456,131 +461,6 @@ struct ec_curve *ec_ed25519(void)
     return &curve;
 }
 
-static struct ec_curve *ec_name_to_curve(const char *name, int len) {
-    if (len > 11 && !memcmp(name, "ecdsa-sha2-", 11)) {
-        name += 11;
-        len -= 11;
-    } else if (len > 10 && !memcmp(name, "ecdh-sha2-", 10)) {
-        name += 10;
-        len -= 10;
-    } else if (len == 11 && !memcmp(name, "ssh-ed25519", 11)) {
-        return ec_ed25519();
-    }
-
-    if (len == 8 && !memcmp(name, "nistp", 5)) {
-        name += 5;
-        if (!memcmp(name, "256", 3)) {
-            return ec_p256();
-        } else if (!memcmp(name, "384", 3)) {
-            return ec_p384();
-        } else if (!memcmp(name, "521", 3)) {
-            return ec_p521();
-        }
-    }
-
-    if (len == 28 && !memcmp(name, "curve25519-sha256@libssh.org", 28)) {
-        return ec_curve25519();
-    }
-
-    return NULL;
-}
-
-/* Type enumeration for specifying the curve name */
-enum ec_name_type { EC_TYPE_DSA, EC_TYPE_DH, EC_TYPE_CURVE };
-
-static int ec_curve_to_name(enum ec_name_type type, const struct ec_curve *curve,
-                            unsigned char *name, int len) {
-    if (curve->type == EC_WEIERSTRASS) {
-        int length, loc;
-        if (type == EC_TYPE_DSA) {
-            length = 19;
-            loc = 16;
-        } else if (type == EC_TYPE_DH) {
-            length = 18;
-            loc = 15;
-        } else {
-            length = 8;
-            loc = 5;
-        }
-
-        /* Return length of string */
-        if (name == NULL) return length;
-
-        /* Not enough space for the name */
-        if (len < length) return 0;
-
-        /* Put the name in the buffer */
-        switch (curve->fieldBits) {
-          case 256:
-            memcpy(name+loc, "256", 3);
-            break;
-          case 384:
-            memcpy(name+loc, "384", 3);
-            break;
-          case 521:
-            memcpy(name+loc, "521", 3);
-            break;
-          default:
-            return 0;
-        }
-
-        if (type == EC_TYPE_DSA) {
-            memcpy(name, "ecdsa-sha2-nistp", 16);
-        } else if (type == EC_TYPE_DH) {
-            memcpy(name, "ecdh-sha2-nistp", 15);
-        } else {
-            memcpy(name, "nistp", 5);
-        }
-
-        return length;
-    } else if (curve->type == EC_EDWARDS) {
-        /* No DH for ed25519 - use Montgomery instead */
-        if (type == EC_TYPE_DH) return 0;
-
-        if (type == EC_TYPE_CURVE) {
-            /* Return length of string */
-            if (name == NULL) return 7;
-
-            /* Not enough space for the name */
-            if (len < 7) return 0;
-
-            /* Unknown curve field */
-            if (curve->fieldBits != 256) return 0;
-
-            memcpy(name, "ed25519", 7);
-            return 7;
-
-        } else {
-            /* Return length of string */
-            if (name == NULL) return 11;
-
-            /* Not enough space for the name */
-            if (len < 11) return 0;
-
-            /* Unknown curve field */
-            if (curve->fieldBits != 256) return 0;
-
-            memcpy(name, "ssh-ed25519", 11);
-            return 11;
-        }
-    } else {
-        /* No DSA for curve25519 */
-        if (type == EC_TYPE_DSA || type == EC_TYPE_CURVE) return 0;
-
-        /* Return length of string */
-        if (name == NULL) return 28;
-
-        /* Not enough space for the name */
-        if (len < 28) return 0;
-
-        /* Unknown curve field */
-        if (curve->fieldBits != 256) return 0;
-
-        memcpy(name, "curve25519-sha256@libssh.org", 28);
-        return 28;
-    }
-}
-
 /* Return 1 if a is -3 % p, otherwise return 0
  * This is used because there are some maths optimisations */
 static int ec_aminus3(const struct ec_curve *curve)
@@ -2569,6 +2449,15 @@ static int getmppoint(const char **data, int *datalen, struct ec_point *point)
  * Exposed ECDSA interface
  */
 
+struct ecsign_extra {
+    struct ec_curve *(*curve)(void);
+    const struct ssh_hash *hash;
+
+    /* These fields are used by the OpenSSH PEM format importer/exporter */
+    const unsigned char *oid;
+    int oidlen;
+};
+
 static void ecdsa_freekey(void *key)
 {
     struct ec_key *ec = (struct ec_key *) key;
@@ -2588,6 +2477,8 @@ static void ecdsa_freekey(void *key)
 static void *ecdsa_newkey(const struct ssh_signkey *self,
                           const char *data, int len)
 {
+    const struct ecsign_extra *extra =
+        (const struct ecsign_extra *)self->extra;
     const char *p;
     int slen;
     struct ec_key *ec;
@@ -2598,21 +2489,18 @@ static void *ecdsa_newkey(const struct ssh_signkey *self,
     if (!p) {
         return NULL;
     }
-    curve = ec_name_to_curve(p, slen);
-    if (!curve) return NULL;
-
-    if (curve->type != EC_WEIERSTRASS && curve->type != EC_EDWARDS) {
-        return NULL;
-    }
+    curve = extra->curve();
+    assert(curve->type == EC_WEIERSTRASS || curve->type == EC_EDWARDS);
 
     /* Curve name is duplicated for Weierstrass form */
     if (curve->type == EC_WEIERSTRASS) {
         getstring(&data, &len, &p, &slen);
-        if (curve != ec_name_to_curve(p, slen)) return NULL;
+        if (!match_ssh_id(slen, p, curve->name)) return NULL;
     }
 
     ec = snew(struct ec_key);
 
+    ec->signalg = self;
     ec->publicKey.curve = curve;
     ec->publicKey.infinity = 0;
     ec->publicKey.x = NULL;
@@ -2644,17 +2532,17 @@ static char *ecdsa_fmtkey(void *key)
     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
         return NULL;
 
-    pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
-    if (pos == 0) return NULL;
-
     len = 4 + 2 + 1;                  /* 2 x "0x", punctuation, \0 */
-    len += pos; /* Curve name */
+    if (ec->publicKey.curve->name)
+        len += strlen(ec->publicKey.curve->name); /* Curve name */
     len += 4 * (bignum_bitcount(ec->publicKey.x) + 15) / 16;
     len += 4 * (bignum_bitcount(ec->publicKey.y) + 15) / 16;
     p = snewn(len, char);
 
-    pos = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, (unsigned char*)p, pos);
-    pos += sprintf(p + pos, ",0x");
+    pos = 0;
+    if (ec->publicKey.curve->name)
+        pos += sprintf(p + pos, "%s,", ec->publicKey.curve->name);
+    pos += sprintf(p + pos, "0x");
     nibbles = (3 + bignum_bitcount(ec->publicKey.x)) / 4;
     if (nibbles < 1)
         nibbles = 1;
@@ -2681,10 +2569,10 @@ static unsigned char *ecdsa_public_blob(void *key, int *len)
     int i;
     unsigned char *blob, *p;
 
+    fullnamelen = strlen(ec->signalg->name);
+
     if (ec->publicKey.curve->type == EC_EDWARDS) {
         /* Edwards compressed form "ssh-ed25519" point y[:-1] + x[0:1] */
-        fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
-        if (fullnamelen == 0) return NULL;
 
         pointlen = ec->publicKey.curve->fieldBits / 8;
 
@@ -2698,7 +2586,8 @@ static unsigned char *ecdsa_public_blob(void *key, int *len)
         p = blob;
         PUT_32BIT(p, fullnamelen);
         p += 4;
-        p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
+        memcpy(p, ec->signalg->name, fullnamelen);
+        p += fullnamelen;
         PUT_32BIT(p, pointlen);
         p += 4;
 
@@ -2710,10 +2599,8 @@ static unsigned char *ecdsa_public_blob(void *key, int *len)
         *p = bignum_byte(ec->publicKey.y, i) & 0x7f;
         *p++ |= bignum_bit(ec->publicKey.x, 0) << 7;
     } else if (ec->publicKey.curve->type == EC_WEIERSTRASS) {
-        fullnamelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
-        if (fullnamelen == 0) return NULL;
-        namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
-        if (namelen == 0) return NULL;
+        assert(ec->publicKey.curve->name);
+        namelen = strlen(ec->publicKey.curve->name);
 
         pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
 
@@ -2726,10 +2613,12 @@ static unsigned char *ecdsa_public_blob(void *key, int *len)
         p = blob;
         PUT_32BIT(p, fullnamelen);
         p += 4;
-        p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, fullnamelen);
+        memcpy(p, ec->signalg->name, fullnamelen);
+        p += fullnamelen;
         PUT_32BIT(p, namelen);
         p += 4;
-        p += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, p, namelen);
+        memcpy(p, ec->publicKey.curve->name, namelen);
+        p += namelen;
         PUT_32BIT(p, (2 * pointlen) + 1);
         p += 4;
         *p++ = 0x04;
@@ -2853,6 +2742,7 @@ static void *ed25519_openssh_createkey(const struct ssh_signkey *self,
         return NULL;
     }
 
+    ec->signalg = self;
     ec->publicKey.curve = ec_ed25519();
     ec->publicKey.infinity = 0;
     ec->privateKey = NULL;
@@ -2956,6 +2846,8 @@ static int ed25519_openssh_fmtkey(void *key, unsigned char *blob, int len)
 static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
                                      const unsigned char **blob, int *len)
 {
+    const struct ecsign_extra *extra =
+        (const struct ecsign_extra *)self->extra;
     const char **b = (const char **) blob;
     const char *p;
     int slen;
@@ -2968,15 +2860,12 @@ static void *ecdsa_openssh_createkey(const struct ssh_signkey *self,
     if (!p) {
         return NULL;
     }
-    curve = ec_name_to_curve(p, slen);
-    if (!curve) return NULL;
-
-    if (curve->type != EC_WEIERSTRASS) {
-        return NULL;
-    }
+    curve = extra->curve();
+    assert(curve->type == EC_WEIERSTRASS);
 
     ec = snew(struct ec_key);
 
+    ec->signalg = self;
     ec->publicKey.curve = curve;
     ec->publicKey.infinity = 0;
     ec->publicKey.x = NULL;
@@ -3039,7 +2928,7 @@ static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
     }
 
     pointlen = (bignum_bitcount(ec->publicKey.curve->p) + 7) / 8;
-    namelen = ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, NULL, 0);
+    namelen = strlen(ec->publicKey.curve->name);
     bloblen =
         4 + namelen /* <LEN> nistpXXX */
         + 4 + 1 + (pointlen * 2) /* <LEN> 0x04 pX pY */
@@ -3052,8 +2941,8 @@ static int ecdsa_openssh_fmtkey(void *key, unsigned char *blob, int len)
 
     PUT_32BIT(blob+bloblen, namelen);
     bloblen += 4;
-
-    bloblen += ec_curve_to_name(EC_TYPE_CURVE, ec->publicKey.curve, blob+bloblen, namelen);
+    memcpy(blob+bloblen, ec->publicKey.curve->name, namelen);
+    bloblen += namelen;
 
     PUT_32BIT(blob+bloblen, 1 + (pointlen * 2));
     bloblen += 4;
@@ -3091,6 +2980,8 @@ static int ecdsa_verifysig(void *key, const char *sig, int siglen,
                            const char *data, int datalen)
 {
     struct ec_key *ec = (struct ec_key *) key;
+    const struct ecsign_extra *extra =
+        (const struct ecsign_extra *)ec->signalg->extra;
     const char *p;
     int slen;
     int digestLen;
@@ -3099,12 +2990,12 @@ static int ecdsa_verifysig(void *key, const char *sig, int siglen,
     if (!ec->publicKey.x || !ec->publicKey.y || !ec->publicKey.curve)
         return 0;
 
-    /* Check the signature curve matches the key curve */
+    /* Check the signature starts with the algorithm name */
     getstring(&sig, &siglen, &p, &slen);
     if (!p) {
         return 0;
     }
-    if (ec->publicKey.curve != ec_name_to_curve(p, slen)) {
+    if (!match_ssh_id(slen, p, ec->signalg->name)) {
         return 0;
     }
 
@@ -3219,6 +3110,7 @@ static int ecdsa_verifysig(void *key, const char *sig, int siglen,
     } else {
         Bignum r, s;
         unsigned char digest[512 / 8];
+        void *hashctx;
 
         r = getmp(&p, &slen);
         if (!r) return 0;
@@ -3228,17 +3120,11 @@ static int ecdsa_verifysig(void *key, const char *sig, int siglen,
             return 0;
         }
 
-        /* Perform correct hash function depending on curve size */
-        if (ec->publicKey.curve->fieldBits <= 256) {
-            SHA256_Simple(data, datalen, digest);
-            digestLen = 256 / 8;
-        } else if (ec->publicKey.curve->fieldBits <= 384) {
-            SHA384_Simple(data, datalen, digest);
-            digestLen = 384 / 8;
-        } else {
-            SHA512_Simple(data, datalen, digest);
-            digestLen = 512 / 8;
-        }
+        digestLen = extra->hash->hlen;
+        assert(digestLen <= sizeof(digest));
+        hashctx = extra->hash->init();
+        extra->hash->bytes(hashctx, data, datalen);
+        extra->hash->final(hashctx, digest);
 
         /* Verify the signature */
         ret = _ecdsa_verify(&ec->publicKey, digest, digestLen, r, s);
@@ -3254,6 +3140,8 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
                                  int *siglen)
 {
     struct ec_key *ec = (struct ec_key *) key;
+    const struct ecsign_extra *extra =
+        (const struct ecsign_extra *)ec->signalg->extra;
     unsigned char digest[512 / 8];
     int digestLen;
     Bignum r = NULL, s = NULL;
@@ -3382,13 +3270,14 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
         }
 
         /* Format the output */
-        namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
+        namelen = strlen(ec->signalg->name);
         *siglen = 4+namelen+4+((ec->publicKey.curve->fieldBits / 8)*2);
         buf = snewn(*siglen, unsigned char);
         p = buf;
         PUT_32BIT(p, namelen);
         p += 4;
-        p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
+        memcpy(p, ec->signalg->name, namelen);
+        p += namelen;
         PUT_32BIT(p, ((ec->publicKey.curve->fieldBits / 8)*2));
         p += 4;
 
@@ -3408,17 +3297,13 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
         }
         freebn(s);
     } else {
-        /* Perform correct hash function depending on curve size */
-        if (ec->publicKey.curve->fieldBits <= 256) {
-            SHA256_Simple(data, datalen, digest);
-            digestLen = 256 / 8;
-        } else if (ec->publicKey.curve->fieldBits <= 384) {
-            SHA384_Simple(data, datalen, digest);
-            digestLen = 384 / 8;
-        } else {
-            SHA512_Simple(data, datalen, digest);
-            digestLen = 512 / 8;
-        }
+        void *hashctx;
+
+        digestLen = extra->hash->hlen;
+        assert(digestLen <= sizeof(digest));
+        hashctx = extra->hash->init();
+        extra->hash->bytes(hashctx, data, datalen);
+        extra->hash->final(hashctx, digest);
 
         /* Do the signature */
         _ecdsa_sign(ec->privateKey, ec->publicKey.curve, digest, digestLen, &r, &s);
@@ -3431,7 +3316,7 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
         rlen = (bignum_bitcount(r) + 8) / 8;
         slen = (bignum_bitcount(s) + 8) / 8;
 
-        namelen = ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, NULL, 0);
+        namelen = strlen(ec->signalg->name);
 
         /* Format the output */
         *siglen = 8+namelen+rlen+slen+8;
@@ -3439,7 +3324,8 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
         p = buf;
         PUT_32BIT(p, namelen);
         p += 4;
-        p += ec_curve_to_name(EC_TYPE_DSA, ec->publicKey.curve, p, namelen);
+        memcpy(p, ec->signalg->name, namelen);
+        p += namelen;
         PUT_32BIT(p, rlen + slen + 8);
         p += 4;
         PUT_32BIT(p, rlen);
@@ -3458,6 +3344,10 @@ static unsigned char *ecdsa_sign(void *key, const char *data, int datalen,
     return buf;
 }
 
+const struct ecsign_extra sign_extra_ed25519 = {
+    ec_ed25519, NULL,
+    NULL, 0,
+};
 const struct ssh_signkey ssh_ecdsa_ed25519 = {
     ecdsa_newkey,
     ecdsa_freekey,
@@ -3473,9 +3363,17 @@ const struct ssh_signkey ssh_ecdsa_ed25519 = {
     ecdsa_sign,
     "ssh-ed25519",
     "ssh-ed25519",
-    NULL,
+    &sign_extra_ed25519,
 };
 
+/* OID: 1.2.840.10045.3.1.7 (ansiX9p256r1) */
+static const unsigned char nistp256_oid[] = {
+    0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07
+};
+const struct ecsign_extra sign_extra_nistp256 = {
+    ec_p256, &ssh_sha256,
+    nistp256_oid, lenof(nistp256_oid),
+};
 const struct ssh_signkey ssh_ecdsa_nistp256 = {
     ecdsa_newkey,
     ecdsa_freekey,
@@ -3491,9 +3389,17 @@ const struct ssh_signkey ssh_ecdsa_nistp256 = {
     ecdsa_sign,
     "ecdsa-sha2-nistp256",
     "ecdsa-sha2-nistp256",
-    NULL,
+    &sign_extra_nistp256,
 };
 
+/* OID: 1.3.132.0.34 (secp384r1) */
+static const unsigned char nistp384_oid[] = {
+    0x2b, 0x81, 0x04, 0x00, 0x22
+};
+const struct ecsign_extra sign_extra_nistp384 = {
+    ec_p384, &ssh_sha384,
+    nistp384_oid, lenof(nistp384_oid),
+};
 const struct ssh_signkey ssh_ecdsa_nistp384 = {
     ecdsa_newkey,
     ecdsa_freekey,
@@ -3509,9 +3415,17 @@ const struct ssh_signkey ssh_ecdsa_nistp384 = {
     ecdsa_sign,
     "ecdsa-sha2-nistp384",
     "ecdsa-sha2-nistp384",
-    NULL,
+    &sign_extra_nistp384,
 };
 
+/* OID: 1.3.132.0.35 (secp521r1) */
+static const unsigned char nistp521_oid[] = {
+    0x2b, 0x81, 0x04, 0x00, 0x23
+};
+const struct ecsign_extra sign_extra_nistp521 = {
+    ec_p521, &ssh_sha512,
+    nistp521_oid, lenof(nistp521_oid),
+};
 const struct ssh_signkey ssh_ecdsa_nistp521 = {
     ecdsa_newkey,
     ecdsa_freekey,
@@ -3527,13 +3441,17 @@ const struct ssh_signkey ssh_ecdsa_nistp521 = {
     ecdsa_sign,
     "ecdsa-sha2-nistp521",
     "ecdsa-sha2-nistp521",
-    NULL,
+    &sign_extra_nistp521,
 };
 
 /* ----------------------------------------------------------------------
  * Exposed ECDH interface
  */
 
+struct eckex_extra {
+    struct ec_curve *(*curve)(void);
+};
+
 static Bignum ecdh_calculate(const Bignum private,
                              const struct ec_point *public)
 {
@@ -3566,19 +3484,21 @@ static Bignum ecdh_calculate(const Bignum private,
     return ret;
 }
 
-void *ssh_ecdhkex_newkey(const char *name)
+void *ssh_ecdhkex_newkey(const struct ssh_kex *kex)
 {
+    const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra;
     struct ec_curve *curve;
     struct ec_key *key;
     struct ec_point *publicKey;
 
-    curve = ec_name_to_curve(name, strlen(name));
+    curve = extra->curve();
 
     key = snew(struct ec_key);
     if (!key) {
         return NULL;
     }
 
+    key->signalg = NULL;
     key->publicKey.curve = curve;
 
     if (curve->type == EC_MONTGOMERY) {
@@ -3704,30 +3624,99 @@ void ssh_ecdhkex_freekey(void *key)
     ecdsa_freekey(key);
 }
 
+static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 };
 static const struct ssh_kex ssh_ec_kex_curve25519 = {
-    "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH, &ssh_sha256, NULL
+    "curve25519-sha256@libssh.org", NULL, KEXTYPE_ECDH,
+    &ssh_sha256, &kex_extra_curve25519,
 };
 
+const struct eckex_extra kex_extra_nistp256 = { ec_p256 };
 static const struct ssh_kex ssh_ec_kex_nistp256 = {
-    "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH, &ssh_sha256, NULL
+    "ecdh-sha2-nistp256", NULL, KEXTYPE_ECDH,
+    &ssh_sha256, &kex_extra_nistp256,
 };
 
+const struct eckex_extra kex_extra_nistp384 = { ec_p384 };
 static const struct ssh_kex ssh_ec_kex_nistp384 = {
-    "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH, &ssh_sha384, NULL
+    "ecdh-sha2-nistp384", NULL, KEXTYPE_ECDH,
+    &ssh_sha384, &kex_extra_nistp384,
 };
 
+const struct eckex_extra kex_extra_nistp521 = { ec_p521 };
 static const struct ssh_kex ssh_ec_kex_nistp521 = {
-    "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH, &ssh_sha512, NULL
+    "ecdh-sha2-nistp521", NULL, KEXTYPE_ECDH,
+    &ssh_sha512, &kex_extra_nistp521,
 };
 
 static const struct ssh_kex *const ec_kex_list[] = {
     &ssh_ec_kex_curve25519,
     &ssh_ec_kex_nistp256,
     &ssh_ec_kex_nistp384,
-    &ssh_ec_kex_nistp521
+    &ssh_ec_kex_nistp521,
 };
 
 const struct ssh_kexes ssh_ecdh_kex = {
     sizeof(ec_kex_list) / sizeof(*ec_kex_list),
     ec_kex_list
 };
+
+/* ----------------------------------------------------------------------
+ * Helper functions for finding key algorithms and returning auxiliary
+ * data.
+ */
+
+const struct ssh_signkey *ec_alg_by_oid(int len, const void *oid,
+                                        const struct ec_curve **curve)
+{
+    static const struct ssh_signkey *algs_with_oid[] = {
+        &ssh_ecdsa_nistp256,
+        &ssh_ecdsa_nistp384,
+        &ssh_ecdsa_nistp521,
+    };
+    int i;
+
+    for (i = 0; i < lenof(algs_with_oid); i++) {
+        const struct ssh_signkey *alg = algs_with_oid[i];
+        const struct ecsign_extra *extra =
+            (const struct ecsign_extra *)alg->extra;
+        if (len == extra->oidlen && !memcmp(oid, extra->oid, len)) {
+            *curve = extra->curve();
+            return alg;
+        }
+    }
+    return NULL;
+}
+
+const unsigned char *ec_alg_oid(const struct ssh_signkey *alg,
+                                int *oidlen)
+{
+    const struct ecsign_extra *extra = (const struct ecsign_extra *)alg->extra;
+    *oidlen = extra->oidlen;
+    return extra->oid;
+}
+
+const int ec_nist_alg_and_curve_by_bits(int bits,
+                                        const struct ec_curve **curve,
+                                        const struct ssh_signkey **alg)
+{
+    switch (bits) {
+      case 256: *alg = &ssh_ecdsa_nistp256; break;
+      case 384: *alg = &ssh_ecdsa_nistp384; break;
+      case 521: *alg = &ssh_ecdsa_nistp521; break;
+      default: return FALSE;
+    }
+    *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
+    return TRUE;
+}
+
+const int ec_ed_alg_and_curve_by_bits(int bits,
+                                      const struct ec_curve **curve,
+                                      const struct ssh_signkey **alg)
+{
+    switch (bits) {
+      case 256: *alg = &ssh_ecdsa_ed25519; break;
+      default: return FALSE;
+    }
+    *curve = ((struct ecsign_extra *)(*alg)->extra)->curve();
+    return TRUE;
+}