]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshccp.c
Merge tag '0.65'
[PuTTY.git] / sshccp.c
1 /*
2  * ChaCha20-Poly1305 Implementation for SSH-2
3  *
4  * Protocol spec:
5  *  http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.chacha20poly1305?rev=1.2&content-type=text/x-cvsweb-markup
6  *
7  * ChaCha20 spec:
8  *  http://cr.yp.to/chacha/chacha-20080128.pdf
9  *
10  * Salsa20 spec:
11  *  http://cr.yp.to/snuffle/spec.pdf
12  *
13  * Poly1305-AES spec:
14  *  http://cr.yp.to/mac/poly1305-20050329.pdf
15  *
16  * The nonce for the Poly1305 is the second part of the key output
17  * from the first round of ChaCha20. This removes the AES requirement.
18  * This is undocumented!
19  *
20  * This has an intricate link between the cipher and the MAC. The
21  * keying of both is done in by the cipher and setting of the IV is
22  * done by the MAC. One cannot operate without the other. The
23  * configuration of the ssh2_cipher structure ensures that the MAC is
24  * set (and others ignored) if this cipher is chosen.
25  *
26  * This cipher also encrypts the length using a different
27  * instantiation of the cipher using a different key and IV made from
28  * the sequence number which is passed in addition when calling
29  * encrypt/decrypt on it.
30  */
31
32 #include "ssh.h"
33 #include "sshbn.h"
34
35 #ifndef INLINE
36 #define INLINE
37 #endif
38
39 /* ChaCha20 implementation, only supporting 256-bit keys */
40
41 /* State for each ChaCha20 instance */
42 struct chacha20 {
43     /* Current context, usually with the count incremented
44      * 0-3 are the static constant
45      * 4-11 are the key
46      * 12-13 are the counter
47      * 14-15 are the IV */
48     uint32 state[16];
49     /* The output of the state above ready to xor */
50     unsigned char current[64];
51     /* The index of the above currently used to allow a true streaming cipher */
52     int currentIndex;
53 };
54
55 static INLINE void chacha20_round(struct chacha20 *ctx)
56 {
57     int i;
58     uint32 copy[16];
59
60     /* Take a copy */
61     memcpy(copy, ctx->state, sizeof(copy));
62
63     /* A circular rotation for a 32bit number */
64 #define rotl(x, shift) x = ((x << shift) | (x >> (32 - shift)))
65
66     /* What to do for each quarter round operation */
67 #define qrop(a, b, c, d)                        \
68     copy[a] += copy[b];                         \
69     copy[c] ^= copy[a];                         \
70     rotl(copy[c], d)
71
72     /* A quarter round */
73 #define quarter(a, b, c, d)                     \
74     qrop(a, b, d, 16);                          \
75     qrop(c, d, b, 12);                          \
76     qrop(a, b, d, 8);                           \
77     qrop(c, d, b, 7)
78
79     /* Do 20 rounds, in pairs because every other is different */
80     for (i = 0; i < 20; i += 2) {
81         /* A round */
82         quarter(0, 4, 8, 12);
83         quarter(1, 5, 9, 13);
84         quarter(2, 6, 10, 14);
85         quarter(3, 7, 11, 15);
86         /* Another slightly different round */
87         quarter(0, 5, 10, 15);
88         quarter(1, 6, 11, 12);
89         quarter(2, 7, 8, 13);
90         quarter(3, 4, 9, 14);
91     }
92
93     /* Dump the macros, don't need them littering */
94 #undef rotl
95 #undef qrop
96 #undef quarter
97
98     /* Add the initial state */
99     for (i = 0; i < 16; ++i) {
100         copy[i] += ctx->state[i];
101     }
102
103     /* Update the content of the xor buffer */
104     for (i = 0; i < 16; ++i) {
105         ctx->current[i * 4 + 0] = copy[i] >> 0;
106         ctx->current[i * 4 + 1] = copy[i] >> 8;
107         ctx->current[i * 4 + 2] = copy[i] >> 16;
108         ctx->current[i * 4 + 3] = copy[i] >> 24;
109     }
110     /* State full, reset pointer to beginning */
111     ctx->currentIndex = 0;
112     smemclr(copy, sizeof(copy));
113
114     /* Increment round counter */
115     ++ctx->state[12];
116     /* Check for overflow, not done in one line so the 32 bits are chopped by the type */
117     if (!(uint32)(ctx->state[12])) {
118         ++ctx->state[13];
119     }
120 }
121
122 /* Initialise context with 256bit key */
123 static void chacha20_key(struct chacha20 *ctx, const unsigned char *key)
124 {
125     static const char constant[16] = "expand 32-byte k";
126
127     /* Add the fixed string to the start of the state */
128     ctx->state[0] = GET_32BIT_LSB_FIRST(constant + 0);
129     ctx->state[1] = GET_32BIT_LSB_FIRST(constant + 4);
130     ctx->state[2] = GET_32BIT_LSB_FIRST(constant + 8);
131     ctx->state[3] = GET_32BIT_LSB_FIRST(constant + 12);
132
133     /* Add the key */
134     ctx->state[4]  = GET_32BIT_LSB_FIRST(key + 0);
135     ctx->state[5]  = GET_32BIT_LSB_FIRST(key + 4);
136     ctx->state[6]  = GET_32BIT_LSB_FIRST(key + 8);
137     ctx->state[7]  = GET_32BIT_LSB_FIRST(key + 12);
138     ctx->state[8]  = GET_32BIT_LSB_FIRST(key + 16);
139     ctx->state[9]  = GET_32BIT_LSB_FIRST(key + 20);
140     ctx->state[10] = GET_32BIT_LSB_FIRST(key + 24);
141     ctx->state[11] = GET_32BIT_LSB_FIRST(key + 28);
142
143     /* New key, dump context */
144     ctx->currentIndex = 64;
145 }
146
147 static void chacha20_iv(struct chacha20 *ctx, const unsigned char *iv)
148 {
149     ctx->state[12] = 0;
150     ctx->state[13] = 0;
151     ctx->state[14] = GET_32BIT_MSB_FIRST(iv);
152     ctx->state[15] = GET_32BIT_MSB_FIRST(iv + 4);
153
154     /* New IV, dump context */
155     ctx->currentIndex = 64;
156 }
157
158 static void chacha20_encrypt(struct chacha20 *ctx, unsigned char *blk, int len)
159 {
160     while (len) {
161         /* If we don't have any state left, then cycle to the next */
162         if (ctx->currentIndex >= 64) {
163             chacha20_round(ctx);
164         }
165
166         /* Do the xor while there's some state left and some plaintext left */
167         while (ctx->currentIndex < 64 && len) {
168             *blk++ ^= ctx->current[ctx->currentIndex++];
169             --len;
170         }
171     }
172 }
173
174 /* Decrypt is encrypt... It's xor against a PRNG... */
175 static INLINE void chacha20_decrypt(struct chacha20 *ctx,
176                                     unsigned char *blk, int len)
177 {
178     chacha20_encrypt(ctx, blk, len);
179 }
180
181 /* Poly1305 implementation (no AES, nonce is not encrypted) */
182
183 #define NWORDS ((130 + BIGNUM_INT_BITS-1) / BIGNUM_INT_BITS)
184 typedef struct bigval {
185     BignumInt w[NWORDS];
186 } bigval;
187
188 static void bigval_clear(bigval *r)
189 {
190     int i;
191     for (i = 0; i < NWORDS; i++)
192         r->w[i] = 0;
193 }
194
195 static void bigval_import_le(bigval *r, const void *vdata, int len)
196 {
197     const unsigned char *data = (const unsigned char *)vdata;
198     int i;
199     bigval_clear(r);
200     for (i = 0; i < len; i++)
201         r->w[i / BIGNUM_INT_BYTES] |=
202             (BignumInt)data[i] << (8 * (i % BIGNUM_INT_BYTES));
203 }
204
205 static void bigval_export_le(const bigval *r, void *vdata, int len)
206 {
207     unsigned char *data = (unsigned char *)vdata;
208     int i;
209     for (i = 0; i < len; i++)
210         data[i] = r->w[i / BIGNUM_INT_BYTES] >> (8 * (i % BIGNUM_INT_BYTES));
211 }
212
213 /*
214  * Addition of bigvals, not mod p.
215  */
216 static void bigval_add(bigval *r, const bigval *a, const bigval *b)
217 {
218 #if BIGNUM_INT_BITS == 64
219     /* ./contrib/make1305.py add 64 */
220     BignumDblInt acclo;
221     acclo = 0;
222     acclo += a->w[0];
223     acclo += b->w[0];
224     r->w[0] = acclo;
225     acclo >>= 64;
226     acclo += a->w[1];
227     acclo += b->w[1];
228     r->w[1] = acclo;
229     acclo >>= 64;
230     acclo += a->w[2];
231     acclo += b->w[2];
232     r->w[2] = acclo;
233     acclo >>= 64;
234 #elif BIGNUM_INT_BITS == 32
235     /* ./contrib/make1305.py add 32 */
236     BignumDblInt acclo;
237     acclo = 0;
238     acclo += a->w[0];
239     acclo += b->w[0];
240     r->w[0] = acclo;
241     acclo >>= 32;
242     acclo += a->w[1];
243     acclo += b->w[1];
244     r->w[1] = acclo;
245     acclo >>= 32;
246     acclo += a->w[2];
247     acclo += b->w[2];
248     r->w[2] = acclo;
249     acclo >>= 32;
250     acclo += a->w[3];
251     acclo += b->w[3];
252     r->w[3] = acclo;
253     acclo >>= 32;
254     acclo += a->w[4];
255     acclo += b->w[4];
256     r->w[4] = acclo;
257     acclo >>= 32;
258 #elif BIGNUM_INT_BITS == 16
259     /* ./contrib/make1305.py add 16 */
260     BignumDblInt acclo;
261     acclo = 0;
262     acclo += a->w[0];
263     acclo += b->w[0];
264     r->w[0] = acclo;
265     acclo >>= 16;
266     acclo += a->w[1];
267     acclo += b->w[1];
268     r->w[1] = acclo;
269     acclo >>= 16;
270     acclo += a->w[2];
271     acclo += b->w[2];
272     r->w[2] = acclo;
273     acclo >>= 16;
274     acclo += a->w[3];
275     acclo += b->w[3];
276     r->w[3] = acclo;
277     acclo >>= 16;
278     acclo += a->w[4];
279     acclo += b->w[4];
280     r->w[4] = acclo;
281     acclo >>= 16;
282     acclo += a->w[5];
283     acclo += b->w[5];
284     r->w[5] = acclo;
285     acclo >>= 16;
286     acclo += a->w[6];
287     acclo += b->w[6];
288     r->w[6] = acclo;
289     acclo >>= 16;
290     acclo += a->w[7];
291     acclo += b->w[7];
292     r->w[7] = acclo;
293     acclo >>= 16;
294     acclo += a->w[8];
295     acclo += b->w[8];
296     r->w[8] = acclo;
297     acclo >>= 16;
298 #else
299 #error Run contrib/make1305.py again with a different bit count
300 #endif
301 }
302
303 /*
304  * Multiplication of bigvals mod p. Uses r as temporary storage, so
305  * don't pass r aliasing a or b.
306  */
307 static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
308 {
309 #if BIGNUM_INT_BITS == 64
310     /* ./contrib/make1305.py mul 64 */
311     BignumDblInt tmp;
312     BignumDblInt acclo;
313     BignumDblInt acchi;
314     BignumDblInt acc2lo;
315     acclo = 0;
316     acchi = 0;
317     tmp = (BignumDblInt)(a->w[0]) * (b->w[0]);
318     acclo += tmp & BIGNUM_INT_MASK;
319     acchi += tmp >> 64;
320     r->w[0] = acclo;
321     acclo = acchi + (acclo >> 64);
322     acchi = 0;
323     tmp = (BignumDblInt)(a->w[0]) * (b->w[1]);
324     acclo += tmp & BIGNUM_INT_MASK;
325     acchi += tmp >> 64;
326     tmp = (BignumDblInt)(a->w[1]) * (b->w[0]);
327     acclo += tmp & BIGNUM_INT_MASK;
328     acchi += tmp >> 64;
329     r->w[1] = acclo;
330     acclo = acchi + (acclo >> 64);
331     acchi = 0;
332     tmp = (BignumDblInt)(a->w[0]) * (b->w[2]);
333     acclo += tmp & BIGNUM_INT_MASK;
334     acchi += tmp >> 64;
335     tmp = (BignumDblInt)(a->w[1]) * (b->w[1]);
336     acclo += tmp & BIGNUM_INT_MASK;
337     acchi += tmp >> 64;
338     tmp = (BignumDblInt)(a->w[2]) * (b->w[0]);
339     acclo += tmp & BIGNUM_INT_MASK;
340     acchi += tmp >> 64;
341     r->w[2] = acclo & (((BignumInt)1 << 2)-1);
342     acc2lo = 0;
343     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 62)-1)) * ((BignumDblInt)5 << 0);
344     acclo = acchi + (acclo >> 64);
345     acchi = 0;
346     tmp = (BignumDblInt)(a->w[1]) * (b->w[2]);
347     acclo += tmp & BIGNUM_INT_MASK;
348     acchi += tmp >> 64;
349     tmp = (BignumDblInt)(a->w[2]) * (b->w[1]);
350     acclo += tmp & BIGNUM_INT_MASK;
351     acchi += tmp >> 64;
352     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 62);
353     acc2lo += r->w[0];
354     r->w[0] = acc2lo;
355     acc2lo >>= 64;
356     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 62)-1)) * ((BignumDblInt)5 << 0);
357     acclo = acchi + (acclo >> 64);
358     acchi = 0;
359     tmp = (BignumDblInt)(a->w[2]) * (b->w[2]);
360     acclo += tmp & BIGNUM_INT_MASK;
361     acchi += tmp >> 64;
362     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 62);
363     acc2lo += r->w[1];
364     r->w[1] = acc2lo;
365     acc2lo >>= 64;
366     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 0);
367     acc2lo += r->w[2];
368     r->w[2] = acc2lo;
369     acc2lo = 0;
370     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 60)-1)) * ((BignumDblInt)25 << 0);
371     acclo = acchi + (acclo >> 64);
372     acchi = 0;
373     acc2lo += (acclo & (((BignumInt)1 << 4)-1)) * ((BignumDblInt)25 << 60);
374     acc2lo += r->w[0];
375     r->w[0] = acc2lo;
376     acc2lo >>= 64;
377     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 60)-1)) * ((BignumDblInt)25 << 0);
378     acclo = acchi + (acclo >> 64);
379     acchi = 0;
380     acc2lo += r->w[1];
381     r->w[1] = acc2lo;
382     acc2lo >>= 64;
383     acc2lo += r->w[2];
384     r->w[2] = acc2lo;
385     acc2lo >>= 64;
386 #elif BIGNUM_INT_BITS == 32
387     /* ./contrib/make1305.py mul 32 */
388     BignumDblInt tmp;
389     BignumDblInt acclo;
390     BignumDblInt acchi;
391     BignumDblInt acc2lo;
392     acclo = 0;
393     acchi = 0;
394     tmp = (BignumDblInt)(a->w[0]) * (b->w[0]);
395     acclo += tmp & BIGNUM_INT_MASK;
396     acchi += tmp >> 32;
397     r->w[0] = acclo;
398     acclo = acchi + (acclo >> 32);
399     acchi = 0;
400     tmp = (BignumDblInt)(a->w[0]) * (b->w[1]);
401     acclo += tmp & BIGNUM_INT_MASK;
402     acchi += tmp >> 32;
403     tmp = (BignumDblInt)(a->w[1]) * (b->w[0]);
404     acclo += tmp & BIGNUM_INT_MASK;
405     acchi += tmp >> 32;
406     r->w[1] = acclo;
407     acclo = acchi + (acclo >> 32);
408     acchi = 0;
409     tmp = (BignumDblInt)(a->w[0]) * (b->w[2]);
410     acclo += tmp & BIGNUM_INT_MASK;
411     acchi += tmp >> 32;
412     tmp = (BignumDblInt)(a->w[1]) * (b->w[1]);
413     acclo += tmp & BIGNUM_INT_MASK;
414     acchi += tmp >> 32;
415     tmp = (BignumDblInt)(a->w[2]) * (b->w[0]);
416     acclo += tmp & BIGNUM_INT_MASK;
417     acchi += tmp >> 32;
418     r->w[2] = acclo;
419     acclo = acchi + (acclo >> 32);
420     acchi = 0;
421     tmp = (BignumDblInt)(a->w[0]) * (b->w[3]);
422     acclo += tmp & BIGNUM_INT_MASK;
423     acchi += tmp >> 32;
424     tmp = (BignumDblInt)(a->w[1]) * (b->w[2]);
425     acclo += tmp & BIGNUM_INT_MASK;
426     acchi += tmp >> 32;
427     tmp = (BignumDblInt)(a->w[2]) * (b->w[1]);
428     acclo += tmp & BIGNUM_INT_MASK;
429     acchi += tmp >> 32;
430     tmp = (BignumDblInt)(a->w[3]) * (b->w[0]);
431     acclo += tmp & BIGNUM_INT_MASK;
432     acchi += tmp >> 32;
433     r->w[3] = acclo;
434     acclo = acchi + (acclo >> 32);
435     acchi = 0;
436     tmp = (BignumDblInt)(a->w[0]) * (b->w[4]);
437     acclo += tmp & BIGNUM_INT_MASK;
438     acchi += tmp >> 32;
439     tmp = (BignumDblInt)(a->w[1]) * (b->w[3]);
440     acclo += tmp & BIGNUM_INT_MASK;
441     acchi += tmp >> 32;
442     tmp = (BignumDblInt)(a->w[2]) * (b->w[2]);
443     acclo += tmp & BIGNUM_INT_MASK;
444     acchi += tmp >> 32;
445     tmp = (BignumDblInt)(a->w[3]) * (b->w[1]);
446     acclo += tmp & BIGNUM_INT_MASK;
447     acchi += tmp >> 32;
448     tmp = (BignumDblInt)(a->w[4]) * (b->w[0]);
449     acclo += tmp & BIGNUM_INT_MASK;
450     acchi += tmp >> 32;
451     r->w[4] = acclo & (((BignumInt)1 << 2)-1);
452     acc2lo = 0;
453     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
454     acclo = acchi + (acclo >> 32);
455     acchi = 0;
456     tmp = (BignumDblInt)(a->w[1]) * (b->w[4]);
457     acclo += tmp & BIGNUM_INT_MASK;
458     acchi += tmp >> 32;
459     tmp = (BignumDblInt)(a->w[2]) * (b->w[3]);
460     acclo += tmp & BIGNUM_INT_MASK;
461     acchi += tmp >> 32;
462     tmp = (BignumDblInt)(a->w[3]) * (b->w[2]);
463     acclo += tmp & BIGNUM_INT_MASK;
464     acchi += tmp >> 32;
465     tmp = (BignumDblInt)(a->w[4]) * (b->w[1]);
466     acclo += tmp & BIGNUM_INT_MASK;
467     acchi += tmp >> 32;
468     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
469     acc2lo += r->w[0];
470     r->w[0] = acc2lo;
471     acc2lo >>= 32;
472     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
473     acclo = acchi + (acclo >> 32);
474     acchi = 0;
475     tmp = (BignumDblInt)(a->w[2]) * (b->w[4]);
476     acclo += tmp & BIGNUM_INT_MASK;
477     acchi += tmp >> 32;
478     tmp = (BignumDblInt)(a->w[3]) * (b->w[3]);
479     acclo += tmp & BIGNUM_INT_MASK;
480     acchi += tmp >> 32;
481     tmp = (BignumDblInt)(a->w[4]) * (b->w[2]);
482     acclo += tmp & BIGNUM_INT_MASK;
483     acchi += tmp >> 32;
484     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
485     acc2lo += r->w[1];
486     r->w[1] = acc2lo;
487     acc2lo >>= 32;
488     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
489     acclo = acchi + (acclo >> 32);
490     acchi = 0;
491     tmp = (BignumDblInt)(a->w[3]) * (b->w[4]);
492     acclo += tmp & BIGNUM_INT_MASK;
493     acchi += tmp >> 32;
494     tmp = (BignumDblInt)(a->w[4]) * (b->w[3]);
495     acclo += tmp & BIGNUM_INT_MASK;
496     acchi += tmp >> 32;
497     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
498     acc2lo += r->w[2];
499     r->w[2] = acc2lo;
500     acc2lo >>= 32;
501     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 30)-1)) * ((BignumDblInt)5 << 0);
502     acclo = acchi + (acclo >> 32);
503     acchi = 0;
504     tmp = (BignumDblInt)(a->w[4]) * (b->w[4]);
505     acclo += tmp & BIGNUM_INT_MASK;
506     acchi += tmp >> 32;
507     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 30);
508     acc2lo += r->w[3];
509     r->w[3] = acc2lo;
510     acc2lo >>= 32;
511     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 0);
512     acc2lo += r->w[4];
513     r->w[4] = acc2lo;
514     acc2lo = 0;
515     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 28)-1)) * ((BignumDblInt)25 << 0);
516     acclo = acchi + (acclo >> 32);
517     acchi = 0;
518     acc2lo += (acclo & (((BignumInt)1 << 4)-1)) * ((BignumDblInt)25 << 28);
519     acc2lo += r->w[0];
520     r->w[0] = acc2lo;
521     acc2lo >>= 32;
522     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 28)-1)) * ((BignumDblInt)25 << 0);
523     acclo = acchi + (acclo >> 32);
524     acchi = 0;
525     acc2lo += r->w[1];
526     r->w[1] = acc2lo;
527     acc2lo >>= 32;
528     acc2lo += r->w[2];
529     r->w[2] = acc2lo;
530     acc2lo >>= 32;
531     acc2lo += r->w[3];
532     r->w[3] = acc2lo;
533     acc2lo >>= 32;
534     acc2lo += r->w[4];
535     r->w[4] = acc2lo;
536     acc2lo >>= 32;
537 #elif BIGNUM_INT_BITS == 16
538     /* ./contrib/make1305.py mul 16 */
539     BignumDblInt tmp;
540     BignumDblInt acclo;
541     BignumDblInt acchi;
542     BignumDblInt acc2lo;
543     acclo = 0;
544     acchi = 0;
545     tmp = (BignumDblInt)(a->w[0]) * (b->w[0]);
546     acclo += tmp & BIGNUM_INT_MASK;
547     acchi += tmp >> 16;
548     r->w[0] = acclo;
549     acclo = acchi + (acclo >> 16);
550     acchi = 0;
551     tmp = (BignumDblInt)(a->w[0]) * (b->w[1]);
552     acclo += tmp & BIGNUM_INT_MASK;
553     acchi += tmp >> 16;
554     tmp = (BignumDblInt)(a->w[1]) * (b->w[0]);
555     acclo += tmp & BIGNUM_INT_MASK;
556     acchi += tmp >> 16;
557     r->w[1] = acclo;
558     acclo = acchi + (acclo >> 16);
559     acchi = 0;
560     tmp = (BignumDblInt)(a->w[0]) * (b->w[2]);
561     acclo += tmp & BIGNUM_INT_MASK;
562     acchi += tmp >> 16;
563     tmp = (BignumDblInt)(a->w[1]) * (b->w[1]);
564     acclo += tmp & BIGNUM_INT_MASK;
565     acchi += tmp >> 16;
566     tmp = (BignumDblInt)(a->w[2]) * (b->w[0]);
567     acclo += tmp & BIGNUM_INT_MASK;
568     acchi += tmp >> 16;
569     r->w[2] = acclo;
570     acclo = acchi + (acclo >> 16);
571     acchi = 0;
572     tmp = (BignumDblInt)(a->w[0]) * (b->w[3]);
573     acclo += tmp & BIGNUM_INT_MASK;
574     acchi += tmp >> 16;
575     tmp = (BignumDblInt)(a->w[1]) * (b->w[2]);
576     acclo += tmp & BIGNUM_INT_MASK;
577     acchi += tmp >> 16;
578     tmp = (BignumDblInt)(a->w[2]) * (b->w[1]);
579     acclo += tmp & BIGNUM_INT_MASK;
580     acchi += tmp >> 16;
581     tmp = (BignumDblInt)(a->w[3]) * (b->w[0]);
582     acclo += tmp & BIGNUM_INT_MASK;
583     acchi += tmp >> 16;
584     r->w[3] = acclo;
585     acclo = acchi + (acclo >> 16);
586     acchi = 0;
587     tmp = (BignumDblInt)(a->w[0]) * (b->w[4]);
588     acclo += tmp & BIGNUM_INT_MASK;
589     acchi += tmp >> 16;
590     tmp = (BignumDblInt)(a->w[1]) * (b->w[3]);
591     acclo += tmp & BIGNUM_INT_MASK;
592     acchi += tmp >> 16;
593     tmp = (BignumDblInt)(a->w[2]) * (b->w[2]);
594     acclo += tmp & BIGNUM_INT_MASK;
595     acchi += tmp >> 16;
596     tmp = (BignumDblInt)(a->w[3]) * (b->w[1]);
597     acclo += tmp & BIGNUM_INT_MASK;
598     acchi += tmp >> 16;
599     tmp = (BignumDblInt)(a->w[4]) * (b->w[0]);
600     acclo += tmp & BIGNUM_INT_MASK;
601     acchi += tmp >> 16;
602     r->w[4] = acclo;
603     acclo = acchi + (acclo >> 16);
604     acchi = 0;
605     tmp = (BignumDblInt)(a->w[0]) * (b->w[5]);
606     acclo += tmp & BIGNUM_INT_MASK;
607     acchi += tmp >> 16;
608     tmp = (BignumDblInt)(a->w[1]) * (b->w[4]);
609     acclo += tmp & BIGNUM_INT_MASK;
610     acchi += tmp >> 16;
611     tmp = (BignumDblInt)(a->w[2]) * (b->w[3]);
612     acclo += tmp & BIGNUM_INT_MASK;
613     acchi += tmp >> 16;
614     tmp = (BignumDblInt)(a->w[3]) * (b->w[2]);
615     acclo += tmp & BIGNUM_INT_MASK;
616     acchi += tmp >> 16;
617     tmp = (BignumDblInt)(a->w[4]) * (b->w[1]);
618     acclo += tmp & BIGNUM_INT_MASK;
619     acchi += tmp >> 16;
620     tmp = (BignumDblInt)(a->w[5]) * (b->w[0]);
621     acclo += tmp & BIGNUM_INT_MASK;
622     acchi += tmp >> 16;
623     r->w[5] = acclo;
624     acclo = acchi + (acclo >> 16);
625     acchi = 0;
626     tmp = (BignumDblInt)(a->w[0]) * (b->w[6]);
627     acclo += tmp & BIGNUM_INT_MASK;
628     acchi += tmp >> 16;
629     tmp = (BignumDblInt)(a->w[1]) * (b->w[5]);
630     acclo += tmp & BIGNUM_INT_MASK;
631     acchi += tmp >> 16;
632     tmp = (BignumDblInt)(a->w[2]) * (b->w[4]);
633     acclo += tmp & BIGNUM_INT_MASK;
634     acchi += tmp >> 16;
635     tmp = (BignumDblInt)(a->w[3]) * (b->w[3]);
636     acclo += tmp & BIGNUM_INT_MASK;
637     acchi += tmp >> 16;
638     tmp = (BignumDblInt)(a->w[4]) * (b->w[2]);
639     acclo += tmp & BIGNUM_INT_MASK;
640     acchi += tmp >> 16;
641     tmp = (BignumDblInt)(a->w[5]) * (b->w[1]);
642     acclo += tmp & BIGNUM_INT_MASK;
643     acchi += tmp >> 16;
644     tmp = (BignumDblInt)(a->w[6]) * (b->w[0]);
645     acclo += tmp & BIGNUM_INT_MASK;
646     acchi += tmp >> 16;
647     r->w[6] = acclo;
648     acclo = acchi + (acclo >> 16);
649     acchi = 0;
650     tmp = (BignumDblInt)(a->w[0]) * (b->w[7]);
651     acclo += tmp & BIGNUM_INT_MASK;
652     acchi += tmp >> 16;
653     tmp = (BignumDblInt)(a->w[1]) * (b->w[6]);
654     acclo += tmp & BIGNUM_INT_MASK;
655     acchi += tmp >> 16;
656     tmp = (BignumDblInt)(a->w[2]) * (b->w[5]);
657     acclo += tmp & BIGNUM_INT_MASK;
658     acchi += tmp >> 16;
659     tmp = (BignumDblInt)(a->w[3]) * (b->w[4]);
660     acclo += tmp & BIGNUM_INT_MASK;
661     acchi += tmp >> 16;
662     tmp = (BignumDblInt)(a->w[4]) * (b->w[3]);
663     acclo += tmp & BIGNUM_INT_MASK;
664     acchi += tmp >> 16;
665     tmp = (BignumDblInt)(a->w[5]) * (b->w[2]);
666     acclo += tmp & BIGNUM_INT_MASK;
667     acchi += tmp >> 16;
668     tmp = (BignumDblInt)(a->w[6]) * (b->w[1]);
669     acclo += tmp & BIGNUM_INT_MASK;
670     acchi += tmp >> 16;
671     tmp = (BignumDblInt)(a->w[7]) * (b->w[0]);
672     acclo += tmp & BIGNUM_INT_MASK;
673     acchi += tmp >> 16;
674     r->w[7] = acclo;
675     acclo = acchi + (acclo >> 16);
676     acchi = 0;
677     tmp = (BignumDblInt)(a->w[0]) * (b->w[8]);
678     acclo += tmp & BIGNUM_INT_MASK;
679     acchi += tmp >> 16;
680     tmp = (BignumDblInt)(a->w[1]) * (b->w[7]);
681     acclo += tmp & BIGNUM_INT_MASK;
682     acchi += tmp >> 16;
683     tmp = (BignumDblInt)(a->w[2]) * (b->w[6]);
684     acclo += tmp & BIGNUM_INT_MASK;
685     acchi += tmp >> 16;
686     tmp = (BignumDblInt)(a->w[3]) * (b->w[5]);
687     acclo += tmp & BIGNUM_INT_MASK;
688     acchi += tmp >> 16;
689     tmp = (BignumDblInt)(a->w[4]) * (b->w[4]);
690     acclo += tmp & BIGNUM_INT_MASK;
691     acchi += tmp >> 16;
692     tmp = (BignumDblInt)(a->w[5]) * (b->w[3]);
693     acclo += tmp & BIGNUM_INT_MASK;
694     acchi += tmp >> 16;
695     tmp = (BignumDblInt)(a->w[6]) * (b->w[2]);
696     acclo += tmp & BIGNUM_INT_MASK;
697     acchi += tmp >> 16;
698     tmp = (BignumDblInt)(a->w[7]) * (b->w[1]);
699     acclo += tmp & BIGNUM_INT_MASK;
700     acchi += tmp >> 16;
701     tmp = (BignumDblInt)(a->w[8]) * (b->w[0]);
702     acclo += tmp & BIGNUM_INT_MASK;
703     acchi += tmp >> 16;
704     r->w[8] = acclo & (((BignumInt)1 << 2)-1);
705     acc2lo = 0;
706     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
707     acclo = acchi + (acclo >> 16);
708     acchi = 0;
709     tmp = (BignumDblInt)(a->w[1]) * (b->w[8]);
710     acclo += tmp & BIGNUM_INT_MASK;
711     acchi += tmp >> 16;
712     tmp = (BignumDblInt)(a->w[2]) * (b->w[7]);
713     acclo += tmp & BIGNUM_INT_MASK;
714     acchi += tmp >> 16;
715     tmp = (BignumDblInt)(a->w[3]) * (b->w[6]);
716     acclo += tmp & BIGNUM_INT_MASK;
717     acchi += tmp >> 16;
718     tmp = (BignumDblInt)(a->w[4]) * (b->w[5]);
719     acclo += tmp & BIGNUM_INT_MASK;
720     acchi += tmp >> 16;
721     tmp = (BignumDblInt)(a->w[5]) * (b->w[4]);
722     acclo += tmp & BIGNUM_INT_MASK;
723     acchi += tmp >> 16;
724     tmp = (BignumDblInt)(a->w[6]) * (b->w[3]);
725     acclo += tmp & BIGNUM_INT_MASK;
726     acchi += tmp >> 16;
727     tmp = (BignumDblInt)(a->w[7]) * (b->w[2]);
728     acclo += tmp & BIGNUM_INT_MASK;
729     acchi += tmp >> 16;
730     tmp = (BignumDblInt)(a->w[8]) * (b->w[1]);
731     acclo += tmp & BIGNUM_INT_MASK;
732     acchi += tmp >> 16;
733     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
734     acc2lo += r->w[0];
735     r->w[0] = acc2lo;
736     acc2lo >>= 16;
737     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
738     acclo = acchi + (acclo >> 16);
739     acchi = 0;
740     tmp = (BignumDblInt)(a->w[2]) * (b->w[8]);
741     acclo += tmp & BIGNUM_INT_MASK;
742     acchi += tmp >> 16;
743     tmp = (BignumDblInt)(a->w[3]) * (b->w[7]);
744     acclo += tmp & BIGNUM_INT_MASK;
745     acchi += tmp >> 16;
746     tmp = (BignumDblInt)(a->w[4]) * (b->w[6]);
747     acclo += tmp & BIGNUM_INT_MASK;
748     acchi += tmp >> 16;
749     tmp = (BignumDblInt)(a->w[5]) * (b->w[5]);
750     acclo += tmp & BIGNUM_INT_MASK;
751     acchi += tmp >> 16;
752     tmp = (BignumDblInt)(a->w[6]) * (b->w[4]);
753     acclo += tmp & BIGNUM_INT_MASK;
754     acchi += tmp >> 16;
755     tmp = (BignumDblInt)(a->w[7]) * (b->w[3]);
756     acclo += tmp & BIGNUM_INT_MASK;
757     acchi += tmp >> 16;
758     tmp = (BignumDblInt)(a->w[8]) * (b->w[2]);
759     acclo += tmp & BIGNUM_INT_MASK;
760     acchi += tmp >> 16;
761     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
762     acc2lo += r->w[1];
763     r->w[1] = acc2lo;
764     acc2lo >>= 16;
765     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
766     acclo = acchi + (acclo >> 16);
767     acchi = 0;
768     tmp = (BignumDblInt)(a->w[3]) * (b->w[8]);
769     acclo += tmp & BIGNUM_INT_MASK;
770     acchi += tmp >> 16;
771     tmp = (BignumDblInt)(a->w[4]) * (b->w[7]);
772     acclo += tmp & BIGNUM_INT_MASK;
773     acchi += tmp >> 16;
774     tmp = (BignumDblInt)(a->w[5]) * (b->w[6]);
775     acclo += tmp & BIGNUM_INT_MASK;
776     acchi += tmp >> 16;
777     tmp = (BignumDblInt)(a->w[6]) * (b->w[5]);
778     acclo += tmp & BIGNUM_INT_MASK;
779     acchi += tmp >> 16;
780     tmp = (BignumDblInt)(a->w[7]) * (b->w[4]);
781     acclo += tmp & BIGNUM_INT_MASK;
782     acchi += tmp >> 16;
783     tmp = (BignumDblInt)(a->w[8]) * (b->w[3]);
784     acclo += tmp & BIGNUM_INT_MASK;
785     acchi += tmp >> 16;
786     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
787     acc2lo += r->w[2];
788     r->w[2] = acc2lo;
789     acc2lo >>= 16;
790     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
791     acclo = acchi + (acclo >> 16);
792     acchi = 0;
793     tmp = (BignumDblInt)(a->w[4]) * (b->w[8]);
794     acclo += tmp & BIGNUM_INT_MASK;
795     acchi += tmp >> 16;
796     tmp = (BignumDblInt)(a->w[5]) * (b->w[7]);
797     acclo += tmp & BIGNUM_INT_MASK;
798     acchi += tmp >> 16;
799     tmp = (BignumDblInt)(a->w[6]) * (b->w[6]);
800     acclo += tmp & BIGNUM_INT_MASK;
801     acchi += tmp >> 16;
802     tmp = (BignumDblInt)(a->w[7]) * (b->w[5]);
803     acclo += tmp & BIGNUM_INT_MASK;
804     acchi += tmp >> 16;
805     tmp = (BignumDblInt)(a->w[8]) * (b->w[4]);
806     acclo += tmp & BIGNUM_INT_MASK;
807     acchi += tmp >> 16;
808     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
809     acc2lo += r->w[3];
810     r->w[3] = acc2lo;
811     acc2lo >>= 16;
812     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
813     acclo = acchi + (acclo >> 16);
814     acchi = 0;
815     tmp = (BignumDblInt)(a->w[5]) * (b->w[8]);
816     acclo += tmp & BIGNUM_INT_MASK;
817     acchi += tmp >> 16;
818     tmp = (BignumDblInt)(a->w[6]) * (b->w[7]);
819     acclo += tmp & BIGNUM_INT_MASK;
820     acchi += tmp >> 16;
821     tmp = (BignumDblInt)(a->w[7]) * (b->w[6]);
822     acclo += tmp & BIGNUM_INT_MASK;
823     acchi += tmp >> 16;
824     tmp = (BignumDblInt)(a->w[8]) * (b->w[5]);
825     acclo += tmp & BIGNUM_INT_MASK;
826     acchi += tmp >> 16;
827     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
828     acc2lo += r->w[4];
829     r->w[4] = acc2lo;
830     acc2lo >>= 16;
831     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
832     acclo = acchi + (acclo >> 16);
833     acchi = 0;
834     tmp = (BignumDblInt)(a->w[6]) * (b->w[8]);
835     acclo += tmp & BIGNUM_INT_MASK;
836     acchi += tmp >> 16;
837     tmp = (BignumDblInt)(a->w[7]) * (b->w[7]);
838     acclo += tmp & BIGNUM_INT_MASK;
839     acchi += tmp >> 16;
840     tmp = (BignumDblInt)(a->w[8]) * (b->w[6]);
841     acclo += tmp & BIGNUM_INT_MASK;
842     acchi += tmp >> 16;
843     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
844     acc2lo += r->w[5];
845     r->w[5] = acc2lo;
846     acc2lo >>= 16;
847     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
848     acclo = acchi + (acclo >> 16);
849     acchi = 0;
850     tmp = (BignumDblInt)(a->w[7]) * (b->w[8]);
851     acclo += tmp & BIGNUM_INT_MASK;
852     acchi += tmp >> 16;
853     tmp = (BignumDblInt)(a->w[8]) * (b->w[7]);
854     acclo += tmp & BIGNUM_INT_MASK;
855     acchi += tmp >> 16;
856     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
857     acc2lo += r->w[6];
858     r->w[6] = acc2lo;
859     acc2lo >>= 16;
860     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 14)-1)) * ((BignumDblInt)5 << 0);
861     acclo = acchi + (acclo >> 16);
862     acchi = 0;
863     tmp = (BignumDblInt)(a->w[8]) * (b->w[8]);
864     acclo += tmp & BIGNUM_INT_MASK;
865     acchi += tmp >> 16;
866     acc2lo += (acclo & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 14);
867     acc2lo += r->w[7];
868     r->w[7] = acc2lo;
869     acc2lo >>= 16;
870     acc2lo += ((acclo >> 2) & (((BignumInt)1 << 2)-1)) * ((BignumDblInt)5 << 0);
871     acc2lo += r->w[8];
872     r->w[8] = acc2lo;
873     acc2lo = 0;
874     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 12)-1)) * ((BignumDblInt)25 << 0);
875     acclo = acchi + (acclo >> 16);
876     acchi = 0;
877     acc2lo += (acclo & (((BignumInt)1 << 4)-1)) * ((BignumDblInt)25 << 12);
878     acc2lo += r->w[0];
879     r->w[0] = acc2lo;
880     acc2lo >>= 16;
881     acc2lo += ((acclo >> 4) & (((BignumInt)1 << 12)-1)) * ((BignumDblInt)25 << 0);
882     acclo = acchi + (acclo >> 16);
883     acchi = 0;
884     acc2lo += r->w[1];
885     r->w[1] = acc2lo;
886     acc2lo >>= 16;
887     acc2lo += r->w[2];
888     r->w[2] = acc2lo;
889     acc2lo >>= 16;
890     acc2lo += r->w[3];
891     r->w[3] = acc2lo;
892     acc2lo >>= 16;
893     acc2lo += r->w[4];
894     r->w[4] = acc2lo;
895     acc2lo >>= 16;
896     acc2lo += r->w[5];
897     r->w[5] = acc2lo;
898     acc2lo >>= 16;
899     acc2lo += r->w[6];
900     r->w[6] = acc2lo;
901     acc2lo >>= 16;
902     acc2lo += r->w[7];
903     r->w[7] = acc2lo;
904     acc2lo >>= 16;
905     acc2lo += r->w[8];
906     r->w[8] = acc2lo;
907     acc2lo >>= 16;
908 #else
909 #error Run contrib/make1305.py again with a different bit count
910 #endif
911 }
912
913 static void bigval_final_reduce(bigval *n)
914 {
915 #if BIGNUM_INT_BITS == 64
916     /* ./contrib/make1305.py final_reduce 64 */
917     BignumDblInt acclo;
918     acclo = 0;
919     acclo += 5 * ((n->w[2] >> 2) + 1);
920     acclo += n->w[0];
921     acclo >>= 64;
922     acclo += n->w[1];
923     acclo >>= 64;
924     acclo += n->w[2];
925     acclo = 5 * (acclo >> 2);
926     acclo += n->w[0];
927     n->w[0] = acclo;
928     acclo >>= 64;
929     acclo += n->w[1];
930     n->w[1] = acclo;
931     acclo >>= 64;
932     acclo += n->w[2];
933     n->w[2] = acclo;
934     acclo >>= 64;
935     n->w[2] &= (1 << 2) - 1;
936 #elif BIGNUM_INT_BITS == 32
937     /* ./contrib/make1305.py final_reduce 32 */
938     BignumDblInt acclo;
939     acclo = 0;
940     acclo += 5 * ((n->w[4] >> 2) + 1);
941     acclo += n->w[0];
942     acclo >>= 32;
943     acclo += n->w[1];
944     acclo >>= 32;
945     acclo += n->w[2];
946     acclo >>= 32;
947     acclo += n->w[3];
948     acclo >>= 32;
949     acclo += n->w[4];
950     acclo = 5 * (acclo >> 2);
951     acclo += n->w[0];
952     n->w[0] = acclo;
953     acclo >>= 32;
954     acclo += n->w[1];
955     n->w[1] = acclo;
956     acclo >>= 32;
957     acclo += n->w[2];
958     n->w[2] = acclo;
959     acclo >>= 32;
960     acclo += n->w[3];
961     n->w[3] = acclo;
962     acclo >>= 32;
963     acclo += n->w[4];
964     n->w[4] = acclo;
965     acclo >>= 32;
966     n->w[4] &= (1 << 2) - 1;
967 #elif BIGNUM_INT_BITS == 16
968     /* ./contrib/make1305.py final_reduce 16 */
969     BignumDblInt acclo;
970     acclo = 0;
971     acclo += 5 * ((n->w[8] >> 2) + 1);
972     acclo += n->w[0];
973     acclo >>= 16;
974     acclo += n->w[1];
975     acclo >>= 16;
976     acclo += n->w[2];
977     acclo >>= 16;
978     acclo += n->w[3];
979     acclo >>= 16;
980     acclo += n->w[4];
981     acclo >>= 16;
982     acclo += n->w[5];
983     acclo >>= 16;
984     acclo += n->w[6];
985     acclo >>= 16;
986     acclo += n->w[7];
987     acclo >>= 16;
988     acclo += n->w[8];
989     acclo = 5 * (acclo >> 2);
990     acclo += n->w[0];
991     n->w[0] = acclo;
992     acclo >>= 16;
993     acclo += n->w[1];
994     n->w[1] = acclo;
995     acclo >>= 16;
996     acclo += n->w[2];
997     n->w[2] = acclo;
998     acclo >>= 16;
999     acclo += n->w[3];
1000     n->w[3] = acclo;
1001     acclo >>= 16;
1002     acclo += n->w[4];
1003     n->w[4] = acclo;
1004     acclo >>= 16;
1005     acclo += n->w[5];
1006     n->w[5] = acclo;
1007     acclo >>= 16;
1008     acclo += n->w[6];
1009     n->w[6] = acclo;
1010     acclo >>= 16;
1011     acclo += n->w[7];
1012     n->w[7] = acclo;
1013     acclo >>= 16;
1014     acclo += n->w[8];
1015     n->w[8] = acclo;
1016     acclo >>= 16;
1017     n->w[8] &= (1 << 2) - 1;
1018 #else
1019 #error Run contrib/make1305.py again with a different bit count
1020 #endif
1021 }
1022
1023 struct poly1305 {
1024     unsigned char nonce[16];
1025     bigval r;
1026     bigval h;
1027
1028     /* Buffer in case we get less that a multiple of 16 bytes */
1029     unsigned char buffer[16];
1030     int bufferIndex;
1031 };
1032
1033 static void poly1305_init(struct poly1305 *ctx)
1034 {
1035     memset(ctx->nonce, 0, 16);
1036     ctx->bufferIndex = 0;
1037     bigval_clear(&ctx->h);
1038 }
1039
1040 /* Takes a 256 bit key */
1041 static void poly1305_key(struct poly1305 *ctx, const unsigned char *key)
1042 {
1043     unsigned char key_copy[16];
1044     memcpy(key_copy, key, 16);
1045
1046     /* Key the MAC itself
1047      * bytes 4, 8, 12 and 16 are required to have their top four bits clear */
1048     key_copy[3] &= 0x0f;
1049     key_copy[7] &= 0x0f;
1050     key_copy[11] &= 0x0f;
1051     key_copy[15] &= 0x0f;
1052     /* bytes 5, 9 and 13 are required to have their bottom two bits clear */
1053     key_copy[4] &= 0xfc;
1054     key_copy[8] &= 0xfc;
1055     key_copy[12] &= 0xfc;
1056     bigval_import_le(&ctx->r, key_copy, 16);
1057     smemclr(key_copy, sizeof(key_copy));
1058
1059     /* Use second 128 bits are the nonce */
1060     memcpy(ctx->nonce, key+16, 16);
1061 }
1062
1063 /* Feed up to 16 bytes (should only be less for the last chunk) */
1064 static void poly1305_feed_chunk(struct poly1305 *ctx,
1065                                 const unsigned char *chunk, int len)
1066 {
1067     bigval c;
1068     bigval_import_le(&c, chunk, len);
1069     c.w[len / BIGNUM_INT_BYTES] |=
1070         (BignumInt)1 << (8 * (len % BIGNUM_INT_BYTES));
1071     bigval_add(&c, &c, &ctx->h);
1072     bigval_mul_mod_p(&ctx->h, &c, &ctx->r);
1073 }
1074
1075 static void poly1305_feed(struct poly1305 *ctx,
1076                           const unsigned char *buf, int len)
1077 {
1078     /* Check for stuff left in the buffer from last time */
1079     if (ctx->bufferIndex) {
1080         /* Try to fill up to 16 */
1081         while (ctx->bufferIndex < 16 && len) {
1082             ctx->buffer[ctx->bufferIndex++] = *buf++;
1083             --len;
1084         }
1085         if (ctx->bufferIndex == 16) {
1086             poly1305_feed_chunk(ctx, ctx->buffer, 16);
1087             ctx->bufferIndex = 0;
1088         }
1089     }
1090
1091     /* Process 16 byte whole chunks */
1092     while (len >= 16) {
1093         poly1305_feed_chunk(ctx, buf, 16);
1094         len -= 16;
1095         buf += 16;
1096     }
1097
1098     /* Cache stuff that's left over */
1099     if (len) {
1100         memcpy(ctx->buffer, buf, len);
1101         ctx->bufferIndex = len;
1102     }
1103 }
1104
1105 /* Finalise and populate buffer with 16 byte with MAC */
1106 static void poly1305_finalise(struct poly1305 *ctx, unsigned char *mac)
1107 {
1108     bigval tmp;
1109
1110     if (ctx->bufferIndex) {
1111         poly1305_feed_chunk(ctx, ctx->buffer, ctx->bufferIndex);
1112     }
1113
1114     bigval_import_le(&tmp, ctx->nonce, 16);
1115     bigval_final_reduce(&ctx->h);
1116     bigval_add(&tmp, &tmp, &ctx->h);
1117     bigval_export_le(&tmp, mac, 16);
1118 }
1119
1120 /* SSH-2 wrapper */
1121
1122 struct ccp_context {
1123     struct chacha20 a_cipher; /* Used for length */
1124     struct chacha20 b_cipher; /* Used for content */
1125
1126     /* Cache of the first 4 bytes because they are the sequence number */
1127     /* Kept in 8 bytes with the top as zero to allow easy passing to setiv */
1128     int mac_initialised; /* Where we have got to in filling mac_iv */
1129     unsigned char mac_iv[8];
1130
1131     struct poly1305 mac;
1132 };
1133
1134 static void *poly_make_context(void *ctx)
1135 {
1136     return ctx;
1137 }
1138
1139 static void poly_free_context(void *ctx)
1140 {
1141     /* Not allocated, just forwarded, no need to free */
1142 }
1143
1144 static void poly_setkey(void *ctx, unsigned char *key)
1145 {
1146     /* Uses the same context as ChaCha20, so ignore */
1147 }
1148
1149 static void poly_start(void *handle)
1150 {
1151     struct ccp_context *ctx = (struct ccp_context *)handle;
1152
1153     ctx->mac_initialised = 0;
1154     memset(ctx->mac_iv, 0, 8);
1155     poly1305_init(&ctx->mac);
1156 }
1157
1158 static void poly_bytes(void *handle, unsigned char const *blk, int len)
1159 {
1160     struct ccp_context *ctx = (struct ccp_context *)handle;
1161
1162     /* First 4 bytes are the IV */
1163     while (ctx->mac_initialised < 4 && len) {
1164         ctx->mac_iv[7 - ctx->mac_initialised] = *blk++;
1165         ++ctx->mac_initialised;
1166         --len;
1167     }
1168
1169     /* Initialise the IV if needed */
1170     if (ctx->mac_initialised == 4) {
1171         chacha20_iv(&ctx->b_cipher, ctx->mac_iv);
1172         ++ctx->mac_initialised;  /* Don't do it again */
1173
1174         /* Do first rotation */
1175         chacha20_round(&ctx->b_cipher);
1176
1177         /* Set the poly key */
1178         poly1305_key(&ctx->mac, ctx->b_cipher.current);
1179
1180         /* Set the first round as used */
1181         ctx->b_cipher.currentIndex = 64;
1182     }
1183
1184     /* Update the MAC with anything left */
1185     if (len) {
1186         poly1305_feed(&ctx->mac, blk, len);
1187     }
1188 }
1189
1190 static void poly_genresult(void *handle, unsigned char *blk)
1191 {
1192     struct ccp_context *ctx = (struct ccp_context *)handle;
1193     poly1305_finalise(&ctx->mac, blk);
1194 }
1195
1196 static int poly_verresult(void *handle, unsigned char const *blk)
1197 {
1198     struct ccp_context *ctx = (struct ccp_context *)handle;
1199     int res;
1200     unsigned char mac[16];
1201     poly1305_finalise(&ctx->mac, mac);
1202     res = smemeq(blk, mac, 16);
1203     return res;
1204 }
1205
1206 /* The generic poly operation used before generate and verify */
1207 static void poly_op(void *handle, unsigned char *blk, int len, unsigned long seq)
1208 {
1209     unsigned char iv[4];
1210     poly_start(handle);
1211     PUT_32BIT_MSB_FIRST(iv, seq);
1212     /* poly_bytes expects the first 4 bytes to be the IV */
1213     poly_bytes(handle, iv, 4);
1214     smemclr(iv, sizeof(iv));
1215     poly_bytes(handle, blk, len);
1216 }
1217
1218 static void poly_generate(void *handle, unsigned char *blk, int len, unsigned long seq)
1219 {
1220     poly_op(handle, blk, len, seq);
1221     poly_genresult(handle, blk+len);
1222 }
1223
1224 static int poly_verify(void *handle, unsigned char *blk, int len, unsigned long seq)
1225 {
1226     poly_op(handle, blk, len, seq);
1227     return poly_verresult(handle, blk+len);
1228 }
1229
1230 static const struct ssh_mac ssh2_poly1305 = {
1231     poly_make_context, poly_free_context,
1232     poly_setkey,
1233
1234     /* whole-packet operations */
1235     poly_generate, poly_verify,
1236
1237     /* partial-packet operations */
1238     poly_start, poly_bytes, poly_genresult, poly_verresult,
1239
1240     "", "", /* Not selectable individually, just part of ChaCha20-Poly1305 */
1241     16, "Poly1305"
1242 };
1243
1244 static void *ccp_make_context(void)
1245 {
1246     struct ccp_context *ctx = snew(struct ccp_context);
1247     if (ctx) {
1248         poly1305_init(&ctx->mac);
1249     }
1250     return ctx;
1251 }
1252
1253 static void ccp_free_context(void *vctx)
1254 {
1255     struct ccp_context *ctx = (struct ccp_context *)vctx;
1256     smemclr(&ctx->a_cipher, sizeof(ctx->a_cipher));
1257     smemclr(&ctx->b_cipher, sizeof(ctx->b_cipher));
1258     smemclr(&ctx->mac, sizeof(ctx->mac));
1259     sfree(ctx);
1260 }
1261
1262 static void ccp_iv(void *vctx, unsigned char *iv)
1263 {
1264     /* struct ccp_context *ctx = (struct ccp_context *)vctx; */
1265     /* IV is set based on the sequence number */
1266 }
1267
1268 static void ccp_key(void *vctx, unsigned char *key)
1269 {
1270     struct ccp_context *ctx = (struct ccp_context *)vctx;
1271     /* Initialise the a_cipher (for decrypting lengths) with the first 256 bits */
1272     chacha20_key(&ctx->a_cipher, key + 32);
1273     /* Initialise the b_cipher (for content and MAC) with the second 256 bits */
1274     chacha20_key(&ctx->b_cipher, key);
1275 }
1276
1277 static void ccp_encrypt(void *vctx, unsigned char *blk, int len)
1278 {
1279     struct ccp_context *ctx = (struct ccp_context *)vctx;
1280     chacha20_encrypt(&ctx->b_cipher, blk, len);
1281 }
1282
1283 static void ccp_decrypt(void *vctx, unsigned char *blk, int len)
1284 {
1285     struct ccp_context *ctx = (struct ccp_context *)vctx;
1286     chacha20_decrypt(&ctx->b_cipher, blk, len);
1287 }
1288
1289 static void ccp_length_op(struct ccp_context *ctx, unsigned char *blk, int len,
1290                           unsigned long seq)
1291 {
1292     unsigned char iv[8];
1293     /*
1294      * According to RFC 4253 (section 6.4), the packet sequence number wraps
1295      * at 2^32, so its 32 high-order bits will always be zero.
1296      */
1297     PUT_32BIT_LSB_FIRST(iv, 0);
1298     PUT_32BIT_LSB_FIRST(iv + 4, seq);
1299     chacha20_iv(&ctx->a_cipher, iv);
1300     chacha20_iv(&ctx->b_cipher, iv);
1301     /* Reset content block count to 1, as the first is the key for Poly1305 */
1302     ++ctx->b_cipher.state[12];
1303     smemclr(iv, sizeof(iv));
1304 }
1305
1306 static void ccp_encrypt_length(void *vctx, unsigned char *blk, int len,
1307                                unsigned long seq)
1308 {
1309     struct ccp_context *ctx = (struct ccp_context *)vctx;
1310     ccp_length_op(ctx, blk, len, seq);
1311     chacha20_encrypt(&ctx->a_cipher, blk, len);
1312 }
1313
1314 static void ccp_decrypt_length(void *vctx, unsigned char *blk, int len,
1315                                unsigned long seq)
1316 {
1317     struct ccp_context *ctx = (struct ccp_context *)vctx;
1318     ccp_length_op(ctx, blk, len, seq);
1319     chacha20_decrypt(&ctx->a_cipher, blk, len);
1320 }
1321
1322 static const struct ssh2_cipher ssh2_chacha20_poly1305 = {
1323
1324     ccp_make_context,
1325     ccp_free_context,
1326     ccp_iv,
1327     ccp_key,
1328     ccp_encrypt,
1329     ccp_decrypt,
1330     ccp_encrypt_length,
1331     ccp_decrypt_length,
1332
1333     "chacha20-poly1305@openssh.com",
1334     1, 512, SSH_CIPHER_SEPARATE_LENGTH, "ChaCha20",
1335
1336     &ssh2_poly1305
1337 };
1338
1339 static const struct ssh2_cipher *const ccp_list[] = {
1340     &ssh2_chacha20_poly1305
1341 };
1342
1343 const struct ssh2_ciphers ssh2_ccp = {
1344     sizeof(ccp_list) / sizeof(*ccp_list),
1345     ccp_list
1346 };