]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshccp.c
Dedicated routines for poly1305 arithmetic.
[PuTTY.git] / sshccp.c
1 /*
2  * ChaCha20-Poly1305 Implementation for SSH-2
3  *
4  * Protocol spec:
5  *  http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?rev=1.2&content-type=text/x-cvsweb-markup
6  *
7  * ChaCha20 spec:
8  *  http://cr.yp.to/chacha/chacha-20080128.pdf
9  *
10  * Salsa20 spec:
11  *  http://cr.yp.to/snuffle/spec.pdf
12  *
13  * Poly1305-AES spec:
14  *  http://cr.yp.to/mac/poly1305-20050329.pdf
15  *
16  * The nonce for the Poly1305 is the second part of the key output
17  * from the first round of ChaCha20. This removes the AES requirement.
18  * This is undocumented!
19  *
20  * This has an intricate link between the cipher and the MAC. The
21  * keying of both is done in by the cipher and setting of the IV is
22  * done by the MAC. One cannot operate without the other. The
23  * configuration of the ssh2_cipher structure ensures that the MAC is
24  * set (and others ignored) if this cipher is chosen.
25  *
26  * This cipher also encrypts the length using a different
27  * instantiation of the cipher using a different key and IV made from
28  * the sequence number which is passed in addition when calling
29  * encrypt/decrypt on it.
30  */
31
32 #include "ssh.h"
33 #include "sshbn.h"
34
35 #ifndef INLINE
36 #define INLINE
37 #endif
38
39 /* ChaCha20 implementation, only supporting 256-bit keys */
40
41 /* State for each ChaCha20 instance */
42 struct chacha20 {
43     /* Current context, usually with the count incremented
44      * 0-3 are the static constant
45      * 4-11 are the key
46      * 12-13 are the counter
47      * 14-15 are the IV */
48     uint32 state[16];
49     /* The output of the state above ready to xor */
50     unsigned char current[64];
51     /* The index of the above currently used to allow a true streaming cipher */
52     int currentIndex;
53 };
54
55 static INLINE void chacha20_round(struct chacha20 *ctx)
56 {
57     int i;
58     uint32 copy[16];
59
60     /* Take a copy */
61     memcpy(copy, ctx->state, sizeof(copy));
62
63     /* A circular rotation for a 32bit number */
64 #define rotl(x, shift) x = ((x << shift) | (x >> (32 - shift)))
65
66     /* What to do for each quarter round operation */
67 #define qrop(a, b, c, d)                        \
68     copy[a] += copy[b];                         \
69     copy[c] ^= copy[a];                         \
70     rotl(copy[c], d)
71
72     /* A quarter round */
73 #define quarter(a, b, c, d)                     \
74     qrop(a, b, d, 16);                          \
75     qrop(c, d, b, 12);                          \
76     qrop(a, b, d, 8);                           \
77     qrop(c, d, b, 7)
78
79     /* Do 20 rounds, in pairs because every other is different */
80     for (i = 0; i < 20; i += 2) {
81         /* A round */
82         quarter(0, 4, 8, 12);
83         quarter(1, 5, 9, 13);
84         quarter(2, 6, 10, 14);
85         quarter(3, 7, 11, 15);
86         /* Another slightly different round */
87         quarter(0, 5, 10, 15);
88         quarter(1, 6, 11, 12);
89         quarter(2, 7, 8, 13);
90         quarter(3, 4, 9, 14);
91     }
92
93     /* Dump the macros, don't need them littering */
94 #undef rotl
95 #undef qrop
96 #undef quarter
97
98     /* Add the initial state */
99     for (i = 0; i < 16; ++i) {
100         copy[i] += ctx->state[i];
101     }
102
103     /* Update the content of the xor buffer */
104     for (i = 0; i < 16; ++i) {
105         ctx->current[i * 4 + 0] = copy[i] >> 0;
106         ctx->current[i * 4 + 1] = copy[i] >> 8;
107         ctx->current[i * 4 + 2] = copy[i] >> 16;
108         ctx->current[i * 4 + 3] = copy[i] >> 24;
109     }
110     /* State full, reset pointer to beginning */
111     ctx->currentIndex = 0;
112     smemclr(copy, sizeof(copy));
113
114     /* Increment round counter */
115     ++ctx->state[12];
116     /* Check for overflow, not done in one line so the 32 bits are chopped by the type */
117     if (!(uint32)(ctx->state[12])) {
118         ++ctx->state[13];
119     }
120 }
121
122 /* Initialise context with 256bit key */
123 static void chacha20_key(struct chacha20 *ctx, const unsigned char *key)
124 {
125     static const char constant[16] = "expand 32-byte k";
126
127     /* Add the fixed string to the start of the state */
128     ctx->state[0] = GET_32BIT_LSB_FIRST(constant + 0);
129     ctx->state[1] = GET_32BIT_LSB_FIRST(constant + 4);
130     ctx->state[2] = GET_32BIT_LSB_FIRST(constant + 8);
131     ctx->state[3] = GET_32BIT_LSB_FIRST(constant + 12);
132
133     /* Add the key */
134     ctx->state[4]  = GET_32BIT_LSB_FIRST(key + 0);
135     ctx->state[5]  = GET_32BIT_LSB_FIRST(key + 4);
136     ctx->state[6]  = GET_32BIT_LSB_FIRST(key + 8);
137     ctx->state[7]  = GET_32BIT_LSB_FIRST(key + 12);
138     ctx->state[8]  = GET_32BIT_LSB_FIRST(key + 16);
139     ctx->state[9]  = GET_32BIT_LSB_FIRST(key + 20);
140     ctx->state[10] = GET_32BIT_LSB_FIRST(key + 24);
141     ctx->state[11] = GET_32BIT_LSB_FIRST(key + 28);
142
143     /* New key, dump context */
144     ctx->currentIndex = 64;
145 }
146
147 static void chacha20_iv(struct chacha20 *ctx, const unsigned char *iv)
148 {
149     ctx->state[12] = 0;
150     ctx->state[13] = 0;
151     ctx->state[14] = GET_32BIT_MSB_FIRST(iv);
152     ctx->state[15] = GET_32BIT_MSB_FIRST(iv + 4);
153
154     /* New IV, dump context */
155     ctx->currentIndex = 64;
156 }
157
158 static void chacha20_encrypt(struct chacha20 *ctx, unsigned char *blk, int len)
159 {
160     while (len) {
161         /* If we don't have any state left, then cycle to the next */
162         if (ctx->currentIndex >= 64) {
163             chacha20_round(ctx);
164         }
165
166         /* Do the xor while there's some state left and some plaintext left */
167         while (ctx->currentIndex < 64 && len) {
168             *blk++ ^= ctx->current[ctx->currentIndex++];
169             --len;
170         }
171     }
172 }
173
174 /* Decrypt is encrypt... It's xor against a PRNG... */
175 static INLINE void chacha20_decrypt(struct chacha20 *ctx,
176                                     unsigned char *blk, int len)
177 {
178     chacha20_encrypt(ctx, blk, len);
179 }
180
181 /* Poly1305 implementation (no AES, nonce is not encrypted) */
182
183 #define NWORDS ((130 + BIGNUM_INT_BITS-1) / BIGNUM_INT_BITS)
184 typedef struct bigval {
185     BignumInt w[NWORDS];
186 } bigval;
187
188 static void bigval_clear(bigval *r)
189 {
190     int i;
191     for (i = 0; i < NWORDS; i++)
192         r->w[i] = 0;
193 }
194
195 static void bigval_import_le(bigval *r, const void *vdata, int len)
196 {
197     const unsigned char *data = (const unsigned char *)vdata;
198     int i;
199     bigval_clear(r);
200     for (i = 0; i < len; i++)
201         r->w[i / BIGNUM_INT_BYTES] |= data[i] << (8 * (i % BIGNUM_INT_BYTES));
202 }
203
204 static void bigval_export_le(const bigval *r, void *vdata, int len)
205 {
206     unsigned char *data = (unsigned char *)vdata;
207     int i;
208     for (i = 0; i < len; i++)
209         data[i] = r->w[i / BIGNUM_INT_BYTES] >> (8 * (i % BIGNUM_INT_BYTES));
210 }
211
212 /*
213  * Addition of bigvals, not mod p.
214  */
215 static void bigval_add(bigval *r, const bigval *a, const bigval *b)
216 {
217 #if BIGNUM_INT_BITS == 32
218     /* ./contrib/make1305.py add 32 */
219     BignumDblInt acclo;
220     acclo = 0;
221     acclo += a->w[0];
222     acclo += b->w[0];
223     r->w[0] = acclo;
224     acclo >>= 32;
225     acclo += a->w[1];
226     acclo += b->w[1];
227     r->w[1] = acclo;
228     acclo >>= 32;
229     acclo += a->w[2];
230     acclo += b->w[2];
231     r->w[2] = acclo;
232     acclo >>= 32;
233     acclo += a->w[3];
234     acclo += b->w[3];
235     r->w[3] = acclo;
236     acclo >>= 32;
237     acclo += a->w[4];
238     acclo += b->w[4];
239     r->w[4] = acclo;
240     acclo >>= 32;
241 #elif BIGNUM_INT_BITS == 16
242     /* ./contrib/make1305.py add 16 */
243     BignumDblInt acclo;
244     acclo = 0;
245     acclo += a->w[0];
246     acclo += b->w[0];
247     r->w[0] = acclo;
248     acclo >>= 16;
249     acclo += a->w[1];
250     acclo += b->w[1];
251     r->w[1] = acclo;
252     acclo >>= 16;
253     acclo += a->w[2];
254     acclo += b->w[2];
255     r->w[2] = acclo;
256     acclo >>= 16;
257     acclo += a->w[3];
258     acclo += b->w[3];
259     r->w[3] = acclo;
260     acclo >>= 16;
261     acclo += a->w[4];
262     acclo += b->w[4];
263     r->w[4] = acclo;
264     acclo >>= 16;
265     acclo += a->w[5];
266     acclo += b->w[5];
267     r->w[5] = acclo;
268     acclo >>= 16;
269     acclo += a->w[6];
270     acclo += b->w[6];
271     r->w[6] = acclo;
272     acclo >>= 16;
273     acclo += a->w[7];
274     acclo += b->w[7];
275     r->w[7] = acclo;
276     acclo >>= 16;
277     acclo += a->w[8];
278     acclo += b->w[8];
279     r->w[8] = acclo;
280     acclo >>= 16;
281 #else
282 #error Run contrib/make1305.py again with a different bit count
283 #endif
284 }
285
286 /*
287  * Multiplication of bigvals mod p. Uses r as temporary storage, so
288  * don't pass r aliasing a or b.
289  */
290 static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
291 {
292 #if BIGNUM_INT_BITS == 32
293     /* ./contrib/make1305.py mul 32 */
294     BignumDblInt tmp;
295     BignumDblInt acclo;
296     BignumDblInt acchi;
297     BignumDblInt acc2lo;
298     acclo = 0;
299     acchi = 0;
300     tmp = (BignumDblInt)(a->w[0]) * (b->w[0]);
301     acclo += tmp & BIGNUM_INT_MASK;
302     acchi += tmp >> 32;
303     r->w[0] = acclo;
304     acclo = acchi + (acclo >> 32);
305     acchi = 0;
306     tmp = (BignumDblInt)(a->w[0]) * (b->w[1]);
307     acclo += tmp & BIGNUM_INT_MASK;
308     acchi += tmp >> 32;
309     tmp = (BignumDblInt)(a->w[1]) * (b->w[0]);
310     acclo += tmp & BIGNUM_INT_MASK;
311     acchi += tmp >> 32;
312     r->w[1] = acclo;
313     acclo = acchi + (acclo >> 32);
314     acchi = 0;
315     tmp = (BignumDblInt)(a->w[0]) * (b->w[2]);
316     acclo += tmp & BIGNUM_INT_MASK;
317     acchi += tmp >> 32;
318     tmp = (BignumDblInt)(a->w[1]) * (b->w[1]);
319     acclo += tmp & BIGNUM_INT_MASK;
320     acchi += tmp >> 32;
321     tmp = (BignumDblInt)(a->w[2]) * (b->w[0]);
322     acclo += tmp & BIGNUM_INT_MASK;
323     acchi += tmp >> 32;
324     r->w[2] = acclo;
325     acclo = acchi + (acclo >> 32);
326     acchi = 0;
327     tmp = (BignumDblInt)(a->w[0]) * (b->w[3]);
328     acclo += tmp & BIGNUM_INT_MASK;
329     acchi += tmp >> 32;
330     tmp = (BignumDblInt)(a->w[1]) * (b->w[2]);
331     acclo += tmp & BIGNUM_INT_MASK;
332     acchi += tmp >> 32;
333     tmp = (BignumDblInt)(a->w[2]) * (b->w[1]);
334     acclo += tmp & BIGNUM_INT_MASK;
335     acchi += tmp >> 32;
336     tmp = (BignumDblInt)(a->w[3]) * (b->w[0]);
337     acclo += tmp & BIGNUM_INT_MASK;
338     acchi += tmp >> 32;
339     r->w[3] = acclo;
340     acclo = acchi + (acclo >> 32);
341     acchi = 0;
342     tmp = (BignumDblInt)(a->w[0]) * (b->w[4]);
343     acclo += tmp & BIGNUM_INT_MASK;
344     acchi += tmp >> 32;
345     tmp = (BignumDblInt)(a->w[1]) * (b->w[3]);
346     acclo += tmp & BIGNUM_INT_MASK;
347     acchi += tmp >> 32;
348     tmp = (BignumDblInt)(a->w[2]) * (b->w[2]);
349     acclo += tmp & BIGNUM_INT_MASK;
350     acchi += tmp >> 32;
351     tmp = (BignumDblInt)(a->w[3]) * (b->w[1]);
352     acclo += tmp & BIGNUM_INT_MASK;
353     acchi += tmp >> 32;
354     tmp = (BignumDblInt)(a->w[4]) * (b->w[0]);
355     acclo += tmp & BIGNUM_INT_MASK;
356     acchi += tmp >> 32;
357     r->w[4] = acclo & (((BignumInt)1 << 2)-1);
358     acc2lo = 0;
359     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
360     acclo = acchi + (acclo >> 32);
361     acchi = 0;
362     tmp = (BignumDblInt)(a->w[1]) * (b->w[4]);
363     acclo += tmp & BIGNUM_INT_MASK;
364     acchi += tmp >> 32;
365     tmp = (BignumDblInt)(a->w[2]) * (b->w[3]);
366     acclo += tmp & BIGNUM_INT_MASK;
367     acchi += tmp >> 32;
368     tmp = (BignumDblInt)(a->w[3]) * (b->w[2]);
369     acclo += tmp & BIGNUM_INT_MASK;
370     acchi += tmp >> 32;
371     tmp = (BignumDblInt)(a->w[4]) * (b->w[1]);
372     acclo += tmp & BIGNUM_INT_MASK;
373     acchi += tmp >> 32;
374     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
375     acc2lo += r->w[0];
376     r->w[0] = acc2lo;
377     acc2lo >>= 32;
378     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
379     acclo = acchi + (acclo >> 32);
380     acchi = 0;
381     tmp = (BignumDblInt)(a->w[2]) * (b->w[4]);
382     acclo += tmp & BIGNUM_INT_MASK;
383     acchi += tmp >> 32;
384     tmp = (BignumDblInt)(a->w[3]) * (b->w[3]);
385     acclo += tmp & BIGNUM_INT_MASK;
386     acchi += tmp >> 32;
387     tmp = (BignumDblInt)(a->w[4]) * (b->w[2]);
388     acclo += tmp & BIGNUM_INT_MASK;
389     acchi += tmp >> 32;
390     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
391     acc2lo += r->w[1];
392     r->w[1] = acc2lo;
393     acc2lo >>= 32;
394     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
395     acclo = acchi + (acclo >> 32);
396     acchi = 0;
397     tmp = (BignumDblInt)(a->w[3]) * (b->w[4]);
398     acclo += tmp & BIGNUM_INT_MASK;
399     acchi += tmp >> 32;
400     tmp = (BignumDblInt)(a->w[4]) * (b->w[3]);
401     acclo += tmp & BIGNUM_INT_MASK;
402     acchi += tmp >> 32;
403     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
404     acc2lo += r->w[2];
405     r->w[2] = acc2lo;
406     acc2lo >>= 32;
407     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
408     acclo = acchi + (acclo >> 32);
409     acchi = 0;
410     tmp = (BignumDblInt)(a->w[4]) * (b->w[4]);
411     acclo += tmp & BIGNUM_INT_MASK;
412     acchi += tmp >> 32;
413     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
414     acc2lo += r->w[3];
415     r->w[3] = acc2lo;
416     acc2lo >>= 32;
417     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 0);
418     acc2lo += r->w[4];
419     r->w[4] = acc2lo;
420     acc2lo = 0;
421     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 28)-1)) * ((BignumDblInt)25 << 0);
422     acclo = acchi + (acclo >> 32);
423     acchi = 0;
424     acc2lo += (acclo & (((BignumInt)1 << 4)-1)) * ((BignumDblInt)25 << 28);
425     acc2lo += r->w[0];
426     r->w[0] = acc2lo;
427     acc2lo >>= 32;
428     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 28)-1)) * ((BignumDblInt)25 << 0);
429     acclo = acchi + (acclo >> 32);
430     acchi = 0;
431     acc2lo += r->w[1];
432     r->w[1] = acc2lo;
433     acc2lo >>= 32;
434     acc2lo += r->w[2];
435     r->w[2] = acc2lo;
436     acc2lo >>= 32;
437     acc2lo += r->w[3];
438     r->w[3] = acc2lo;
439     acc2lo >>= 32;
440     acc2lo += r->w[4];
441     r->w[4] = acc2lo;
442     acc2lo >>= 32;
443 #elif BIGNUM_INT_BITS == 16
444     /* ./contrib/make1305.py mul 16 */
445     BignumDblInt tmp;
446     BignumDblInt acclo;
447     BignumDblInt acchi;
448     BignumDblInt acc2lo;
449     acclo = 0;
450     acchi = 0;
451     tmp = (BignumDblInt)(a->w[0]) * (b->w[0]);
452     acclo += tmp & BIGNUM_INT_MASK;
453     acchi += tmp >> 16;
454     r->w[0] = acclo;
455     acclo = acchi + (acclo >> 16);
456     acchi = 0;
457     tmp = (BignumDblInt)(a->w[0]) * (b->w[1]);
458     acclo += tmp & BIGNUM_INT_MASK;
459     acchi += tmp >> 16;
460     tmp = (BignumDblInt)(a->w[1]) * (b->w[0]);
461     acclo += tmp & BIGNUM_INT_MASK;
462     acchi += tmp >> 16;
463     r->w[1] = acclo;
464     acclo = acchi + (acclo >> 16);
465     acchi = 0;
466     tmp = (BignumDblInt)(a->w[0]) * (b->w[2]);
467     acclo += tmp & BIGNUM_INT_MASK;
468     acchi += tmp >> 16;
469     tmp = (BignumDblInt)(a->w[1]) * (b->w[1]);
470     acclo += tmp & BIGNUM_INT_MASK;
471     acchi += tmp >> 16;
472     tmp = (BignumDblInt)(a->w[2]) * (b->w[0]);
473     acclo += tmp & BIGNUM_INT_MASK;
474     acchi += tmp >> 16;
475     r->w[2] = acclo;
476     acclo = acchi + (acclo >> 16);
477     acchi = 0;
478     tmp = (BignumDblInt)(a->w[0]) * (b->w[3]);
479     acclo += tmp & BIGNUM_INT_MASK;
480     acchi += tmp >> 16;
481     tmp = (BignumDblInt)(a->w[1]) * (b->w[2]);
482     acclo += tmp & BIGNUM_INT_MASK;
483     acchi += tmp >> 16;
484     tmp = (BignumDblInt)(a->w[2]) * (b->w[1]);
485     acclo += tmp & BIGNUM_INT_MASK;
486     acchi += tmp >> 16;
487     tmp = (BignumDblInt)(a->w[3]) * (b->w[0]);
488     acclo += tmp & BIGNUM_INT_MASK;
489     acchi += tmp >> 16;
490     r->w[3] = acclo;
491     acclo = acchi + (acclo >> 16);
492     acchi = 0;
493     tmp = (BignumDblInt)(a->w[0]) * (b->w[4]);
494     acclo += tmp & BIGNUM_INT_MASK;
495     acchi += tmp >> 16;
496     tmp = (BignumDblInt)(a->w[1]) * (b->w[3]);
497     acclo += tmp & BIGNUM_INT_MASK;
498     acchi += tmp >> 16;
499     tmp = (BignumDblInt)(a->w[2]) * (b->w[2]);
500     acclo += tmp & BIGNUM_INT_MASK;
501     acchi += tmp >> 16;
502     tmp = (BignumDblInt)(a->w[3]) * (b->w[1]);
503     acclo += tmp & BIGNUM_INT_MASK;
504     acchi += tmp >> 16;
505     tmp = (BignumDblInt)(a->w[4]) * (b->w[0]);
506     acclo += tmp & BIGNUM_INT_MASK;
507     acchi += tmp >> 16;
508     r->w[4] = acclo;
509     acclo = acchi + (acclo >> 16);
510     acchi = 0;
511     tmp = (BignumDblInt)(a->w[0]) * (b->w[5]);
512     acclo += tmp & BIGNUM_INT_MASK;
513     acchi += tmp >> 16;
514     tmp = (BignumDblInt)(a->w[1]) * (b->w[4]);
515     acclo += tmp & BIGNUM_INT_MASK;
516     acchi += tmp >> 16;
517     tmp = (BignumDblInt)(a->w[2]) * (b->w[3]);
518     acclo += tmp & BIGNUM_INT_MASK;
519     acchi += tmp >> 16;
520     tmp = (BignumDblInt)(a->w[3]) * (b->w[2]);
521     acclo += tmp & BIGNUM_INT_MASK;
522     acchi += tmp >> 16;
523     tmp = (BignumDblInt)(a->w[4]) * (b->w[1]);
524     acclo += tmp & BIGNUM_INT_MASK;
525     acchi += tmp >> 16;
526     tmp = (BignumDblInt)(a->w[5]) * (b->w[0]);
527     acclo += tmp & BIGNUM_INT_MASK;
528     acchi += tmp >> 16;
529     r->w[5] = acclo;
530     acclo = acchi + (acclo >> 16);
531     acchi = 0;
532     tmp = (BignumDblInt)(a->w[0]) * (b->w[6]);
533     acclo += tmp & BIGNUM_INT_MASK;
534     acchi += tmp >> 16;
535     tmp = (BignumDblInt)(a->w[1]) * (b->w[5]);
536     acclo += tmp & BIGNUM_INT_MASK;
537     acchi += tmp >> 16;
538     tmp = (BignumDblInt)(a->w[2]) * (b->w[4]);
539     acclo += tmp & BIGNUM_INT_MASK;
540     acchi += tmp >> 16;
541     tmp = (BignumDblInt)(a->w[3]) * (b->w[3]);
542     acclo += tmp & BIGNUM_INT_MASK;
543     acchi += tmp >> 16;
544     tmp = (BignumDblInt)(a->w[4]) * (b->w[2]);
545     acclo += tmp & BIGNUM_INT_MASK;
546     acchi += tmp >> 16;
547     tmp = (BignumDblInt)(a->w[5]) * (b->w[1]);
548     acclo += tmp & BIGNUM_INT_MASK;
549     acchi += tmp >> 16;
550     tmp = (BignumDblInt)(a->w[6]) * (b->w[0]);
551     acclo += tmp & BIGNUM_INT_MASK;
552     acchi += tmp >> 16;
553     r->w[6] = acclo;
554     acclo = acchi + (acclo >> 16);
555     acchi = 0;
556     tmp = (BignumDblInt)(a->w[0]) * (b->w[7]);
557     acclo += tmp & BIGNUM_INT_MASK;
558     acchi += tmp >> 16;
559     tmp = (BignumDblInt)(a->w[1]) * (b->w[6]);
560     acclo += tmp & BIGNUM_INT_MASK;
561     acchi += tmp >> 16;
562     tmp = (BignumDblInt)(a->w[2]) * (b->w[5]);
563     acclo += tmp & BIGNUM_INT_MASK;
564     acchi += tmp >> 16;
565     tmp = (BignumDblInt)(a->w[3]) * (b->w[4]);
566     acclo += tmp & BIGNUM_INT_MASK;
567     acchi += tmp >> 16;
568     tmp = (BignumDblInt)(a->w[4]) * (b->w[3]);
569     acclo += tmp & BIGNUM_INT_MASK;
570     acchi += tmp >> 16;
571     tmp = (BignumDblInt)(a->w[5]) * (b->w[2]);
572     acclo += tmp & BIGNUM_INT_MASK;
573     acchi += tmp >> 16;
574     tmp = (BignumDblInt)(a->w[6]) * (b->w[1]);
575     acclo += tmp & BIGNUM_INT_MASK;
576     acchi += tmp >> 16;
577     tmp = (BignumDblInt)(a->w[7]) * (b->w[0]);
578     acclo += tmp & BIGNUM_INT_MASK;
579     acchi += tmp >> 16;
580     r->w[7] = acclo;
581     acclo = acchi + (acclo >> 16);
582     acchi = 0;
583     tmp = (BignumDblInt)(a->w[0]) * (b->w[8]);
584     acclo += tmp & BIGNUM_INT_MASK;
585     acchi += tmp >> 16;
586     tmp = (BignumDblInt)(a->w[1]) * (b->w[7]);
587     acclo += tmp & BIGNUM_INT_MASK;
588     acchi += tmp >> 16;
589     tmp = (BignumDblInt)(a->w[2]) * (b->w[6]);
590     acclo += tmp & BIGNUM_INT_MASK;
591     acchi += tmp >> 16;
592     tmp = (BignumDblInt)(a->w[3]) * (b->w[5]);
593     acclo += tmp & BIGNUM_INT_MASK;
594     acchi += tmp >> 16;
595     tmp = (BignumDblInt)(a->w[4]) * (b->w[4]);
596     acclo += tmp & BIGNUM_INT_MASK;
597     acchi += tmp >> 16;
598     tmp = (BignumDblInt)(a->w[5]) * (b->w[3]);
599     acclo += tmp & BIGNUM_INT_MASK;
600     acchi += tmp >> 16;
601     tmp = (BignumDblInt)(a->w[6]) * (b->w[2]);
602     acclo += tmp & BIGNUM_INT_MASK;
603     acchi += tmp >> 16;
604     tmp = (BignumDblInt)(a->w[7]) * (b->w[1]);
605     acclo += tmp & BIGNUM_INT_MASK;
606     acchi += tmp >> 16;
607     tmp = (BignumDblInt)(a->w[8]) * (b->w[0]);
608     acclo += tmp & BIGNUM_INT_MASK;
609     acchi += tmp >> 16;
610     r->w[8] = acclo & (((BignumInt)1 << 2)-1);
611     acc2lo = 0;
612     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
613     acclo = acchi + (acclo >> 16);
614     acchi = 0;
615     tmp = (BignumDblInt)(a->w[1]) * (b->w[8]);
616     acclo += tmp & BIGNUM_INT_MASK;
617     acchi += tmp >> 16;
618     tmp = (BignumDblInt)(a->w[2]) * (b->w[7]);
619     acclo += tmp & BIGNUM_INT_MASK;
620     acchi += tmp >> 16;
621     tmp = (BignumDblInt)(a->w[3]) * (b->w[6]);
622     acclo += tmp & BIGNUM_INT_MASK;
623     acchi += tmp >> 16;
624     tmp = (BignumDblInt)(a->w[4]) * (b->w[5]);
625     acclo += tmp & BIGNUM_INT_MASK;
626     acchi += tmp >> 16;
627     tmp = (BignumDblInt)(a->w[5]) * (b->w[4]);
628     acclo += tmp & BIGNUM_INT_MASK;
629     acchi += tmp >> 16;
630     tmp = (BignumDblInt)(a->w[6]) * (b->w[3]);
631     acclo += tmp & BIGNUM_INT_MASK;
632     acchi += tmp >> 16;
633     tmp = (BignumDblInt)(a->w[7]) * (b->w[2]);
634     acclo += tmp & BIGNUM_INT_MASK;
635     acchi += tmp >> 16;
636     tmp = (BignumDblInt)(a->w[8]) * (b->w[1]);
637     acclo += tmp & BIGNUM_INT_MASK;
638     acchi += tmp >> 16;
639     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
640     acc2lo += r->w[0];
641     r->w[0] = acc2lo;
642     acc2lo >>= 16;
643     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
644     acclo = acchi + (acclo >> 16);
645     acchi = 0;
646     tmp = (BignumDblInt)(a->w[2]) * (b->w[8]);
647     acclo += tmp & BIGNUM_INT_MASK;
648     acchi += tmp >> 16;
649     tmp = (BignumDblInt)(a->w[3]) * (b->w[7]);
650     acclo += tmp & BIGNUM_INT_MASK;
651     acchi += tmp >> 16;
652     tmp = (BignumDblInt)(a->w[4]) * (b->w[6]);
653     acclo += tmp & BIGNUM_INT_MASK;
654     acchi += tmp >> 16;
655     tmp = (BignumDblInt)(a->w[5]) * (b->w[5]);
656     acclo += tmp & BIGNUM_INT_MASK;
657     acchi += tmp >> 16;
658     tmp = (BignumDblInt)(a->w[6]) * (b->w[4]);
659     acclo += tmp & BIGNUM_INT_MASK;
660     acchi += tmp >> 16;
661     tmp = (BignumDblInt)(a->w[7]) * (b->w[3]);
662     acclo += tmp & BIGNUM_INT_MASK;
663     acchi += tmp >> 16;
664     tmp = (BignumDblInt)(a->w[8]) * (b->w[2]);
665     acclo += tmp & BIGNUM_INT_MASK;
666     acchi += tmp >> 16;
667     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
668     acc2lo += r->w[1];
669     r->w[1] = acc2lo;
670     acc2lo >>= 16;
671     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
672     acclo = acchi + (acclo >> 16);
673     acchi = 0;
674     tmp = (BignumDblInt)(a->w[3]) * (b->w[8]);
675     acclo += tmp & BIGNUM_INT_MASK;
676     acchi += tmp >> 16;
677     tmp = (BignumDblInt)(a->w[4]) * (b->w[7]);
678     acclo += tmp & BIGNUM_INT_MASK;
679     acchi += tmp >> 16;
680     tmp = (BignumDblInt)(a->w[5]) * (b->w[6]);
681     acclo += tmp & BIGNUM_INT_MASK;
682     acchi += tmp >> 16;
683     tmp = (BignumDblInt)(a->w[6]) * (b->w[5]);
684     acclo += tmp & BIGNUM_INT_MASK;
685     acchi += tmp >> 16;
686     tmp = (BignumDblInt)(a->w[7]) * (b->w[4]);
687     acclo += tmp & BIGNUM_INT_MASK;
688     acchi += tmp >> 16;
689     tmp = (BignumDblInt)(a->w[8]) * (b->w[3]);
690     acclo += tmp & BIGNUM_INT_MASK;
691     acchi += tmp >> 16;
692     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
693     acc2lo += r->w[2];
694     r->w[2] = acc2lo;
695     acc2lo >>= 16;
696     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
697     acclo = acchi + (acclo >> 16);
698     acchi = 0;
699     tmp = (BignumDblInt)(a->w[4]) * (b->w[8]);
700     acclo += tmp & BIGNUM_INT_MASK;
701     acchi += tmp >> 16;
702     tmp = (BignumDblInt)(a->w[5]) * (b->w[7]);
703     acclo += tmp & BIGNUM_INT_MASK;
704     acchi += tmp >> 16;
705     tmp = (BignumDblInt)(a->w[6]) * (b->w[6]);
706     acclo += tmp & BIGNUM_INT_MASK;
707     acchi += tmp >> 16;
708     tmp = (BignumDblInt)(a->w[7]) * (b->w[5]);
709     acclo += tmp & BIGNUM_INT_MASK;
710     acchi += tmp >> 16;
711     tmp = (BignumDblInt)(a->w[8]) * (b->w[4]);
712     acclo += tmp & BIGNUM_INT_MASK;
713     acchi += tmp >> 16;
714     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
715     acc2lo += r->w[3];
716     r->w[3] = acc2lo;
717     acc2lo >>= 16;
718     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
719     acclo = acchi + (acclo >> 16);
720     acchi = 0;
721     tmp = (BignumDblInt)(a->w[5]) * (b->w[8]);
722     acclo += tmp & BIGNUM_INT_MASK;
723     acchi += tmp >> 16;
724     tmp = (BignumDblInt)(a->w[6]) * (b->w[7]);
725     acclo += tmp & BIGNUM_INT_MASK;
726     acchi += tmp >> 16;
727     tmp = (BignumDblInt)(a->w[7]) * (b->w[6]);
728     acclo += tmp & BIGNUM_INT_MASK;
729     acchi += tmp >> 16;
730     tmp = (BignumDblInt)(a->w[8]) * (b->w[5]);
731     acclo += tmp & BIGNUM_INT_MASK;
732     acchi += tmp >> 16;
733     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
734     acc2lo += r->w[4];
735     r->w[4] = acc2lo;
736     acc2lo >>= 16;
737     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
738     acclo = acchi + (acclo >> 16);
739     acchi = 0;
740     tmp = (BignumDblInt)(a->w[6]) * (b->w[8]);
741     acclo += tmp & BIGNUM_INT_MASK;
742     acchi += tmp >> 16;
743     tmp = (BignumDblInt)(a->w[7]) * (b->w[7]);
744     acclo += tmp & BIGNUM_INT_MASK;
745     acchi += tmp >> 16;
746     tmp = (BignumDblInt)(a->w[8]) * (b->w[6]);
747     acclo += tmp & BIGNUM_INT_MASK;
748     acchi += tmp >> 16;
749     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
750     acc2lo += r->w[5];
751     r->w[5] = acc2lo;
752     acc2lo >>= 16;
753     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
754     acclo = acchi + (acclo >> 16);
755     acchi = 0;
756     tmp = (BignumDblInt)(a->w[7]) * (b->w[8]);
757     acclo += tmp & BIGNUM_INT_MASK;
758     acchi += tmp >> 16;
759     tmp = (BignumDblInt)(a->w[8]) * (b->w[7]);
760     acclo += tmp & BIGNUM_INT_MASK;
761     acchi += tmp >> 16;
762     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
763     acc2lo += r->w[6];
764     r->w[6] = acc2lo;
765     acc2lo >>= 16;
766     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
767     acclo = acchi + (acclo >> 16);
768     acchi = 0;
769     tmp = (BignumDblInt)(a->w[8]) * (b->w[8]);
770     acclo += tmp & BIGNUM_INT_MASK;
771     acchi += tmp >> 16;
772     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
773     acc2lo += r->w[7];
774     r->w[7] = acc2lo;
775     acc2lo >>= 16;
776     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 0);
777     acc2lo += r->w[8];
778     r->w[8] = acc2lo;
779     acc2lo = 0;
780     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 12)-1)) * ((BignumDblInt)25 << 0);
781     acclo = acchi + (acclo >> 16);
782     acchi = 0;
783     acc2lo += (acclo & (((BignumInt)1 << 4)-1)) * ((BignumDblInt)25 << 12);
784     acc2lo += r->w[0];
785     r->w[0] = acc2lo;
786     acc2lo >>= 16;
787     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 12)-1)) * ((BignumDblInt)25 << 0);
788     acclo = acchi + (acclo >> 16);
789     acchi = 0;
790     acc2lo += r->w[1];
791     r->w[1] = acc2lo;
792     acc2lo >>= 16;
793     acc2lo += r->w[2];
794     r->w[2] = acc2lo;
795     acc2lo >>= 16;
796     acc2lo += r->w[3];
797     r->w[3] = acc2lo;
798     acc2lo >>= 16;
799     acc2lo += r->w[4];
800     r->w[4] = acc2lo;
801     acc2lo >>= 16;
802     acc2lo += r->w[5];
803     r->w[5] = acc2lo;
804     acc2lo >>= 16;
805     acc2lo += r->w[6];
806     r->w[6] = acc2lo;
807     acc2lo >>= 16;
808     acc2lo += r->w[7];
809     r->w[7] = acc2lo;
810     acc2lo >>= 16;
811     acc2lo += r->w[8];
812     r->w[8] = acc2lo;
813     acc2lo >>= 16;
814 #else
815 #error Run contrib/make1305.py again with a different bit count
816 #endif
817 }
818
819 static void bigval_final_reduce(bigval *n)
820 {
821 #if BIGNUM_INT_BITS == 32
822     /* ./contrib/make1305.py final_reduce 32 */
823     BignumDblInt acclo;
824     acclo = 0;
825     acclo += 5 * ((n->w[4] >> 2) + 1);
826     acclo += n->w[0];
827     acclo >>= 32;
828     acclo += n->w[1];
829     acclo >>= 32;
830     acclo += n->w[2];
831     acclo >>= 32;
832     acclo += n->w[3];
833     acclo >>= 32;
834     acclo += n->w[4];
835     acclo = 5 * (acclo >> 2);
836     acclo += n->w[0];
837     n->w[0] = acclo;
838     acclo >>= 32;
839     acclo += n->w[1];
840     n->w[1] = acclo;
841     acclo >>= 32;
842     acclo += n->w[2];
843     n->w[2] = acclo;
844     acclo >>= 32;
845     acclo += n->w[3];
846     n->w[3] = acclo;
847     acclo >>= 32;
848     acclo += n->w[4];
849     n->w[4] = acclo;
850     acclo >>= 32;
851     n->w[4] &= (1 << 2) - 1;
852 #elif BIGNUM_INT_BITS == 16
853     /* ./contrib/make1305.py final_reduce 16 */
854     BignumDblInt acclo;
855     acclo = 0;
856     acclo += 5 * ((n->w[8] >> 2) + 1);
857     acclo += n->w[0];
858     acclo >>= 16;
859     acclo += n->w[1];
860     acclo >>= 16;
861     acclo += n->w[2];
862     acclo >>= 16;
863     acclo += n->w[3];
864     acclo >>= 16;
865     acclo += n->w[4];
866     acclo >>= 16;
867     acclo += n->w[5];
868     acclo >>= 16;
869     acclo += n->w[6];
870     acclo >>= 16;
871     acclo += n->w[7];
872     acclo >>= 16;
873     acclo += n->w[8];
874     acclo = 5 * (acclo >> 2);
875     acclo += n->w[0];
876     n->w[0] = acclo;
877     acclo >>= 16;
878     acclo += n->w[1];
879     n->w[1] = acclo;
880     acclo >>= 16;
881     acclo += n->w[2];
882     n->w[2] = acclo;
883     acclo >>= 16;
884     acclo += n->w[3];
885     n->w[3] = acclo;
886     acclo >>= 16;
887     acclo += n->w[4];
888     n->w[4] = acclo;
889     acclo >>= 16;
890     acclo += n->w[5];
891     n->w[5] = acclo;
892     acclo >>= 16;
893     acclo += n->w[6];
894     n->w[6] = acclo;
895     acclo >>= 16;
896     acclo += n->w[7];
897     n->w[7] = acclo;
898     acclo >>= 16;
899     acclo += n->w[8];
900     n->w[8] = acclo;
901     acclo >>= 16;
902     n->w[8] &= (1 << 2) - 1;
903 #else
904 #error Run contrib/make1305.py again with a different bit count
905 #endif
906 }
907
908 struct poly1305 {
909     unsigned char nonce[16];
910     bigval r;
911     bigval h;
912
913     /* Buffer in case we get less that a multiple of 16 bytes */
914     unsigned char buffer[16];
915     int bufferIndex;
916 };
917
918 static void poly1305_init(struct poly1305 *ctx)
919 {
920     memset(ctx->nonce, 0, 16);
921     ctx->bufferIndex = 0;
922     bigval_clear(&ctx->h);
923 }
924
925 /* Takes a 256 bit key */
926 static void poly1305_key(struct poly1305 *ctx, const unsigned char *key)
927 {
928     unsigned char key_copy[16];
929     memcpy(key_copy, key, 16);
930
931     /* Key the MAC itself
932      * bytes 4, 8, 12 and 16 are required to have their top four bits clear */
933     key_copy[3] &= 0x0f;
934     key_copy[7] &= 0x0f;
935     key_copy[11] &= 0x0f;
936     key_copy[15] &= 0x0f;
937     /* bytes 5, 9 and 13 are required to have their bottom two bits clear */
938     key_copy[4] &= 0xfc;
939     key_copy[8] &= 0xfc;
940     key_copy[12] &= 0xfc;
941     bigval_import_le(&ctx->r, key_copy, 16);
942     smemclr(key_copy, sizeof(key_copy));
943
944     /* Use second 128 bits are the nonce */
945     memcpy(ctx->nonce, key+16, 16);
946 }
947
948 /* Feed up to 16 bytes (should only be less for the last chunk) */
949 static void poly1305_feed_chunk(struct poly1305 *ctx,
950                                 const unsigned char *chunk, int len)
951 {
952     bigval c;
953     bigval_import_le(&c, chunk, len);
954     c.w[len / BIGNUM_INT_BYTES] |= 1 << (8 * (len % BIGNUM_INT_BYTES));
955     bigval_add(&c, &c, &ctx->h);
956     bigval_mul_mod_p(&ctx->h, &c, &ctx->r);
957 }
958
959 static void poly1305_feed(struct poly1305 *ctx,
960                           const unsigned char *buf, int len)
961 {
962     /* Check for stuff left in the buffer from last time */
963     if (ctx->bufferIndex) {
964         /* Try to fill up to 16 */
965         while (ctx->bufferIndex < 16 && len) {
966             ctx->buffer[ctx->bufferIndex++] = *buf++;
967             --len;
968         }
969         if (ctx->bufferIndex == 16) {
970             poly1305_feed_chunk(ctx, ctx->buffer, 16);
971             ctx->bufferIndex = 0;
972         }
973     }
974
975     /* Process 16 byte whole chunks */
976     while (len >= 16) {
977         poly1305_feed_chunk(ctx, buf, 16);
978         len -= 16;
979         buf += 16;
980     }
981
982     /* Cache stuff that's left over */
983     if (len) {
984         memcpy(ctx->buffer, buf, len);
985         ctx->bufferIndex = len;
986     }
987 }
988
989 /* Finalise and populate buffer with 16 byte with MAC */
990 static void poly1305_finalise(struct poly1305 *ctx, unsigned char *mac)
991 {
992     bigval tmp;
993
994     if (ctx->bufferIndex) {
995         poly1305_feed_chunk(ctx, ctx->buffer, ctx->bufferIndex);
996     }
997
998     bigval_import_le(&tmp, ctx->nonce, 16);
999     bigval_final_reduce(&ctx->h);
1000     bigval_add(&tmp, &tmp, &ctx->h);
1001     bigval_export_le(&tmp, mac, 16);
1002 }
1003
1004 /* SSH-2 wrapper */
1005
1006 struct ccp_context {
1007     struct chacha20 a_cipher; /* Used for length */
1008     struct chacha20 b_cipher; /* Used for content */
1009
1010     /* Cache of the first 4 bytes because they are the sequence number */
1011     /* Kept in 8 bytes with the top as zero to allow easy passing to setiv */
1012     int mac_initialised; /* Where we have got to in filling mac_iv */
1013     unsigned char mac_iv[8];
1014
1015     struct poly1305 mac;
1016 };
1017
1018 static void *poly_make_context(void *ctx)
1019 {
1020     return ctx;
1021 }
1022
1023 static void poly_free_context(void *ctx)
1024 {
1025     /* Not allocated, just forwarded, no need to free */
1026 }
1027
1028 static void poly_setkey(void *ctx, unsigned char *key)
1029 {
1030     /* Uses the same context as ChaCha20, so ignore */
1031 }
1032
1033 static void poly_start(void *handle)
1034 {
1035     struct ccp_context *ctx = (struct ccp_context *)handle;
1036
1037     ctx->mac_initialised = 0;
1038     memset(ctx->mac_iv, 0, 8);
1039     poly1305_init(&ctx->mac);
1040 }
1041
1042 static void poly_bytes(void *handle, unsigned char const *blk, int len)
1043 {
1044     struct ccp_context *ctx = (struct ccp_context *)handle;
1045
1046     /* First 4 bytes are the IV */
1047     while (ctx->mac_initialised < 4 && len) {
1048         ctx->mac_iv[7 - ctx->mac_initialised] = *blk++;
1049         ++ctx->mac_initialised;
1050         --len;
1051     }
1052
1053     /* Initialise the IV if needed */
1054     if (ctx->mac_initialised == 4) {
1055         chacha20_iv(&ctx->b_cipher, ctx->mac_iv);
1056         ++ctx->mac_initialised;  /* Don't do it again */
1057
1058         /* Do first rotation */
1059         chacha20_round(&ctx->b_cipher);
1060
1061         /* Set the poly key */
1062         poly1305_key(&ctx->mac, ctx->b_cipher.current);
1063
1064         /* Set the first round as used */
1065         ctx->b_cipher.currentIndex = 64;
1066     }
1067
1068     /* Update the MAC with anything left */
1069     if (len) {
1070         poly1305_feed(&ctx->mac, blk, len);
1071     }
1072 }
1073
1074 static void poly_genresult(void *handle, unsigned char *blk)
1075 {
1076     struct ccp_context *ctx = (struct ccp_context *)handle;
1077     poly1305_finalise(&ctx->mac, blk);
1078 }
1079
1080 static int poly_verresult(void *handle, unsigned char const *blk)
1081 {
1082     struct ccp_context *ctx = (struct ccp_context *)handle;
1083     int res;
1084     unsigned char mac[16];
1085     poly1305_finalise(&ctx->mac, mac);
1086     res = smemeq(blk, mac, 16);
1087     return res;
1088 }
1089
1090 /* The generic poly operation used before generate and verify */
1091 static void poly_op(void *handle, unsigned char *blk, int len, unsigned long seq)
1092 {
1093     unsigned char iv[4];
1094     poly_start(handle);
1095     PUT_32BIT_MSB_FIRST(iv, seq);
1096     /* poly_bytes expects the first 4 bytes to be the IV */
1097     poly_bytes(handle, iv, 4);
1098     smemclr(iv, sizeof(iv));
1099     poly_bytes(handle, blk, len);
1100 }
1101
1102 static void poly_generate(void *handle, unsigned char *blk, int len, unsigned long seq)
1103 {
1104     poly_op(handle, blk, len, seq);
1105     poly_genresult(handle, blk+len);
1106 }
1107
1108 static int poly_verify(void *handle, unsigned char *blk, int len, unsigned long seq)
1109 {
1110     poly_op(handle, blk, len, seq);
1111     return poly_verresult(handle, blk+len);
1112 }
1113
1114 static const struct ssh_mac ssh2_poly1305 = {
1115     poly_make_context, poly_free_context,
1116     poly_setkey,
1117
1118     /* whole-packet operations */
1119     poly_generate, poly_verify,
1120
1121     /* partial-packet operations */
1122     poly_start, poly_bytes, poly_genresult, poly_verresult,
1123
1124     "", "", /* Not selectable individually, just part of ChaCha20-Poly1305 */
1125     16, "Poly1305"
1126 };
1127
1128 static void *ccp_make_context(void)
1129 {
1130     struct ccp_context *ctx = snew(struct ccp_context);
1131     if (ctx) {
1132         poly1305_init(&ctx->mac);
1133     }
1134     return ctx;
1135 }
1136
1137 static void ccp_free_context(void *vctx)
1138 {
1139     struct ccp_context *ctx = (struct ccp_context *)vctx;
1140     smemclr(&ctx->a_cipher, sizeof(ctx->a_cipher));
1141     smemclr(&ctx->b_cipher, sizeof(ctx->b_cipher));
1142     smemclr(&ctx->mac, sizeof(ctx->mac));
1143     sfree(ctx);
1144 }
1145
1146 static void ccp_iv(void *vctx, unsigned char *iv)
1147 {
1148     /* struct ccp_context *ctx = (struct ccp_context *)vctx; */
1149     /* IV is set based on the sequence number */
1150 }
1151
1152 static void ccp_key(void *vctx, unsigned char *key)
1153 {
1154     struct ccp_context *ctx = (struct ccp_context *)vctx;
1155     /* Initialise the a_cipher (for decrypting lengths) with the first 256 bits */
1156     chacha20_key(&ctx->a_cipher, key + 32);
1157     /* Initialise the b_cipher (for content and MAC) with the second 256 bits */
1158     chacha20_key(&ctx->b_cipher, key);
1159 }
1160
1161 static void ccp_encrypt(void *vctx, unsigned char *blk, int len)
1162 {
1163     struct ccp_context *ctx = (struct ccp_context *)vctx;
1164     chacha20_encrypt(&ctx->b_cipher, blk, len);
1165 }
1166
1167 static void ccp_decrypt(void *vctx, unsigned char *blk, int len)
1168 {
1169     struct ccp_context *ctx = (struct ccp_context *)vctx;
1170     chacha20_decrypt(&ctx->b_cipher, blk, len);
1171 }
1172
1173 static void ccp_length_op(struct ccp_context *ctx, unsigned char *blk, int len,
1174                           unsigned long seq)
1175 {
1176     unsigned char iv[8];
1177     PUT_32BIT_LSB_FIRST(iv, seq >> 32);
1178     PUT_32BIT_LSB_FIRST(iv + 4, seq);
1179     chacha20_iv(&ctx->a_cipher, iv);
1180     chacha20_iv(&ctx->b_cipher, iv);
1181     /* Reset content block count to 1, as the first is the key for Poly1305 */
1182     ++ctx->b_cipher.state[12];
1183     smemclr(iv, sizeof(iv));
1184 }
1185
1186 static void ccp_encrypt_length(void *vctx, unsigned char *blk, int len,
1187                                unsigned long seq)
1188 {
1189     struct ccp_context *ctx = (struct ccp_context *)vctx;
1190     ccp_length_op(ctx, blk, len, seq);
1191     chacha20_encrypt(&ctx->a_cipher, blk, len);
1192 }
1193
1194 static void ccp_decrypt_length(void *vctx, unsigned char *blk, int len,
1195                                unsigned long seq)
1196 {
1197     struct ccp_context *ctx = (struct ccp_context *)vctx;
1198     ccp_length_op(ctx, blk, len, seq);
1199     chacha20_decrypt(&ctx->a_cipher, blk, len);
1200 }
1201
1202 static const struct ssh2_cipher ssh2_chacha20_poly1305 = {
1203
1204     ccp_make_context,
1205     ccp_free_context,
1206     ccp_iv,
1207     ccp_key,
1208     ccp_encrypt,
1209     ccp_decrypt,
1210     ccp_encrypt_length,
1211     ccp_decrypt_length,
1212
1213     "chacha20-poly1305@openssh.com",
1214     1, 512, SSH_CIPHER_SEPARATE_LENGTH, "ChaCha20",
1215
1216     &ssh2_poly1305
1217 };
1218
1219 static const struct ssh2_cipher *const ccp_list[] = {
1220     &ssh2_chacha20_poly1305
1221 };
1222
1223 const struct ssh2_ciphers ssh2_ccp = {
1224     sizeof(ccp_list) / sizeof(*ccp_list),
1225     ccp_list
1226 };