+/* ----------------------------------------------------------------------
+ * Code to read and write OpenSSH private keys in the new-style format.
+ */
+
+typedef enum {
+ ON_E_NONE, ON_E_AES256CBC
+} openssh_new_cipher;
+typedef enum {
+ ON_K_NONE, ON_K_BCRYPT
+} openssh_new_kdf;
+
+struct openssh_new_key {
+ openssh_new_cipher cipher;
+ openssh_new_kdf kdf;
+ union {
+ struct {
+ int rounds;
+ /* This points to a position within keyblob, not a
+ * separately allocated thing */
+ const unsigned char *salt;
+ int saltlen;
+ } bcrypt;
+ } kdfopts;
+ int nkeys, key_wanted;
+ /* This too points to a position within keyblob */
+ unsigned char *privatestr;
+ int privatelen;
+
+ unsigned char *keyblob;
+ int keyblob_len, keyblob_size;
+};
+
+static struct openssh_new_key *load_openssh_new_key(const Filename *filename,
+ const char **errmsg_p)
+{
+ struct openssh_new_key *ret;
+ FILE *fp = NULL;
+ char *line = NULL;
+ const char *errmsg;
+ char *p;
+ char base64_bit[4];
+ int base64_chars = 0;
+ const void *filedata;
+ int filelen;
+ const void *string, *kdfopts, *bcryptsalt, *pubkey;
+ int stringlen, kdfoptlen, bcryptsaltlen, pubkeylen;
+ unsigned bcryptrounds, nkeys, key_index;
+
+ ret = snew(struct openssh_new_key);
+ ret->keyblob = NULL;
+ ret->keyblob_len = ret->keyblob_size = 0;
+
+ fp = f_open(filename, "r", FALSE);
+ if (!fp) {
+ errmsg = "unable to open key file";
+ goto error;
+ }
+
+ if (!(line = fgetline(fp))) {
+ errmsg = "unexpected end of file";
+ goto error;
+ }
+ strip_crlf(line);
+ if (0 != strcmp(line, "-----BEGIN OPENSSH PRIVATE KEY-----")) {
+ errmsg = "file does not begin with OpenSSH new-style key header";
+ goto error;
+ }
+ smemclr(line, strlen(line));
+ sfree(line);
+ line = NULL;
+
+ while (1) {
+ if (!(line = fgetline(fp))) {
+ errmsg = "unexpected end of file";
+ goto error;
+ }
+ strip_crlf(line);
+ if (0 == strcmp(line, "-----END OPENSSH PRIVATE KEY-----")) {
+ sfree(line);
+ line = NULL;
+ break; /* done */
+ }
+
+ p = line;
+ while (isbase64(*p)) {
+ base64_bit[base64_chars++] = *p;
+ if (base64_chars == 4) {
+ unsigned char out[3];
+ int len;
+
+ base64_chars = 0;
+
+ len = base64_decode_atom(base64_bit, out);
+
+ if (len <= 0) {
+ errmsg = "invalid base64 encoding";
+ goto error;
+ }
+
+ if (ret->keyblob_len + len > ret->keyblob_size) {
+ ret->keyblob_size = ret->keyblob_len + len + 256;
+ ret->keyblob = sresize(ret->keyblob, ret->keyblob_size,
+ unsigned char);
+ }
+
+ memcpy(ret->keyblob + ret->keyblob_len, out, len);
+ ret->keyblob_len += len;
+
+ smemclr(out, sizeof(out));
+ }
+
+ p++;
+ }
+ smemclr(line, strlen(line));
+ sfree(line);
+ line = NULL;
+ }
+
+ fclose(fp);
+ fp = NULL;
+
+ if (ret->keyblob_len == 0 || !ret->keyblob) {
+ errmsg = "key body not present";
+ goto error;
+ }
+
+ filedata = ret->keyblob;
+ filelen = ret->keyblob_len;
+
+ if (filelen < 15 || 0 != memcmp(filedata, "openssh-key-v1\0", 15)) {
+ errmsg = "new-style OpenSSH magic number missing\n";
+ goto error;
+ }
+ filedata = (const char *)filedata + 15;
+ filelen -= 15;
+
+ if (!(string = get_ssh_string(&filelen, &filedata, &stringlen))) {
+ errmsg = "encountered EOF before cipher name\n";
+ goto error;
+ }
+ if (match_ssh_id(stringlen, string, "none")) {
+ ret->cipher = ON_E_NONE;
+ } else if (match_ssh_id(stringlen, string, "aes256-cbc")) {
+ ret->cipher = ON_E_AES256CBC;
+ } else {
+ errmsg = "unrecognised cipher name\n";
+ goto error;
+ }
+
+ if (!(string = get_ssh_string(&filelen, &filedata, &stringlen))) {
+ errmsg = "encountered EOF before kdf name\n";
+ goto error;
+ }
+ if (match_ssh_id(stringlen, string, "none")) {
+ ret->kdf = ON_K_NONE;
+ } else if (match_ssh_id(stringlen, string, "bcrypt")) {
+ ret->kdf = ON_K_BCRYPT;
+ } else {
+ errmsg = "unrecognised kdf name\n";
+ goto error;
+ }
+
+ if (!(kdfopts = get_ssh_string(&filelen, &filedata, &kdfoptlen))) {
+ errmsg = "encountered EOF before kdf options\n";
+ goto error;
+ }
+ switch (ret->kdf) {
+ case ON_K_NONE:
+ if (kdfoptlen != 0) {
+ errmsg = "expected empty options string for 'none' kdf";
+ goto error;
+ }
+ break;
+ case ON_K_BCRYPT:
+ if (!(bcryptsalt = get_ssh_string(&kdfoptlen, &kdfopts,
+ &bcryptsaltlen))) {
+ errmsg = "bcrypt options string did not contain salt\n";
+ goto error;
+ }
+ if (!get_ssh_uint32(&kdfoptlen, &kdfopts, &bcryptrounds)) {
+ errmsg = "bcrypt options string did not contain round count\n";
+ goto error;
+ }
+ ret->kdfopts.bcrypt.salt = bcryptsalt;
+ ret->kdfopts.bcrypt.saltlen = bcryptsaltlen;
+ ret->kdfopts.bcrypt.rounds = bcryptrounds;
+ break;
+ }
+
+ /*
+ * At this point we expect a uint32 saying how many keys are
+ * stored in this file. OpenSSH new-style key files can
+ * contain more than one. Currently we don't have any user
+ * interface to specify which one we're trying to extract, so
+ * we just bomb out with an error if more than one is found in
+ * the file. However, I've put in all the mechanism here to
+ * extract the nth one for a given n, in case we later connect
+ * up some UI to that mechanism. Just arrange that the
+ * 'key_wanted' field is set to a value in the range [0,
+ * nkeys) by some mechanism.
+ */
+ if (!get_ssh_uint32(&filelen, &filedata, &nkeys)) {
+ errmsg = "encountered EOF before key count\n";
+ goto error;
+ }
+ if (nkeys != 1) {
+ errmsg = "multiple keys in new-style OpenSSH key file "
+ "not supported\n";
+ goto error;
+ }
+ ret->nkeys = nkeys;
+ ret->key_wanted = 0;
+
+ for (key_index = 0; key_index < nkeys; key_index++) {
+ if (!(pubkey = get_ssh_string(&filelen, &filedata, &pubkeylen))) {
+ errmsg = "encountered EOF before kdf options\n";
+ goto error;
+ }
+ }
+
+ /*
+ * Now we expect a string containing the encrypted part of the
+ * key file.
+ */
+ if (!(string = get_ssh_string(&filelen, &filedata, &stringlen))) {
+ errmsg = "encountered EOF before private key container\n";
+ goto error;
+ }
+ ret->privatestr = (unsigned char *)string;
+ ret->privatelen = stringlen;
+
+ /*
+ * And now we're done, until asked to actually decrypt.
+ */
+
+ smemclr(base64_bit, sizeof(base64_bit));
+ if (errmsg_p) *errmsg_p = NULL;
+ return ret;
+
+ error:
+ if (line) {
+ smemclr(line, strlen(line));
+ sfree(line);
+ line = NULL;
+ }
+ smemclr(base64_bit, sizeof(base64_bit));
+ if (ret) {
+ if (ret->keyblob) {
+ smemclr(ret->keyblob, ret->keyblob_size);
+ sfree(ret->keyblob);
+ }
+ smemclr(ret, sizeof(*ret));
+ sfree(ret);
+ }
+ if (errmsg_p) *errmsg_p = errmsg;
+ if (fp) fclose(fp);
+ return NULL;
+}
+
+int openssh_new_encrypted(const Filename *filename)
+{
+ struct openssh_new_key *key = load_openssh_new_key(filename, NULL);
+ int ret;
+
+ if (!key)
+ return 0;
+ ret = (key->cipher != ON_E_NONE);
+ smemclr(key->keyblob, key->keyblob_size);
+ sfree(key->keyblob);
+ smemclr(key, sizeof(*key));
+ sfree(key);
+ return ret;
+}
+
+struct ssh2_userkey *openssh_new_read(const Filename *filename,
+ char *passphrase,
+ const char **errmsg_p)
+{
+ struct openssh_new_key *key = load_openssh_new_key(filename, errmsg_p);
+ struct ssh2_userkey *retkey;
+ int i;
+ struct ssh2_userkey *retval = NULL;
+ const char *errmsg;
+ unsigned char *blob;
+ int blobsize = 0;
+ unsigned checkint0, checkint1;
+ const void *priv, *string;
+ int privlen, stringlen, key_index;
+ const struct ssh_signkey *alg;
+
+ blob = NULL;
+
+ if (!key)
+ return NULL;
+
+ if (key->cipher != ON_E_NONE) {
+ unsigned char keybuf[48];
+ int keysize;
+
+ /*
+ * Construct the decryption key, and decrypt the string.
+ */
+ switch (key->cipher) {
+ case ON_E_NONE:
+ keysize = 0;
+ break;
+ case ON_E_AES256CBC:
+ keysize = 48; /* 32 byte key + 16 byte IV */
+ break;
+ default:
+ assert(0 && "Bad cipher enumeration value");
+ }
+ assert(keysize <= sizeof(keybuf));
+ switch (key->kdf) {
+ case ON_K_NONE:
+ memset(keybuf, 0, keysize);
+ break;
+ case ON_K_BCRYPT:
+ openssh_bcrypt(passphrase,
+ key->kdfopts.bcrypt.salt,
+ key->kdfopts.bcrypt.saltlen,
+ key->kdfopts.bcrypt.rounds,
+ keybuf, keysize);
+ break;
+ default:
+ assert(0 && "Bad kdf enumeration value");
+ }
+ switch (key->cipher) {
+ case ON_E_NONE:
+ break;
+ case ON_E_AES256CBC:
+ if (key->privatelen % 16 != 0) {
+ errmsg = "private key container length is not a"
+ " multiple of AES block size\n";
+ goto error;
+ }
+ {
+ void *ctx = aes_make_context();
+ aes256_key(ctx, keybuf);
+ aes_iv(ctx, keybuf + 32);
+ aes_ssh2_decrypt_blk(ctx, key->privatestr,
+ key->privatelen);
+ aes_free_context(ctx);
+ }
+ break;
+ default:
+ assert(0 && "Bad cipher enumeration value");
+ }
+ }
+
+ /*
+ * Now parse the entire encrypted section, and extract the key
+ * identified by key_wanted.
+ */
+ priv = key->privatestr;
+ privlen = key->privatelen;
+
+ if (!get_ssh_uint32(&privlen, &priv, &checkint0) ||
+ !get_ssh_uint32(&privlen, &priv, &checkint1) ||
+ checkint0 != checkint1) {
+ errmsg = "decryption check failed";
+ goto error;
+ }
+
+ retkey = NULL;
+ for (key_index = 0; key_index < key->nkeys; key_index++) {
+ const unsigned char *thiskey;
+ int thiskeylen;
+
+ /*
+ * Read the key type, which will tell us how to scan over
+ * the key to get to the next one.
+ */
+ if (!(string = get_ssh_string(&privlen, &priv, &stringlen))) {
+ errmsg = "expected key type in private string";
+ goto error;
+ }
+
+ /*
+ * Preliminary key type identification, and decide how
+ * many pieces of key we expect to see. Currently
+ * (conveniently) all key types can be seen as some number
+ * of strings, so we just need to know how many of them to
+ * skip over. (The numbers below exclude the key comment.)
+ */
+ {
+ /* find_pubkey_alg needs a zero-terminated copy of the
+ * algorithm name */
+ char *name_zt = dupprintf("%.*s", stringlen, (char *)string);
+ alg = find_pubkey_alg(name_zt);
+ sfree(name_zt);
+ }
+
+ if (!alg) {
+ errmsg = "private key type not recognised\n";
+ goto error;
+ }
+
+ thiskey = priv;
+
+ /*
+ * Skip over the pieces of key.
+ */
+ for (i = 0; i < alg->openssh_private_npieces; i++) {
+ if (!(string = get_ssh_string(&privlen, &priv, &stringlen))) {
+ errmsg = "ran out of data in mid-private-key";
+ goto error;
+ }
+ }
+
+ thiskeylen = (int)((const unsigned char *)priv -
+ (const unsigned char *)thiskey);
+ if (key_index == key->key_wanted) {
+ retkey = snew(struct ssh2_userkey);
+ retkey->alg = alg;
+ retkey->data = alg->openssh_createkey(alg, &thiskey, &thiskeylen);
+ if (!retkey->data) {
+ sfree(retkey);
+ errmsg = "unable to create key data structure";
+ goto error;
+ }
+ }
+
+ /*
+ * Read the key comment.
+ */
+ if (!(string = get_ssh_string(&privlen, &priv, &stringlen))) {
+ errmsg = "ran out of data at key comment";
+ goto error;
+ }
+ if (key_index == key->key_wanted) {
+ assert(retkey);
+ retkey->comment = dupprintf("%.*s", stringlen,
+ (const char *)string);
+ }
+ }
+
+ if (!retkey) {
+ errmsg = "key index out of range";
+ goto error;
+ }
+
+ /*
+ * Now we expect nothing left but padding.
+ */
+ for (i = 0; i < privlen; i++) {
+ if (((const unsigned char *)priv)[i] != (unsigned char)(i+1)) {
+ errmsg = "padding at end of private string did not match";
+ goto error;
+ }
+ }
+
+ errmsg = NULL; /* no error */
+ retval = retkey;
+
+ error:
+ if (blob) {
+ smemclr(blob, blobsize);
+ sfree(blob);
+ }
+ smemclr(key->keyblob, key->keyblob_size);
+ sfree(key->keyblob);
+ smemclr(key, sizeof(*key));
+ sfree(key);
+ if (errmsg_p) *errmsg_p = errmsg;
+ return retval;
+}
+