]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - contrib/make1305.py
first pass
[PuTTY.git] / contrib / make1305.py
1 #!/usr/bin/env python
2
3 import sys
4 import string
5 from collections import namedtuple
6
7 class Multiprecision(object):
8     def __init__(self, target, minval, maxval, words):
9         self.target = target
10         self.minval = minval
11         self.maxval = maxval
12         self.words = words
13         assert 0 <= self.minval
14         assert self.minval <= self.maxval
15         assert self.target.nwords(self.maxval) == len(words)
16
17     def getword(self, n):
18         return self.words[n] if n < len(self.words) else "0"
19
20     def __add__(self, rhs):
21         newmin = self.minval + rhs.minval
22         newmax = self.maxval + rhs.maxval
23         nwords = self.target.nwords(newmax)
24         words = []
25
26         addfn = self.target.add
27         for i in range(nwords):
28             words.append(addfn(self.getword(i), rhs.getword(i)))
29             addfn = self.target.adc
30
31         return Multiprecision(self.target, newmin, newmax, words)
32
33     def __mul__(self, rhs):
34         newmin = self.minval * rhs.minval
35         newmax = self.maxval * rhs.maxval
36         nwords = self.target.nwords(newmax)
37         words = []
38
39         # There are basically two strategies we could take for
40         # multiplying two multiprecision integers. One is to enumerate
41         # the space of pairs of word indices in lexicographic order,
42         # essentially computing a*b[i] for each i and adding them
43         # together; the other is to enumerate in diagonal order,
44         # computing everything together that belongs at a particular
45         # output word index.
46         #
47         # For the moment, I've gone for the former.
48
49         sprev = []
50         for i, sword in enumerate(self.words):
51             rprev = None
52             sthis = sprev[:i]
53             for j, rword in enumerate(rhs.words):
54                 prevwords = []
55                 if i+j < len(sprev):
56                     prevwords.append(sprev[i+j])
57                 if rprev is not None:
58                     prevwords.append(rprev)
59                 vhi, vlo = self.target.muladd(sword, rword, *prevwords)
60                 sthis.append(vlo)
61                 rprev = vhi
62             sthis.append(rprev)
63             sprev = sthis
64
65         # Remove unneeded words from the top of the output, if we can
66         # prove by range analysis that they'll always be zero.
67         sprev = sprev[:self.target.nwords(newmax)]
68
69         return Multiprecision(self.target, newmin, newmax, sprev)
70
71     def extract_bits(self, start, bits=None):
72         if bits is None:
73             bits = (self.maxval >> start).bit_length()
74
75         # Overly thorough range analysis: if min and max have the same
76         # *quotient* by 2^bits, then the result of reducing anything
77         # in the range [min,max] mod 2^bits has to fall within the
78         # obvious range. But if they have different quotients, then
79         # you can wrap round the modulus and so any value mod 2^bits
80         # is possible.
81         newmin = self.minval >> start
82         newmax = self.maxval >> start
83         if (newmin >> bits) != (newmax >> bits):
84             newmin = 0
85             newmax = (1 << bits) - 1
86
87         nwords = self.target.nwords(newmax)
88         words = []
89         for i in range(nwords):
90             srcpos = i * self.target.bits + start
91             maxbits = min(self.target.bits, start + bits - srcpos)
92             wordindex = srcpos / self.target.bits
93             if srcpos % self.target.bits == 0:
94                 word = self.getword(srcpos / self.target.bits)
95             elif (wordindex+1 >= len(self.words) or
96                   srcpos % self.target.bits + maxbits < self.target.bits):
97                 word = self.target.new_value(
98                     "(%%s) >> %d" % (srcpos % self.target.bits),
99                     self.getword(srcpos / self.target.bits))
100             else:
101                 word = self.target.new_value(
102                     "((%%s) >> %d) | ((%%s) << %d)" % (
103                         srcpos % self.target.bits,
104                         self.target.bits - (srcpos % self.target.bits)),
105                     self.getword(srcpos / self.target.bits),
106                     self.getword(srcpos / self.target.bits + 1))
107             if maxbits < self.target.bits and maxbits < bits:
108                 word = self.target.new_value(
109                     "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
110                     word)
111             words.append(word)
112
113         return Multiprecision(self.target, newmin, newmax, words)
114
115 # Each Statement has a list of variables it reads, and a list of ones
116 # it writes. 'forms' is a list of multiple actual C statements it
117 # could be generated as, depending on which of its output variables is
118 # actually used (e.g. no point calling BignumADC if the generated
119 # carry in a particular case is unused, or BignumMUL if nobody needs
120 # the top half). It is indexed by a bitmap whose bits correspond to
121 # the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
122 Statement = namedtuple("Statement", "rvars wvars forms")
123
124 class CodegenTarget(object):
125     def __init__(self, bits):
126         self.bits = bits
127         self.valindex = 0
128         self.stmts = []
129         self.generators = {}
130         self.bv_words = (130 + self.bits - 1) / self.bits
131         self.carry_index = 0
132
133     def nwords(self, maxval):
134         return (maxval.bit_length() + self.bits - 1) / self.bits
135
136     def stmt(self, stmt, needed=False):
137         index = len(self.stmts)
138         self.stmts.append([needed, stmt])
139         for val in stmt.wvars:
140             self.generators[val] = index
141
142     def new_value(self, formatstr=None, *deps):
143         name = "v%d" % self.valindex
144         self.valindex += 1
145         if formatstr is not None:
146             self.stmt(Statement(
147                     rvars=deps, wvars=[name],
148                     forms=[None, name + " = " + formatstr % deps]))
149         return name
150
151     def bigval_input(self, name, bits):
152         words = (bits + self.bits - 1) / self.bits
153         # Expect not to require an entire extra word
154         assert words == self.bv_words
155
156         return Multiprecision(self, 0, (1<<bits)-1, [
157                 self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
158
159     def const(self, value):
160         # We only support constants small enough to both fit in a
161         # BignumInt (of any size supported) _and_ be expressible in C
162         # with no weird integer literal syntax like a trailing LL.
163         #
164         # Supporting larger constants would be possible - you could
165         # break 'value' up into word-sized pieces on the Python side,
166         # and generate a legal C expression for each piece by
167         # splitting it further into pieces within the
168         # standards-guaranteed 'unsigned long' limit of 32 bits and
169         # then casting those to BignumInt before combining them with
170         # shifts. But it would be a lot of effort, and since the
171         # application for this code doesn't even need it, there's no
172         # point in bothering.
173         assert value < 2**16
174         return Multiprecision(self, value, value, ["%d" % value])
175
176     def current_carry(self):
177         return "carry%d" % self.carry_index
178
179     def add(self, a1, a2):
180         ret = self.new_value()
181         adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
182         plainform = "%s = %s + %s" % (ret, a1, a2)
183         self.carry_index += 1
184         carryout = self.current_carry()
185         self.stmt(Statement(
186                 rvars=[a1,a2], wvars=[ret,carryout],
187                 forms=[None, adcform, plainform, adcform]))
188         return ret
189
190     def adc(self, a1, a2):
191         ret = self.new_value()
192         adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
193         plainform = "%s = %s + %s + carry" % (ret, a1, a2)
194         carryin = self.current_carry()
195         self.carry_index += 1
196         carryout = self.current_carry()
197         self.stmt(Statement(
198                 rvars=[a1,a2,carryin], wvars=[ret,carryout],
199                 forms=[None, adcform, plainform, adcform]))
200         return ret
201
202     def muladd(self, m1, m2, *addends):
203         rlo = self.new_value()
204         rhi = self.new_value()
205         wideform = "BignumMUL%s(%s)" % (
206             { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
207             ", ".join([rhi, rlo, m1, m2] + list(addends)))
208         narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
209                                 list(addends))
210         self.stmt(Statement(
211                 rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
212                 forms=[None, narrowform, wideform, wideform]))
213         return rhi, rlo
214
215     def write_bigval(self, name, val):
216         for i in range(self.bv_words):
217             word = val.getword(i)
218             self.stmt(Statement(
219                     rvars=[word], wvars=[],
220                     forms=["%s->w[%d] = %s" % (name, i, word)]),
221                       needed=True)
222
223     def compute_needed(self):
224         used_vars = set()
225
226         self.queue = [stmt for (needed,stmt) in self.stmts if needed]
227         while len(self.queue) > 0:
228             stmt = self.queue.pop(0)
229             deps = []
230             for var in stmt.rvars:
231                 if var[0] in string.digits:
232                     continue # constant
233                 deps.append(self.generators[var])
234                 used_vars.add(var)
235             for index in deps:
236                 if not self.stmts[index][0]:
237                     self.stmts[index][0] = True
238                     self.queue.append(self.stmts[index][1])
239
240         forms = []
241         for i, (needed, stmt) in enumerate(self.stmts):
242             if needed:
243                 formindex = 0
244                 for (j, var) in enumerate(stmt.wvars):
245                     formindex *= 2
246                     if var in used_vars:
247                         formindex += 1
248                 forms.append(stmt.forms[formindex])
249
250                 # Now we must check whether this form of the statement
251                 # also writes some variables we _don't_ actually need
252                 # (e.g. if you only wanted the top half from a mul, or
253                 # only the carry from an adc, you'd be forced to
254                 # generate the other output too). Easiest way to do
255                 # this is to look for an identical statement form
256                 # later in the array.
257                 maxindex = max(i for i in range(len(stmt.forms))
258                                if stmt.forms[i] == stmt.forms[formindex])
259                 extra_vars = maxindex & ~formindex
260                 bitpos = 0
261                 while extra_vars != 0:
262                     if extra_vars & (1 << bitpos):
263                         extra_vars &= ~(1 << bitpos)
264                         var = stmt.wvars[-1-bitpos]
265                         used_vars.add(var)
266                         # Also, write out a cast-to-void for each
267                         # subsequently unused value, to prevent gcc
268                         # warnings when the output code is compiled.
269                         forms.append("(void)" + var)
270                     bitpos += 1
271
272         used_carry = any(v.startswith("carry") for v in used_vars)
273         used_vars = [v for v in used_vars if v.startswith("v")]
274         used_vars.sort(key=lambda v: int(v[1:]))
275
276         return used_carry, used_vars, forms
277
278     def text(self):
279         used_carry, values, forms = self.compute_needed()
280
281         ret = ""
282         while len(values) > 0:
283             prefix, sep, suffix = "    BignumInt ", ", ", ";"
284             currline = values.pop(0)
285             while (len(values) > 0 and
286                    len(prefix+currline+sep+values[0]+suffix) < 79):
287                 currline += sep + values.pop(0)
288             ret += prefix + currline + suffix + "\n"
289         if used_carry:
290             ret += "    BignumCarry carry;\n"
291         if ret != "":
292             ret += "\n"
293         for stmtform in forms:
294             ret += "    %s;\n" % stmtform
295         return ret
296
297 def gen_add(target):
298     # This is an addition _without_ reduction mod p, so that it can be
299     # used both during accumulation of the polynomial and for adding
300     # on the encrypted nonce at the end (which is mod 2^128, not mod
301     # p).
302     #
303     # Because one of the inputs will have come from our
304     # not-completely-reducing multiplication function, we expect up to
305     # 3 extra bits of input.
306
307     a = target.bigval_input("a", 133)
308     b = target.bigval_input("b", 133)
309     ret = a + b
310     target.write_bigval("r", ret)
311     return """\
312 static void bigval_add(bigval *r, const bigval *a, const bigval *b)
313 {
314 %s}
315 \n""" % target.text()
316
317 def gen_mul(target):
318     # The inputs are not 100% reduced mod p. Specifically, we can get
319     # a full 130-bit number from the pow5==0 pass, and then a 130-bit
320     # number times 5 from the pow5==1 pass, plus a possible carry. The
321     # total of that can be easily bounded above by 2^130 * 8, so we
322     # need to assume we're multiplying two 133-bit numbers.
323
324     a = target.bigval_input("a", 133)
325     b = target.bigval_input("b", 133)
326     ab = a * b
327     ab0 = ab.extract_bits(0, 130)
328     ab1 = ab.extract_bits(130, 130)
329     ab2 = ab.extract_bits(260)
330     ab1_5 = target.const(5) * ab1
331     ab2_25 = target.const(25) * ab2
332     ret = ab0 + ab1_5 + ab2_25
333     target.write_bigval("r", ret)
334     return """\
335 static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
336 {
337 %s}
338 \n""" % target.text()
339
340 def gen_final_reduce(target):
341     # We take our input number n, and compute k = n + 5*(n >> 130).
342     # Then k >> 130 is precisely the multiple of p that needs to be
343     # subtracted from n to reduce it to strictly less than p.
344
345     a = target.bigval_input("n", 133)
346     a1 = a.extract_bits(130, 130)
347     k = a + target.const(5) * a1
348     q = k.extract_bits(130)
349     adjusted = a + target.const(5) * q
350     ret = adjusted.extract_bits(0, 130)
351     target.write_bigval("n", ret)
352     return """\
353 static void bigval_final_reduce(bigval *n)
354 {
355 %s}
356 \n""" % target.text()
357
358 pp_keyword = "#if"
359 for bits in [16, 32, 64]:
360     sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
361     pp_keyword = "#elif"
362     sys.stdout.write(gen_add(CodegenTarget(bits)))
363     sys.stdout.write(gen_mul(CodegenTarget(bits)))
364     sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
365 sys.stdout.write("""#else
366 #error Add another bit count to contrib/make1305.py and rerun it
367 #endif
368 """)