]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - testdata/bignum.py
first pass
[PuTTY.git] / testdata / bignum.py
1 # Generate test cases for a bignum implementation.
2
3 import sys
4
5 # integer square roots
6 def sqrt(n):
7     d = long(n)
8     a = 0L
9     # b must start off as a power of 4 at least as large as n
10     ndigits = len(hex(long(n)))
11     b = 1L << (ndigits*4)
12     while 1:
13         a = a >> 1
14         di = 2*a + b
15         if di <= d:
16             d = d - di
17             a = a + b
18         b = b >> 2
19         if b == 0: break
20     return a
21
22 # continued fraction convergents of a rational
23 def confrac(n, d):
24     coeffs = [(1,0),(0,1)]
25     while d != 0:
26         i = n / d
27         n, d = d, n % d
28         coeffs.append((coeffs[-2][0]-i*coeffs[-1][0],
29                        coeffs[-2][1]-i*coeffs[-1][1]))
30     return coeffs
31
32 def findprod(target, dir = +1, ratio=(1,1)):
33     # Return two numbers whose product is as close as we can get to
34     # 'target', with any deviation having the sign of 'dir', and in
35     # the same approximate ratio as 'ratio'.
36
37     r = sqrt(target * ratio[0] * ratio[1])
38     a = r / ratio[1]
39     b = r / ratio[0]
40     if a*b * dir < target * dir:
41         a = a + 1
42         b = b + 1
43     assert a*b * dir >= target * dir
44
45     best = (a,b,a*b)
46
47     while 1:
48         improved = 0
49         a, b = best[:2]
50
51         coeffs = confrac(a, b)
52         for c in coeffs:
53             # a*c[0]+b*c[1] is as close as we can get it to zero. So
54             # if we replace a and b with a+c[1] and b+c[0], then that
55             # will be added to our product, along with c[0]*c[1].
56             da, db = c[1], c[0]
57
58             # Flip signs as appropriate.
59             if (a+da) * (b+db) * dir < target * dir:
60                 da, db = -da, -db
61
62             # Multiply up. We want to get as close as we can to a
63             # solution of the quadratic equation in n
64             #
65             #    (a + n da) (b + n db) = target
66             # => n^2 da db + n (b da + a db) + (a b - target) = 0
67             A,B,C = da*db, b*da+a*db, a*b-target
68             discrim = B^2-4*A*C
69             if discrim > 0 and A != 0:
70                 root = sqrt(discrim)
71                 vals = []
72                 vals.append((-B + root) / (2*A))
73                 vals.append((-B - root) / (2*A))
74                 if root * root != discrim:
75                     root = root + 1
76                     vals.append((-B + root) / (2*A))
77                     vals.append((-B - root) / (2*A))
78
79                 for n in vals:
80                     ap = a + da*n
81                     bp = b + db*n
82                     pp = ap*bp
83                     if pp * dir >= target * dir and pp * dir < best[2]*dir:
84                         best = (ap, bp, pp)
85                         improved = 1
86
87         if not improved:
88             break
89
90     return best
91
92 def hexstr(n):
93     s = hex(n)
94     if s[:2] == "0x": s = s[2:]
95     if s[-1:] == "L": s = s[:-1]
96     return s
97
98 # Tests of multiplication which exercise the propagation of the last
99 # carry to the very top of the number.
100 for i in range(1,4200):
101     a, b, p = findprod((1<<i)+1, +1, (i, i*i+1))
102     print "mul", hexstr(a), hexstr(b), hexstr(p)
103     a, b, p = findprod((1<<i)+1, +1, (i, i+1))
104     print "mul", hexstr(a), hexstr(b), hexstr(p)
105
106 # Bare tests of division/modulo.
107 prefixes = [2**63, int(2**63.5), 2**64-1]
108 for nsize in range(20, 200):
109     for dsize in range(20, 200):
110         for dprefix in prefixes:
111             d = sqrt(3<<(2*dsize)) + (dprefix<<dsize)
112             for nprefix in prefixes:
113                 nbase = sqrt(3<<(2*nsize)) + (nprefix<<nsize)
114                 for modulus in sorted({-1, 0, +1, d/2, nbase % d}):
115                     n = nbase - (nbase % d) + modulus
116                     if n < 0:
117                         n += d
118                         assert n >= 0
119                     print "divmod", hexstr(n), hexstr(d), hexstr(n/d), hexstr(n%d)
120
121 # Simple tests of modmul.
122 for ai in range(20, 200, 60):
123     a = sqrt(3<<(2*ai-1))
124     for bi in range(20, 200, 60):
125         b = sqrt(5<<(2*bi-1))
126         for m in range(20, 600, 32):
127             m = sqrt(2**(m+1))
128             print "modmul", hexstr(a), hexstr(b), hexstr(m), hexstr((a*b) % m)
129
130 # Simple tests of modpow.
131 for i in range(64, 4097, 63):
132     modulus = sqrt(1<<(2*i-1)) | 1
133     base = sqrt(3*modulus*modulus) % modulus
134     expt = sqrt(modulus*modulus*2/5)
135     print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus))
136     if i <= 1024:
137         # Test even moduli, which can't be done by Montgomery.
138         modulus = modulus - 1
139         print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus))
140         print "pow", hexstr(i), hexstr(expt), hexstr(modulus), hexstr(pow(i, expt, modulus))