]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - contrib/make1305.py
Merge tag '0.66'
[PuTTY.git] / contrib / make1305.py
1 #!/usr/bin/env python
2
3 import sys
4
5 class Output(object):
6     def __init__(self, bignum_int_bits):
7         self.bignum_int_bits = bignum_int_bits
8         self.text = ""
9         self.vars = []
10     def stmt(self, statement):
11         self.text += "    %s;\n" % statement
12     def register_var(self, var):
13         self.vars.append(var)
14     def finalise(self):
15         for var in self.vars:
16             assert var.maxval == 0, "Variable not clear: %s" % var.name
17         return self.text
18
19 class Variable(object):
20     def __init__(self, out, name):
21         self.out = out
22         self.maxval = 0
23         self.name = name
24         self.placeval = None
25         self.out.stmt("BignumDblInt %s" % (self.name))
26         self.out.register_var(self)
27     def clear(self, placeval):
28         self.maxval = 0
29         self.placeval = placeval
30         self.out.stmt("%s = 0" % (self.name))
31     def set_word(self, name, limit=None):
32         if limit is not None:
33             self.maxval = limit-1
34         else:
35             self.maxval = (1 << self.out.bignum_int_bits) - 1
36         assert self.maxval < (1 << 2*self.out.bignum_int_bits)
37         self.out.stmt("%s = %s" % (self.name, name))
38     def add_word(self, name, limit=None):
39         if limit is not None:
40             self.maxval += limit-1
41         else:
42             self.maxval += (1 << self.out.bignum_int_bits) - 1
43         assert self.maxval < (1 << 2*self.out.bignum_int_bits)
44         self.out.stmt("%s += %s" % (self.name, name))
45     def add_input_word(self, fmt, wordpos, limit=None):
46         assert self.placeval == wordpos * self.out.bignum_int_bits
47         self.add_word(fmt % wordpos, limit)
48     def set_to_product(self, a, b, placeval):
49         self.maxval = ((1 << self.out.bignum_int_bits) - 1) ** 2
50         assert self.maxval < (1 << 2*self.out.bignum_int_bits)        
51         self.out.stmt("%s = (BignumDblInt)(%s) * (%s)" % (self.name, a, b))
52         self.placeval = placeval
53     def add_bottom_half(self, srcvar):
54         self.add_word("%s & BIGNUM_INT_MASK" % (srcvar.name))
55     def add_top_half(self, srcvar):
56         self.add_word("%s >> %d" % (srcvar.name, self.out.bignum_int_bits))
57     def unload_into(self, topvar, botvar):
58         assert botvar.placeval == self.placeval
59         botvar.add_bottom_half(self)
60         assert topvar.placeval == self.placeval + self.out.bignum_int_bits
61         topvar.add_top_half(self)
62         self.maxval = 0
63     def output_word(self, bitpos, bits, destfmt, destwordpos):
64         assert bitpos == 0
65         assert self.placeval == destwordpos * self.out.bignum_int_bits
66         dest = destfmt % destwordpos
67         if bits == self.out.bignum_int_bits:
68             self.out.stmt("%s = %s" % (dest, self.name))
69         else:
70             self.out.stmt("%s = %s & (((BignumInt)1 << %d)-1)" %
71                           (dest, self.name, bits))
72     def transfer_to_next_acc(self, bitpos, bits, pow5, destvar):
73         destbitpos = self.placeval + bitpos - 130 * pow5 - destvar.placeval
74         #print "transfer", "*%d" % 5**pow5, self.name, self.placeval, bitpos, destvar.name, destvar.placeval, destbitpos, bits
75         assert 0 <= bitpos < bitpos+bits <= self.out.bignum_int_bits
76         assert 0 <= destbitpos < destbitpos+bits <= self.out.bignum_int_bits
77         expr = self.name
78         if bitpos > 0:
79             expr = "(%s >> %d)" % (expr, bitpos)
80         expr = "(%s & (((BignumInt)1 << %d)-1))" % (expr, bits)
81         self.out.stmt("%s += %s * ((BignumDblInt)%d << %d)" %
82                       (destvar.name, expr, 5**pow5, destbitpos))
83         destvar.maxval += (((1 << bits)-1) << destbitpos) * (5**pow5)
84     def shift_down_from(self, top):
85         if top is not None:
86             self.out.stmt("%s = %s + (%s >> %d)" %
87                           (self.name, top.name, self.name,
88                            self.out.bignum_int_bits))
89             topmaxval = top.maxval
90         else:
91             self.out.stmt("%s >>= %d" % (self.name, self.out.bignum_int_bits))
92             topmaxval = 0
93         self.maxval = topmaxval + self.maxval >> self.out.bignum_int_bits
94         assert self.maxval < (1 << 2*self.out.bignum_int_bits)
95         if top is not None:
96             assert self.placeval + self.out.bignum_int_bits == top.placeval
97             top.clear(top.placeval + self.out.bignum_int_bits)
98         self.placeval += self.out.bignum_int_bits
99
100 def gen_add(bignum_int_bits):
101     out = Output(bignum_int_bits)
102
103     inbits = 130
104     inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits
105
106     # This is an addition _without_ reduction mod p, so that it can be
107     # used both during accumulation of the polynomial and for adding
108     # on the encrypted nonce at the end (which is mod 2^128, not mod
109     # p).
110     #
111     # Because one of the inputs will have come from our
112     # not-completely-reducing multiplication function, we expect up to
113     # 3 extra bits of input.
114     acclo = Variable(out, "acclo")
115
116     acclo.clear(0)
117
118     for wordpos in range(inwords):
119         limit = min(1 << bignum_int_bits, 1 << (130 - wordpos*bignum_int_bits))
120         acclo.add_input_word("a->w[%d]", wordpos, limit)
121         acclo.add_input_word("b->w[%d]", wordpos, limit)
122         acclo.output_word(0, bignum_int_bits, "r->w[%d]", wordpos)
123         acclo.shift_down_from(None)
124
125     return out.finalise()
126
127 def gen_mul_1305(bignum_int_bits):
128     out = Output(bignum_int_bits)
129
130     inbits = 130
131     inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits
132
133     # The inputs are not 100% reduced mod p. Specifically, we can get
134     # a full 130-bit number from the pow5==0 pass, and then a 130-bit
135     # number times 5 from the pow5==1 pass, plus a possible carry. The
136     # total of that can be easily bounded above by 2^130 * 8, so we
137     # need to assume we're multiplying two 133-bit numbers.
138     outbits = (inbits + 3) * 2
139     outwords = (outbits + bignum_int_bits - 1) / bignum_int_bits + 1
140
141     tmp = Variable(out, "tmp")
142     acclo = Variable(out, "acclo")
143     acchi = Variable(out, "acchi")
144     acc2lo = Variable(out, "acc2lo")
145
146     pow5, bits_at_pow5 = 0, inbits
147
148     acclo.clear(0)
149     acchi.clear(bignum_int_bits)
150     bits_needed_in_acc2 = bignum_int_bits
151
152     for outwordpos in range(outwords):
153         for a in range(inwords):
154             b = outwordpos - a
155             if 0 <= b < inwords:
156                 tmp.set_to_product("a->w[%d]" % a, "b->w[%d]" % b,
157                                    outwordpos * bignum_int_bits)
158                 tmp.unload_into(acchi, acclo)
159
160         bits_in_word = bignum_int_bits
161         bitpos = 0
162         #print "begin output"
163         while bits_in_word > 0:
164             chunk = min(bits_in_word, bits_at_pow5)
165             if pow5 > 0:
166                 chunk = min(chunk, bits_needed_in_acc2)
167             if pow5 == 0:
168                 acclo.output_word(bitpos, chunk, "r->w[%d]", outwordpos)
169             else:
170                 acclo.transfer_to_next_acc(bitpos, chunk, pow5, acc2lo)
171                 bits_needed_in_acc2 -= chunk
172                 if bits_needed_in_acc2 == 0:
173                     assert acc2lo.placeval % bignum_int_bits == 0
174                     other_outwordpos = acc2lo.placeval / bignum_int_bits
175                     acc2lo.add_input_word("r->w[%d]", other_outwordpos)
176                     acc2lo.output_word(bitpos, bignum_int_bits, "r->w[%d]",
177                                        other_outwordpos)
178                     acc2lo.shift_down_from(None)
179                     bits_needed_in_acc2 = bignum_int_bits
180             bits_in_word -= chunk
181             bits_at_pow5 -= chunk
182             bitpos += chunk
183             if bits_at_pow5 == 0:
184                 if pow5 > 0:
185                     assert acc2lo.placeval % bignum_int_bits == 0
186                     other_outwordpos = acc2lo.placeval / bignum_int_bits
187                     acc2lo.add_input_word("r->w[%d]", other_outwordpos)
188                     acc2lo.output_word(0, bignum_int_bits, "r->w[%d]",
189                                        other_outwordpos)
190                 pow5 += 1
191                 bits_at_pow5 = inbits
192                 acc2lo.clear(0)
193                 bits_needed_in_acc2 = bignum_int_bits
194         acclo.shift_down_from(acchi)
195
196     while acc2lo.maxval > 0:
197         other_outwordpos = acc2lo.placeval / bignum_int_bits
198         bitsleft = inbits - other_outwordpos * bignum_int_bits
199         limit = 1<<bitsleft if bitsleft < bignum_int_bits else None
200         acc2lo.add_input_word("r->w[%d]", other_outwordpos, limit=limit)
201         acc2lo.output_word(0, bignum_int_bits, "r->w[%d]", other_outwordpos)
202         acc2lo.shift_down_from(None)
203
204     return out.finalise()
205
206 def gen_final_reduce_1305(bignum_int_bits):
207     out = Output(bignum_int_bits)
208
209     inbits = 130
210     inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits
211
212     # We take our input number n, and compute k = 5 + 5*(n >> 130).
213     # Then k >> 130 is precisely the multiple of p that needs to be
214     # subtracted from n to reduce it to strictly less than p.
215
216     acclo = Variable(out, "acclo")
217
218     acclo.clear(0)
219     # Hopefully all the bits we're shifting down fit in the same word.
220     assert 130 / bignum_int_bits == (130 + 3 - 1) / bignum_int_bits
221     acclo.add_word("5 * ((n->w[%d] >> %d) + 1)" %
222                    (130 / bignum_int_bits, 130 % bignum_int_bits),
223                    limit = 5 * (7 + 1))
224     for wordpos in range(inwords):
225         acclo.add_input_word("n->w[%d]", wordpos)
226         # Notionally, we could call acclo.output_word here to store
227         # our adjusted value k. But we don't need to, because all we
228         # actually want is the very top word of it.
229         if wordpos == 130 / bignum_int_bits:
230             break
231         acclo.shift_down_from(None)
232
233     # Now we can find the right multiple of p to subtract. We actually
234     # subtract it by adding 5 times it, and then finally discarding
235     # the top bits of the output.
236
237     # Hopefully all the bits we're shifting down fit in the same word.
238     assert 130 / bignum_int_bits == (130 + 3 - 1) / bignum_int_bits
239     acclo.set_word("5 * (acclo >> %d)" % (130 % bignum_int_bits),
240                    limit = 5 * (7 + 1))
241     acclo.placeval = 0
242     for wordpos in range(inwords):
243         acclo.add_input_word("n->w[%d]", wordpos)
244         acclo.output_word(0, bignum_int_bits, "n->w[%d]", wordpos)
245         acclo.shift_down_from(None)
246
247     out.stmt("n->w[%d] &= (1 << %d) - 1" %
248              (130 / bignum_int_bits, 130 % bignum_int_bits))
249
250     # Here we don't call out.finalise(), because that will complain
251     # that there are bits of output we never dealt with. This is true,
252     # but all the bits in question are above 2^130, so they're bits
253     # we're discarding anyway.
254     return out.text # not out.finalise()
255
256 ops = { "mul" : gen_mul_1305,
257         "add" : gen_add,
258         "final_reduce" : gen_final_reduce_1305 }
259
260 args = sys.argv[1:]
261 if len(args) != 2 or args[0] not in ops:
262     sys.stderr.write("usage: make1305.py (%s) <bits>\n" % (" | ".join(sorted(ops))))
263     sys.exit(1)
264
265 sys.stdout.write("    /* ./contrib/make1305.py %s %s */\n" % tuple(args))
266 s = ops[args[0]](int(args[1]))
267 sys.stdout.write(s)