5 from collections import namedtuple
7 class Multiprecision(object):
8 def __init__(self, target, minval, maxval, words):
13 assert 0 <= self.minval
14 assert self.minval <= self.maxval
15 assert self.target.nwords(self.maxval) == len(words)
18 return self.words[n] if n < len(self.words) else "0"
20 def __add__(self, rhs):
21 newmin = self.minval + rhs.minval
22 newmax = self.maxval + rhs.maxval
23 nwords = self.target.nwords(newmax)
26 addfn = self.target.add
27 for i in range(nwords):
28 words.append(addfn(self.getword(i), rhs.getword(i)))
29 addfn = self.target.adc
31 return Multiprecision(self.target, newmin, newmax, words)
33 def __mul__(self, rhs):
34 newmin = self.minval * rhs.minval
35 newmax = self.maxval * rhs.maxval
36 nwords = self.target.nwords(newmax)
39 # There are basically two strategies we could take for
40 # multiplying two multiprecision integers. One is to enumerate
41 # the space of pairs of word indices in lexicographic order,
42 # essentially computing a*b[i] for each i and adding them
43 # together; the other is to enumerate in diagonal order,
44 # computing everything together that belongs at a particular
47 # For the moment, I've gone for the former.
50 for i, sword in enumerate(self.words):
53 for j, rword in enumerate(rhs.words):
56 prevwords.append(sprev[i+j])
58 prevwords.append(rprev)
59 vhi, vlo = self.target.muladd(sword, rword, *prevwords)
65 # Remove unneeded words from the top of the output, if we can
66 # prove by range analysis that they'll always be zero.
67 sprev = sprev[:self.target.nwords(newmax)]
69 return Multiprecision(self.target, newmin, newmax, sprev)
71 def extract_bits(self, start, bits=None):
73 bits = (self.maxval >> start).bit_length()
75 # Overly thorough range analysis: if min and max have the same
76 # *quotient* by 2^bits, then the result of reducing anything
77 # in the range [min,max] mod 2^bits has to fall within the
78 # obvious range. But if they have different quotients, then
79 # you can wrap round the modulus and so any value mod 2^bits
81 newmin = self.minval >> start
82 newmax = self.maxval >> start
83 if (newmin >> bits) != (newmax >> bits):
85 newmax = (1 << bits) - 1
87 nwords = self.target.nwords(newmax)
89 for i in range(nwords):
90 srcpos = i * self.target.bits + start
91 maxbits = min(self.target.bits, start + bits - srcpos)
92 wordindex = srcpos / self.target.bits
93 if srcpos % self.target.bits == 0:
94 word = self.getword(srcpos / self.target.bits)
95 elif (wordindex+1 >= len(self.words) or
96 srcpos % self.target.bits + maxbits < self.target.bits):
97 word = self.target.new_value(
98 "(%%s) >> %d" % (srcpos % self.target.bits),
99 self.getword(srcpos / self.target.bits))
101 word = self.target.new_value(
102 "((%%s) >> %d) | ((%%s) << %d)" % (
103 srcpos % self.target.bits,
104 self.target.bits - (srcpos % self.target.bits)),
105 self.getword(srcpos / self.target.bits),
106 self.getword(srcpos / self.target.bits + 1))
107 if maxbits < self.target.bits and maxbits < bits:
108 word = self.target.new_value(
109 "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
113 return Multiprecision(self.target, newmin, newmax, words)
115 # Each Statement has a list of variables it reads, and a list of ones
116 # it writes. 'forms' is a list of multiple actual C statements it
117 # could be generated as, depending on which of its output variables is
118 # actually used (e.g. no point calling BignumADC if the generated
119 # carry in a particular case is unused, or BignumMUL if nobody needs
120 # the top half). It is indexed by a bitmap whose bits correspond to
121 # the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
122 Statement = namedtuple("Statement", "rvars wvars forms")
124 class CodegenTarget(object):
125 def __init__(self, bits):
130 self.bv_words = (130 + self.bits - 1) / self.bits
133 def nwords(self, maxval):
134 return (maxval.bit_length() + self.bits - 1) / self.bits
136 def stmt(self, stmt, needed=False):
137 index = len(self.stmts)
138 self.stmts.append([needed, stmt])
139 for val in stmt.wvars:
140 self.generators[val] = index
142 def new_value(self, formatstr=None, *deps):
143 name = "v%d" % self.valindex
145 if formatstr is not None:
147 rvars=deps, wvars=[name],
148 forms=[None, name + " = " + formatstr % deps]))
151 def bigval_input(self, name, bits):
152 words = (bits + self.bits - 1) / self.bits
153 # Expect not to require an entire extra word
154 assert words == self.bv_words
156 return Multiprecision(self, 0, (1<<bits)-1, [
157 self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
159 def const(self, value):
160 # We only support constants small enough to both fit in a
161 # BignumInt (of any size supported) _and_ be expressible in C
162 # with no weird integer literal syntax like a trailing LL.
164 # Supporting larger constants would be possible - you could
165 # break 'value' up into word-sized pieces on the Python side,
166 # and generate a legal C expression for each piece by
167 # splitting it further into pieces within the
168 # standards-guaranteed 'unsigned long' limit of 32 bits and
169 # then casting those to BignumInt before combining them with
170 # shifts. But it would be a lot of effort, and since the
171 # application for this code doesn't even need it, there's no
172 # point in bothering.
174 return Multiprecision(self, value, value, ["%d" % value])
176 def current_carry(self):
177 return "carry%d" % self.carry_index
179 def add(self, a1, a2):
180 ret = self.new_value()
181 adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
182 plainform = "%s = %s + %s" % (ret, a1, a2)
183 self.carry_index += 1
184 carryout = self.current_carry()
186 rvars=[a1,a2], wvars=[ret,carryout],
187 forms=[None, adcform, plainform, adcform]))
190 def adc(self, a1, a2):
191 ret = self.new_value()
192 adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
193 plainform = "%s = %s + %s + carry" % (ret, a1, a2)
194 carryin = self.current_carry()
195 self.carry_index += 1
196 carryout = self.current_carry()
198 rvars=[a1,a2,carryin], wvars=[ret,carryout],
199 forms=[None, adcform, plainform, adcform]))
202 def muladd(self, m1, m2, *addends):
203 rlo = self.new_value()
204 rhi = self.new_value()
205 wideform = "BignumMUL%s(%s)" % (
206 { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
207 ", ".join([rhi, rlo, m1, m2] + list(addends)))
208 narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
211 rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
212 forms=[None, narrowform, wideform, wideform]))
215 def write_bigval(self, name, val):
216 for i in range(self.bv_words):
217 word = val.getword(i)
219 rvars=[word], wvars=[],
220 forms=["%s->w[%d] = %s" % (name, i, word)]),
223 def compute_needed(self):
226 self.queue = [stmt for (needed,stmt) in self.stmts if needed]
227 while len(self.queue) > 0:
228 stmt = self.queue.pop(0)
230 for var in stmt.rvars:
231 if var[0] in string.digits:
233 deps.append(self.generators[var])
236 if not self.stmts[index][0]:
237 self.stmts[index][0] = True
238 self.queue.append(self.stmts[index][1])
241 for i, (needed, stmt) in enumerate(self.stmts):
244 for (j, var) in enumerate(stmt.wvars):
248 forms.append(stmt.forms[formindex])
250 # Now we must check whether this form of the statement
251 # also writes some variables we _don't_ actually need
252 # (e.g. if you only wanted the top half from a mul, or
253 # only the carry from an adc, you'd be forced to
254 # generate the other output too). Easiest way to do
255 # this is to look for an identical statement form
256 # later in the array.
257 maxindex = max(i for i in range(len(stmt.forms))
258 if stmt.forms[i] == stmt.forms[formindex])
259 extra_vars = maxindex & ~formindex
261 while extra_vars != 0:
262 if extra_vars & (1 << bitpos):
263 extra_vars &= ~(1 << bitpos)
264 var = stmt.wvars[-1-bitpos]
266 # Also, write out a cast-to-void for each
267 # subsequently unused value, to prevent gcc
268 # warnings when the output code is compiled.
269 forms.append("(void)" + var)
272 used_carry = any(v.startswith("carry") for v in used_vars)
273 used_vars = [v for v in used_vars if v.startswith("v")]
274 used_vars.sort(key=lambda v: int(v[1:]))
276 return used_carry, used_vars, forms
279 used_carry, values, forms = self.compute_needed()
282 while len(values) > 0:
283 prefix, sep, suffix = " BignumInt ", ", ", ";"
284 currline = values.pop(0)
285 while (len(values) > 0 and
286 len(prefix+currline+sep+values[0]+suffix) < 79):
287 currline += sep + values.pop(0)
288 ret += prefix + currline + suffix + "\n"
290 ret += " BignumCarry carry;\n"
293 for stmtform in forms:
294 ret += " %s;\n" % stmtform
298 # This is an addition _without_ reduction mod p, so that it can be
299 # used both during accumulation of the polynomial and for adding
300 # on the encrypted nonce at the end (which is mod 2^128, not mod
303 # Because one of the inputs will have come from our
304 # not-completely-reducing multiplication function, we expect up to
305 # 3 extra bits of input.
307 a = target.bigval_input("a", 133)
308 b = target.bigval_input("b", 133)
310 target.write_bigval("r", ret)
312 static void bigval_add(bigval *r, const bigval *a, const bigval *b)
315 \n""" % target.text()
318 # The inputs are not 100% reduced mod p. Specifically, we can get
319 # a full 130-bit number from the pow5==0 pass, and then a 130-bit
320 # number times 5 from the pow5==1 pass, plus a possible carry. The
321 # total of that can be easily bounded above by 2^130 * 8, so we
322 # need to assume we're multiplying two 133-bit numbers.
324 a = target.bigval_input("a", 133)
325 b = target.bigval_input("b", 133)
327 ab0 = ab.extract_bits(0, 130)
328 ab1 = ab.extract_bits(130, 130)
329 ab2 = ab.extract_bits(260)
330 ab1_5 = target.const(5) * ab1
331 ab2_25 = target.const(25) * ab2
332 ret = ab0 + ab1_5 + ab2_25
333 target.write_bigval("r", ret)
335 static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
338 \n""" % target.text()
340 def gen_final_reduce(target):
341 # We take our input number n, and compute k = n + 5*(n >> 130).
342 # Then k >> 130 is precisely the multiple of p that needs to be
343 # subtracted from n to reduce it to strictly less than p.
345 a = target.bigval_input("n", 133)
346 a1 = a.extract_bits(130, 130)
347 k = a + target.const(5) * a1
348 q = k.extract_bits(130)
349 adjusted = a + target.const(5) * q
350 ret = adjusted.extract_bits(0, 130)
351 target.write_bigval("n", ret)
353 static void bigval_final_reduce(bigval *n)
356 \n""" % target.text()
359 for bits in [16, 32, 64]:
360 sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
362 sys.stdout.write(gen_add(CodegenTarget(bits)))
363 sys.stdout.write(gen_mul(CodegenTarget(bits)))
364 sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
365 sys.stdout.write("""#else
366 #error Add another bit count to contrib/make1305.py and rerun it