X-Git-Url: https://asedeno.scripts.mit.edu/gitweb/?a=blobdiff_plain;f=contrib%2Fmake1305.py;h=c8597040556c391c020f0dbd8d9cced4f9a97595;hb=991d30412d0911e7727a852d0a00ae0f1bec1b3e;hp=ede220df70a9c742110ca628aaef15238432421b;hpb=0bd014e456a0e5f755c45a8a5a420d6fad85c1d8;p=PuTTY.git diff --git a/contrib/make1305.py b/contrib/make1305.py index ede220df..c8597040 100755 --- a/contrib/make1305.py +++ b/contrib/make1305.py @@ -1,108 +1,300 @@ #!/usr/bin/env python import sys +import string +from collections import namedtuple -class Output(object): - def __init__(self, bignum_int_bits): - self.bignum_int_bits = bignum_int_bits - self.text = "" - self.vars = [] - def stmt(self, statement): - self.text += " %s;\n" % statement - def register_var(self, var): - self.vars.append(var) - def finalise(self): - for var in self.vars: - assert var.maxval == 0, "Variable not clear: %s" % var.name - return self.text - -class Variable(object): - def __init__(self, out, name): - self.out = out - self.maxval = 0 - self.name = name - self.placeval = None - self.out.stmt("BignumDblInt %s" % (self.name)) - self.out.register_var(self) - def clear(self, placeval): - self.maxval = 0 - self.placeval = placeval - self.out.stmt("%s = 0" % (self.name)) - def set_word(self, name, limit=None): - if limit is not None: - self.maxval = limit-1 - else: - self.maxval = (1 << self.out.bignum_int_bits) - 1 - assert self.maxval < (1 << 2*self.out.bignum_int_bits) - self.out.stmt("%s = %s" % (self.name, name)) - def add_word(self, name, limit=None): - if limit is not None: - self.maxval += limit-1 - else: - self.maxval += (1 << self.out.bignum_int_bits) - 1 - assert self.maxval < (1 << 2*self.out.bignum_int_bits) - self.out.stmt("%s += %s" % (self.name, name)) - def add_input_word(self, fmt, wordpos, limit=None): - assert self.placeval == wordpos * self.out.bignum_int_bits - self.add_word(fmt % wordpos, limit) - def set_to_product(self, a, b, placeval): - self.maxval = ((1 << self.out.bignum_int_bits) - 1) ** 2 - assert self.maxval < (1 << 2*self.out.bignum_int_bits) - self.out.stmt("%s = (BignumDblInt)(%s) * (%s)" % (self.name, a, b)) - self.placeval = placeval - def add_bottom_half(self, srcvar): - self.add_word("%s & BIGNUM_INT_MASK" % (srcvar.name)) - def add_top_half(self, srcvar): - self.add_word("%s >> %d" % (srcvar.name, self.out.bignum_int_bits)) - def unload_into(self, topvar, botvar): - assert botvar.placeval == self.placeval - botvar.add_bottom_half(self) - assert topvar.placeval == self.placeval + self.out.bignum_int_bits - topvar.add_top_half(self) - self.maxval = 0 - def output_word(self, bitpos, bits, destfmt, destwordpos): - assert bitpos == 0 - assert self.placeval == destwordpos * self.out.bignum_int_bits - dest = destfmt % destwordpos - if bits == self.out.bignum_int_bits: - self.out.stmt("%s = %s" % (dest, self.name)) - else: - self.out.stmt("%s = %s & (((BignumInt)1 << %d)-1)" % - (dest, self.name, bits)) - def transfer_to_next_acc(self, bitpos, bits, pow5, destvar): - destbitpos = self.placeval + bitpos - 130 * pow5 - destvar.placeval - #print "transfer", "*%d" % 5**pow5, self.name, self.placeval, bitpos, destvar.name, destvar.placeval, destbitpos, bits - assert 0 <= bitpos < bitpos+bits <= self.out.bignum_int_bits - assert 0 <= destbitpos < destbitpos+bits <= self.out.bignum_int_bits - expr = self.name - if bitpos > 0: - expr = "(%s >> %d)" % (expr, bitpos) - expr = "(%s & (((BignumInt)1 << %d)-1))" % (expr, bits) - self.out.stmt("%s += %s * ((BignumDblInt)%d << %d)" % - (destvar.name, expr, 5**pow5, destbitpos)) - destvar.maxval += (((1 << bits)-1) << destbitpos) * (5**pow5) - def shift_down_from(self, top): - if top is not None: - self.out.stmt("%s = %s + (%s >> %d)" % - (self.name, top.name, self.name, - self.out.bignum_int_bits)) - topmaxval = top.maxval - else: - self.out.stmt("%s >>= %d" % (self.name, self.out.bignum_int_bits)) - topmaxval = 0 - self.maxval = topmaxval + self.maxval >> self.out.bignum_int_bits - assert self.maxval < (1 << 2*self.out.bignum_int_bits) - if top is not None: - assert self.placeval + self.out.bignum_int_bits == top.placeval - top.clear(top.placeval + self.out.bignum_int_bits) - self.placeval += self.out.bignum_int_bits - -def gen_add(bignum_int_bits): - out = Output(bignum_int_bits) - - inbits = 130 - inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits +class Multiprecision(object): + def __init__(self, target, minval, maxval, words): + self.target = target + self.minval = minval + self.maxval = maxval + self.words = words + assert 0 <= self.minval + assert self.minval <= self.maxval + assert self.target.nwords(self.maxval) == len(words) + def getword(self, n): + return self.words[n] if n < len(self.words) else "0" + + def __add__(self, rhs): + newmin = self.minval + rhs.minval + newmax = self.maxval + rhs.maxval + nwords = self.target.nwords(newmax) + words = [] + + addfn = self.target.add + for i in range(nwords): + words.append(addfn(self.getword(i), rhs.getword(i))) + addfn = self.target.adc + + return Multiprecision(self.target, newmin, newmax, words) + + def __mul__(self, rhs): + newmin = self.minval * rhs.minval + newmax = self.maxval * rhs.maxval + nwords = self.target.nwords(newmax) + words = [] + + # There are basically two strategies we could take for + # multiplying two multiprecision integers. One is to enumerate + # the space of pairs of word indices in lexicographic order, + # essentially computing a*b[i] for each i and adding them + # together; the other is to enumerate in diagonal order, + # computing everything together that belongs at a particular + # output word index. + # + # For the moment, I've gone for the former. + + sprev = [] + for i, sword in enumerate(self.words): + rprev = None + sthis = sprev[:i] + for j, rword in enumerate(rhs.words): + prevwords = [] + if i+j < len(sprev): + prevwords.append(sprev[i+j]) + if rprev is not None: + prevwords.append(rprev) + vhi, vlo = self.target.muladd(sword, rword, *prevwords) + sthis.append(vlo) + rprev = vhi + sthis.append(rprev) + sprev = sthis + + # Remove unneeded words from the top of the output, if we can + # prove by range analysis that they'll always be zero. + sprev = sprev[:self.target.nwords(newmax)] + + return Multiprecision(self.target, newmin, newmax, sprev) + + def extract_bits(self, start, bits=None): + if bits is None: + bits = (self.maxval >> start).bit_length() + + # Overly thorough range analysis: if min and max have the same + # *quotient* by 2^bits, then the result of reducing anything + # in the range [min,max] mod 2^bits has to fall within the + # obvious range. But if they have different quotients, then + # you can wrap round the modulus and so any value mod 2^bits + # is possible. + newmin = self.minval >> start + newmax = self.maxval >> start + if (newmin >> bits) != (newmax >> bits): + newmin = 0 + newmax = (1 << bits) - 1 + + nwords = self.target.nwords(newmax) + words = [] + for i in range(nwords): + srcpos = i * self.target.bits + start + maxbits = min(self.target.bits, start + bits - srcpos) + wordindex = srcpos / self.target.bits + if srcpos % self.target.bits == 0: + word = self.getword(srcpos / self.target.bits) + elif (wordindex+1 >= len(self.words) or + srcpos % self.target.bits + maxbits < self.target.bits): + word = self.target.new_value( + "(%%s) >> %d" % (srcpos % self.target.bits), + self.getword(srcpos / self.target.bits)) + else: + word = self.target.new_value( + "((%%s) >> %d) | ((%%s) << %d)" % ( + srcpos % self.target.bits, + self.target.bits - (srcpos % self.target.bits)), + self.getword(srcpos / self.target.bits), + self.getword(srcpos / self.target.bits + 1)) + if maxbits < self.target.bits and maxbits < bits: + word = self.target.new_value( + "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits, + word) + words.append(word) + + return Multiprecision(self.target, newmin, newmax, words) + +# Each Statement has a list of variables it reads, and a list of ones +# it writes. 'forms' is a list of multiple actual C statements it +# could be generated as, depending on which of its output variables is +# actually used (e.g. no point calling BignumADC if the generated +# carry in a particular case is unused, or BignumMUL if nobody needs +# the top half). It is indexed by a bitmap whose bits correspond to +# the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB. +Statement = namedtuple("Statement", "rvars wvars forms") + +class CodegenTarget(object): + def __init__(self, bits): + self.bits = bits + self.valindex = 0 + self.stmts = [] + self.generators = {} + self.bv_words = (130 + self.bits - 1) / self.bits + self.carry_index = 0 + + def nwords(self, maxval): + return (maxval.bit_length() + self.bits - 1) / self.bits + + def stmt(self, stmt, needed=False): + index = len(self.stmts) + self.stmts.append([needed, stmt]) + for val in stmt.wvars: + self.generators[val] = index + + def new_value(self, formatstr=None, *deps): + name = "v%d" % self.valindex + self.valindex += 1 + if formatstr is not None: + self.stmt(Statement( + rvars=deps, wvars=[name], + forms=[None, name + " = " + formatstr % deps])) + return name + + def bigval_input(self, name, bits): + words = (bits + self.bits - 1) / self.bits + # Expect not to require an entire extra word + assert words == self.bv_words + + return Multiprecision(self, 0, (1<w[%d]" % (name, i)) for i in range(words)]) + + def const(self, value): + # We only support constants small enough to both fit in a + # BignumInt (of any size supported) _and_ be expressible in C + # with no weird integer literal syntax like a trailing LL. + # + # Supporting larger constants would be possible - you could + # break 'value' up into word-sized pieces on the Python side, + # and generate a legal C expression for each piece by + # splitting it further into pieces within the + # standards-guaranteed 'unsigned long' limit of 32 bits and + # then casting those to BignumInt before combining them with + # shifts. But it would be a lot of effort, and since the + # application for this code doesn't even need it, there's no + # point in bothering. + assert value < 2**16 + return Multiprecision(self, value, value, ["%d" % value]) + + def current_carry(self): + return "carry%d" % self.carry_index + + def add(self, a1, a2): + ret = self.new_value() + adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2) + plainform = "%s = %s + %s" % (ret, a1, a2) + self.carry_index += 1 + carryout = self.current_carry() + self.stmt(Statement( + rvars=[a1,a2], wvars=[ret,carryout], + forms=[None, adcform, plainform, adcform])) + return ret + + def adc(self, a1, a2): + ret = self.new_value() + adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2) + plainform = "%s = %s + %s + carry" % (ret, a1, a2) + carryin = self.current_carry() + self.carry_index += 1 + carryout = self.current_carry() + self.stmt(Statement( + rvars=[a1,a2,carryin], wvars=[ret,carryout], + forms=[None, adcform, plainform, adcform])) + return ret + + def muladd(self, m1, m2, *addends): + rlo = self.new_value() + rhi = self.new_value() + wideform = "BignumMUL%s(%s)" % ( + { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)], + ", ".join([rhi, rlo, m1, m2] + list(addends))) + narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] + + list(addends)) + self.stmt(Statement( + rvars=[m1,m2]+list(addends), wvars=[rhi,rlo], + forms=[None, narrowform, wideform, wideform])) + return rhi, rlo + + def write_bigval(self, name, val): + for i in range(self.bv_words): + word = val.getword(i) + self.stmt(Statement( + rvars=[word], wvars=[], + forms=["%s->w[%d] = %s" % (name, i, word)]), + needed=True) + + def compute_needed(self): + used_vars = set() + + self.queue = [stmt for (needed,stmt) in self.stmts if needed] + while len(self.queue) > 0: + stmt = self.queue.pop(0) + deps = [] + for var in stmt.rvars: + if var[0] in string.digits: + continue # constant + deps.append(self.generators[var]) + used_vars.add(var) + for index in deps: + if not self.stmts[index][0]: + self.stmts[index][0] = True + self.queue.append(self.stmts[index][1]) + + forms = [] + for i, (needed, stmt) in enumerate(self.stmts): + if needed: + formindex = 0 + for (j, var) in enumerate(stmt.wvars): + formindex *= 2 + if var in used_vars: + formindex += 1 + forms.append(stmt.forms[formindex]) + + # Now we must check whether this form of the statement + # also writes some variables we _don't_ actually need + # (e.g. if you only wanted the top half from a mul, or + # only the carry from an adc, you'd be forced to + # generate the other output too). Easiest way to do + # this is to look for an identical statement form + # later in the array. + maxindex = max(i for i in range(len(stmt.forms)) + if stmt.forms[i] == stmt.forms[formindex]) + extra_vars = maxindex & ~formindex + bitpos = 0 + while extra_vars != 0: + if extra_vars & (1 << bitpos): + extra_vars &= ~(1 << bitpos) + var = stmt.wvars[-1-bitpos] + used_vars.add(var) + # Also, write out a cast-to-void for each + # subsequently unused value, to prevent gcc + # warnings when the output code is compiled. + forms.append("(void)" + var) + bitpos += 1 + + used_carry = any(v.startswith("carry") for v in used_vars) + used_vars = [v for v in used_vars if v.startswith("v")] + used_vars.sort(key=lambda v: int(v[1:])) + + return used_carry, used_vars, forms + + def text(self): + used_carry, values, forms = self.compute_needed() + + ret = "" + while len(values) > 0: + prefix, sep, suffix = " BignumInt ", ", ", ";" + currline = values.pop(0) + while (len(values) > 0 and + len(prefix+currline+sep+values[0]+suffix) < 79): + currline += sep + values.pop(0) + ret += prefix + currline + suffix + "\n" + if used_carry: + ret += " BignumCarry carry;\n" + if ret != "": + ret += "\n" + for stmtform in forms: + ret += " %s;\n" % stmtform + return ret + +def gen_add(target): # This is an addition _without_ reduction mod p, so that it can be # used both during accumulation of the polynomial and for adding # on the encrypted nonce at the end (which is mod 2^128, not mod @@ -111,157 +303,66 @@ def gen_add(bignum_int_bits): # Because one of the inputs will have come from our # not-completely-reducing multiplication function, we expect up to # 3 extra bits of input. - acclo = Variable(out, "acclo") - - acclo.clear(0) - - for wordpos in range(inwords): - limit = min(1 << bignum_int_bits, 1 << (130 - wordpos*bignum_int_bits)) - acclo.add_input_word("a->w[%d]", wordpos, limit) - acclo.add_input_word("b->w[%d]", wordpos, limit) - acclo.output_word(0, bignum_int_bits, "r->w[%d]", wordpos) - acclo.shift_down_from(None) - - return out.finalise() -def gen_mul_1305(bignum_int_bits): - out = Output(bignum_int_bits) - - inbits = 130 - inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits + a = target.bigval_input("a", 133) + b = target.bigval_input("b", 133) + ret = a + b + target.write_bigval("r", ret) + return """\ +static void bigval_add(bigval *r, const bigval *a, const bigval *b) +{ +%s} +\n""" % target.text() +def gen_mul(target): # The inputs are not 100% reduced mod p. Specifically, we can get # a full 130-bit number from the pow5==0 pass, and then a 130-bit # number times 5 from the pow5==1 pass, plus a possible carry. The # total of that can be easily bounded above by 2^130 * 8, so we # need to assume we're multiplying two 133-bit numbers. - outbits = (inbits + 3) * 2 - outwords = (outbits + bignum_int_bits - 1) / bignum_int_bits + 1 - - tmp = Variable(out, "tmp") - acclo = Variable(out, "acclo") - acchi = Variable(out, "acchi") - acc2lo = Variable(out, "acc2lo") - - pow5, bits_at_pow5 = 0, inbits - - acclo.clear(0) - acchi.clear(bignum_int_bits) - bits_needed_in_acc2 = bignum_int_bits - - for outwordpos in range(outwords): - for a in range(inwords): - b = outwordpos - a - if 0 <= b < inwords: - tmp.set_to_product("a->w[%d]" % a, "b->w[%d]" % b, - outwordpos * bignum_int_bits) - tmp.unload_into(acchi, acclo) - - bits_in_word = bignum_int_bits - bitpos = 0 - #print "begin output" - while bits_in_word > 0: - chunk = min(bits_in_word, bits_at_pow5) - if pow5 > 0: - chunk = min(chunk, bits_needed_in_acc2) - if pow5 == 0: - acclo.output_word(bitpos, chunk, "r->w[%d]", outwordpos) - else: - acclo.transfer_to_next_acc(bitpos, chunk, pow5, acc2lo) - bits_needed_in_acc2 -= chunk - if bits_needed_in_acc2 == 0: - assert acc2lo.placeval % bignum_int_bits == 0 - other_outwordpos = acc2lo.placeval / bignum_int_bits - acc2lo.add_input_word("r->w[%d]", other_outwordpos) - acc2lo.output_word(bitpos, bignum_int_bits, "r->w[%d]", - other_outwordpos) - acc2lo.shift_down_from(None) - bits_needed_in_acc2 = bignum_int_bits - bits_in_word -= chunk - bits_at_pow5 -= chunk - bitpos += chunk - if bits_at_pow5 == 0: - if pow5 > 0: - assert acc2lo.placeval % bignum_int_bits == 0 - other_outwordpos = acc2lo.placeval / bignum_int_bits - acc2lo.add_input_word("r->w[%d]", other_outwordpos) - acc2lo.output_word(0, bignum_int_bits, "r->w[%d]", - other_outwordpos) - pow5 += 1 - bits_at_pow5 = inbits - acc2lo.clear(0) - bits_needed_in_acc2 = bignum_int_bits - acclo.shift_down_from(acchi) - - while acc2lo.maxval > 0: - other_outwordpos = acc2lo.placeval / bignum_int_bits - bitsleft = inbits - other_outwordpos * bignum_int_bits - limit = 1<w[%d]", other_outwordpos, limit=limit) - acc2lo.output_word(0, bignum_int_bits, "r->w[%d]", other_outwordpos) - acc2lo.shift_down_from(None) - - return out.finalise() - -def gen_final_reduce_1305(bignum_int_bits): - out = Output(bignum_int_bits) - - inbits = 130 - inwords = (inbits + bignum_int_bits - 1) / bignum_int_bits - - # We take our input number n, and compute k = 5 + 5*(n >> 130). + + a = target.bigval_input("a", 133) + b = target.bigval_input("b", 133) + ab = a * b + ab0 = ab.extract_bits(0, 130) + ab1 = ab.extract_bits(130, 130) + ab2 = ab.extract_bits(260) + ab1_5 = target.const(5) * ab1 + ab2_25 = target.const(25) * ab2 + ret = ab0 + ab1_5 + ab2_25 + target.write_bigval("r", ret) + return """\ +static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) +{ +%s} +\n""" % target.text() + +def gen_final_reduce(target): + # We take our input number n, and compute k = n + 5*(n >> 130). # Then k >> 130 is precisely the multiple of p that needs to be # subtracted from n to reduce it to strictly less than p. - acclo = Variable(out, "acclo") - - acclo.clear(0) - # Hopefully all the bits we're shifting down fit in the same word. - assert 130 / bignum_int_bits == (130 + 3 - 1) / bignum_int_bits - acclo.add_word("5 * ((n->w[%d] >> %d) + 1)" % - (130 / bignum_int_bits, 130 % bignum_int_bits), - limit = 5 * (7 + 1)) - for wordpos in range(inwords): - acclo.add_input_word("n->w[%d]", wordpos) - # Notionally, we could call acclo.output_word here to store - # our adjusted value k. But we don't need to, because all we - # actually want is the very top word of it. - if wordpos == 130 / bignum_int_bits: - break - acclo.shift_down_from(None) - - # Now we can find the right multiple of p to subtract. We actually - # subtract it by adding 5 times it, and then finally discarding - # the top bits of the output. - - # Hopefully all the bits we're shifting down fit in the same word. - assert 130 / bignum_int_bits == (130 + 3 - 1) / bignum_int_bits - acclo.set_word("5 * (acclo >> %d)" % (130 % bignum_int_bits), - limit = 5 * (7 + 1)) - acclo.placeval = 0 - for wordpos in range(inwords): - acclo.add_input_word("n->w[%d]", wordpos) - acclo.output_word(0, bignum_int_bits, "n->w[%d]", wordpos) - acclo.shift_down_from(None) - - out.stmt("n->w[%d] &= (1 << %d) - 1" % - (130 / bignum_int_bits, 130 % bignum_int_bits)) - - # Here we don't call out.finalise(), because that will complain - # that there are bits of output we never dealt with. This is true, - # but all the bits in question are above 2^130, so they're bits - # we're discarding anyway. - return out.text # not out.finalise() - -ops = { "mul" : gen_mul_1305, - "add" : gen_add, - "final_reduce" : gen_final_reduce_1305 } - -args = sys.argv[1:] -if len(args) != 2 or args[0] not in ops: - sys.stderr.write("usage: make1305.py (%s) \n" % (" | ".join(sorted(ops)))) - sys.exit(1) - -sys.stdout.write(" /* ./contrib/make1305.py %s %s */\n" % tuple(args)) -s = ops[args[0]](int(args[1])) -sys.stdout.write(s) + a = target.bigval_input("n", 133) + a1 = a.extract_bits(130, 130) + k = a + target.const(5) * a1 + q = k.extract_bits(130) + adjusted = a + target.const(5) * q + ret = adjusted.extract_bits(0, 130) + target.write_bigval("n", ret) + return """\ +static void bigval_final_reduce(bigval *n) +{ +%s} +\n""" % target.text() + +pp_keyword = "#if" +for bits in [16, 32, 64]: + sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits)) + pp_keyword = "#elif" + sys.stdout.write(gen_add(CodegenTarget(bits))) + sys.stdout.write(gen_mul(CodegenTarget(bits))) + sys.stdout.write(gen_final_reduce(CodegenTarget(bits))) +sys.stdout.write("""#else +#error Add another bit count to contrib/make1305.py and rerun it +#endif +""")