+ def getword(self, n):
+ return self.words[n] if n < len(self.words) else "0"
+
+ def __add__(self, rhs):
+ newmin = self.minval + rhs.minval
+ newmax = self.maxval + rhs.maxval
+ nwords = self.target.nwords(newmax)
+ words = []
+
+ addfn = self.target.add
+ for i in range(nwords):
+ words.append(addfn(self.getword(i), rhs.getword(i)))
+ addfn = self.target.adc
+
+ return Multiprecision(self.target, newmin, newmax, words)
+
+ def __mul__(self, rhs):
+ newmin = self.minval * rhs.minval
+ newmax = self.maxval * rhs.maxval
+ nwords = self.target.nwords(newmax)
+ words = []
+
+ # There are basically two strategies we could take for
+ # multiplying two multiprecision integers. One is to enumerate
+ # the space of pairs of word indices in lexicographic order,
+ # essentially computing a*b[i] for each i and adding them
+ # together; the other is to enumerate in diagonal order,
+ # computing everything together that belongs at a particular
+ # output word index.
+ #
+ # For the moment, I've gone for the former.
+
+ sprev = []
+ for i, sword in enumerate(self.words):
+ rprev = None
+ sthis = sprev[:i]
+ for j, rword in enumerate(rhs.words):
+ prevwords = []
+ if i+j < len(sprev):
+ prevwords.append(sprev[i+j])
+ if rprev is not None:
+ prevwords.append(rprev)
+ vhi, vlo = self.target.muladd(sword, rword, *prevwords)
+ sthis.append(vlo)
+ rprev = vhi
+ sthis.append(rprev)
+ sprev = sthis
+
+ # Remove unneeded words from the top of the output, if we can
+ # prove by range analysis that they'll always be zero.
+ sprev = sprev[:self.target.nwords(newmax)]
+
+ return Multiprecision(self.target, newmin, newmax, sprev)
+
+ def extract_bits(self, start, bits=None):
+ if bits is None:
+ bits = (self.maxval >> start).bit_length()
+
+ # Overly thorough range analysis: if min and max have the same
+ # *quotient* by 2^bits, then the result of reducing anything
+ # in the range [min,max] mod 2^bits has to fall within the
+ # obvious range. But if they have different quotients, then
+ # you can wrap round the modulus and so any value mod 2^bits
+ # is possible.
+ newmin = self.minval >> start
+ newmax = self.maxval >> start
+ if (newmin >> bits) != (newmax >> bits):
+ newmin = 0
+ newmax = (1 << bits) - 1
+
+ nwords = self.target.nwords(newmax)
+ words = []
+ for i in range(nwords):
+ srcpos = i * self.target.bits + start
+ maxbits = min(self.target.bits, start + bits - srcpos)
+ wordindex = srcpos / self.target.bits
+ if srcpos % self.target.bits == 0:
+ word = self.getword(srcpos / self.target.bits)
+ elif (wordindex+1 >= len(self.words) or
+ srcpos % self.target.bits + maxbits < self.target.bits):
+ word = self.target.new_value(
+ "(%%s) >> %d" % (srcpos % self.target.bits),
+ self.getword(srcpos / self.target.bits))
+ else:
+ word = self.target.new_value(
+ "((%%s) >> %d) | ((%%s) << %d)" % (
+ srcpos % self.target.bits,
+ self.target.bits - (srcpos % self.target.bits)),
+ self.getword(srcpos / self.target.bits),
+ self.getword(srcpos / self.target.bits + 1))
+ if maxbits < self.target.bits and maxbits < bits:
+ word = self.target.new_value(
+ "(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
+ word)
+ words.append(word)
+
+ return Multiprecision(self.target, newmin, newmax, words)
+
+# Each Statement has a list of variables it reads, and a list of ones
+# it writes. 'forms' is a list of multiple actual C statements it
+# could be generated as, depending on which of its output variables is
+# actually used (e.g. no point calling BignumADC if the generated
+# carry in a particular case is unused, or BignumMUL if nobody needs
+# the top half). It is indexed by a bitmap whose bits correspond to
+# the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
+Statement = namedtuple("Statement", "rvars wvars forms")
+
+class CodegenTarget(object):
+ def __init__(self, bits):
+ self.bits = bits
+ self.valindex = 0
+ self.stmts = []
+ self.generators = {}
+ self.bv_words = (130 + self.bits - 1) / self.bits
+ self.carry_index = 0
+
+ def nwords(self, maxval):
+ return (maxval.bit_length() + self.bits - 1) / self.bits
+
+ def stmt(self, stmt, needed=False):
+ index = len(self.stmts)
+ self.stmts.append([needed, stmt])
+ for val in stmt.wvars:
+ self.generators[val] = index
+
+ def new_value(self, formatstr=None, *deps):
+ name = "v%d" % self.valindex
+ self.valindex += 1
+ if formatstr is not None:
+ self.stmt(Statement(
+ rvars=deps, wvars=[name],
+ forms=[None, name + " = " + formatstr % deps]))
+ return name
+
+ def bigval_input(self, name, bits):
+ words = (bits + self.bits - 1) / self.bits
+ # Expect not to require an entire extra word
+ assert words == self.bv_words
+
+ return Multiprecision(self, 0, (1<<bits)-1, [
+ self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
+
+ def const(self, value):
+ # We only support constants small enough to both fit in a
+ # BignumInt (of any size supported) _and_ be expressible in C
+ # with no weird integer literal syntax like a trailing LL.
+ #
+ # Supporting larger constants would be possible - you could
+ # break 'value' up into word-sized pieces on the Python side,
+ # and generate a legal C expression for each piece by
+ # splitting it further into pieces within the
+ # standards-guaranteed 'unsigned long' limit of 32 bits and
+ # then casting those to BignumInt before combining them with
+ # shifts. But it would be a lot of effort, and since the
+ # application for this code doesn't even need it, there's no
+ # point in bothering.
+ assert value < 2**16
+ return Multiprecision(self, value, value, ["%d" % value])
+
+ def current_carry(self):
+ return "carry%d" % self.carry_index
+
+ def add(self, a1, a2):
+ ret = self.new_value()
+ adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
+ plainform = "%s = %s + %s" % (ret, a1, a2)
+ self.carry_index += 1
+ carryout = self.current_carry()
+ self.stmt(Statement(
+ rvars=[a1,a2], wvars=[ret,carryout],
+ forms=[None, adcform, plainform, adcform]))
+ return ret
+
+ def adc(self, a1, a2):
+ ret = self.new_value()
+ adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
+ plainform = "%s = %s + %s + carry" % (ret, a1, a2)
+ carryin = self.current_carry()
+ self.carry_index += 1
+ carryout = self.current_carry()
+ self.stmt(Statement(
+ rvars=[a1,a2,carryin], wvars=[ret,carryout],
+ forms=[None, adcform, plainform, adcform]))
+ return ret
+
+ def muladd(self, m1, m2, *addends):
+ rlo = self.new_value()
+ rhi = self.new_value()
+ wideform = "BignumMUL%s(%s)" % (
+ { 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
+ ", ".join([rhi, rlo, m1, m2] + list(addends)))
+ narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
+ list(addends))
+ self.stmt(Statement(
+ rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
+ forms=[None, narrowform, wideform, wideform]))
+ return rhi, rlo
+
+ def write_bigval(self, name, val):
+ for i in range(self.bv_words):
+ word = val.getword(i)
+ self.stmt(Statement(
+ rvars=[word], wvars=[],
+ forms=["%s->w[%d] = %s" % (name, i, word)]),
+ needed=True)
+
+ def compute_needed(self):
+ used_vars = set()
+
+ self.queue = [stmt for (needed,stmt) in self.stmts if needed]
+ while len(self.queue) > 0:
+ stmt = self.queue.pop(0)
+ deps = []
+ for var in stmt.rvars:
+ if var[0] in string.digits:
+ continue # constant
+ deps.append(self.generators[var])
+ used_vars.add(var)
+ for index in deps:
+ if not self.stmts[index][0]:
+ self.stmts[index][0] = True
+ self.queue.append(self.stmts[index][1])
+
+ forms = []
+ for i, (needed, stmt) in enumerate(self.stmts):
+ if needed:
+ formindex = 0
+ for (j, var) in enumerate(stmt.wvars):
+ formindex *= 2
+ if var in used_vars:
+ formindex += 1
+ forms.append(stmt.forms[formindex])
+
+ # Now we must check whether this form of the statement
+ # also writes some variables we _don't_ actually need
+ # (e.g. if you only wanted the top half from a mul, or
+ # only the carry from an adc, you'd be forced to
+ # generate the other output too). Easiest way to do
+ # this is to look for an identical statement form
+ # later in the array.
+ maxindex = max(i for i in range(len(stmt.forms))
+ if stmt.forms[i] == stmt.forms[formindex])
+ extra_vars = maxindex & ~formindex
+ bitpos = 0
+ while extra_vars != 0:
+ if extra_vars & (1 << bitpos):
+ extra_vars &= ~(1 << bitpos)
+ var = stmt.wvars[-1-bitpos]
+ used_vars.add(var)
+ # Also, write out a cast-to-void for each
+ # subsequently unused value, to prevent gcc
+ # warnings when the output code is compiled.
+ forms.append("(void)" + var)
+ bitpos += 1
+
+ used_carry = any(v.startswith("carry") for v in used_vars)
+ used_vars = [v for v in used_vars if v.startswith("v")]
+ used_vars.sort(key=lambda v: int(v[1:]))
+
+ return used_carry, used_vars, forms
+
+ def text(self):
+ used_carry, values, forms = self.compute_needed()
+
+ ret = ""
+ while len(values) > 0:
+ prefix, sep, suffix = " BignumInt ", ", ", ";"
+ currline = values.pop(0)
+ while (len(values) > 0 and
+ len(prefix+currline+sep+values[0]+suffix) < 79):
+ currline += sep + values.pop(0)
+ ret += prefix + currline + suffix + "\n"
+ if used_carry:
+ ret += " BignumCarry carry;\n"
+ if ret != "":
+ ret += "\n"
+ for stmtform in forms:
+ ret += " %s;\n" % stmtform
+ return ret
+
+def gen_add(target):