]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - sshzlib.c
Make the SSH2 traffic analysis defence robust in the face of Zlib
[PuTTY.git] / sshzlib.c
1 /*
2  * Zlib (RFC1950 / RFC1951) compression for PuTTY.
3  * 
4  * There will no doubt be criticism of my decision to reimplement
5  * Zlib compression from scratch instead of using the existing zlib
6  * code. People will cry `reinventing the wheel'; they'll claim
7  * that the `fundamental basis of OSS' is code reuse; they'll want
8  * to see a really good reason for me having chosen not to use the
9  * existing code.
10  * 
11  * Well, here are my reasons. Firstly, I don't want to link the
12  * whole of zlib into the PuTTY binary; PuTTY is justifiably proud
13  * of its small size and I think zlib contains a lot of unnecessary
14  * baggage for the kind of compression that SSH requires.
15  * 
16  * Secondly, I also don't like the alternative of using zlib.dll.
17  * Another thing PuTTY is justifiably proud of is its ease of
18  * installation, and the last thing I want to do is to start
19  * mandating DLLs. Not only that, but there are two _kinds_ of
20  * zlib.dll kicking around, one with C calling conventions on the
21  * exported functions and another with WINAPI conventions, and
22  * there would be a significant danger of getting the wrong one.
23  * 
24  * Thirdly, there seems to be a difference of opinion on the IETF
25  * secsh mailing list about the correct way to round off a
26  * compressed packet and start the next. In particular, there's
27  * some talk of switching to a mechanism zlib isn't currently
28  * capable of supporting (see below for an explanation). Given that
29  * sort of uncertainty, I thought it might be better to have code
30  * that will support even the zlib-incompatible worst case.
31  * 
32  * Fourthly, it's a _second implementation_. Second implementations
33  * are fundamentally a Good Thing in standardisation efforts. The
34  * difference of opinion mentioned above has arisen _precisely_
35  * because there has been only one zlib implementation and
36  * everybody has used it. I don't intend that this should happen
37  * again.
38  */
39
40 #include <stdlib.h>
41 #include <assert.h>
42
43 /* FIXME */
44 #include <windows.h>
45 #include <stdio.h>
46 #include "putty.h"
47
48 #include "ssh.h"
49
50 /* ----------------------------------------------------------------------
51  * Basic LZ77 code. This bit is designed modularly, so it could be
52  * ripped out and used in a different LZ77 compressor. Go to it,
53  * and good luck :-)
54  */
55
56 struct LZ77InternalContext;
57 struct LZ77Context {
58     struct LZ77InternalContext *ictx;
59     void *userdata;
60     void (*literal)(struct LZ77Context *ctx, unsigned char c);
61     void (*match)(struct LZ77Context *ctx, int distance, int len);
62 };
63
64 /*
65  * Initialise the private fields of an LZ77Context. It's up to the
66  * user to initialise the public fields.
67  */
68 static int lz77_init(struct LZ77Context *ctx);
69
70 /*
71  * Supply data to be compressed. Will update the private fields of
72  * the LZ77Context, and will call literal() and match() to output.
73  * If `compress' is FALSE, it will never emit a match, but will
74  * instead call literal() for everything.
75  */
76 static void lz77_compress(struct LZ77Context *ctx,
77                           unsigned char *data, int len, int compress);
78
79 /*
80  * Modifiable parameters.
81  */
82 #define WINSIZE 32768                  /* window size. Must be power of 2! */
83 #define HASHMAX 2039                   /* one more than max hash value */
84 #define MAXMATCH 32                    /* how many matches we track */
85 #define HASHCHARS 3                    /* how many chars make a hash */
86
87 /*
88  * This compressor takes a less slapdash approach than the
89  * gzip/zlib one. Rather than allowing our hash chains to fall into
90  * disuse near the far end, we keep them doubly linked so we can
91  * _find_ the far end, and then every time we add a new byte to the
92  * window (thus rolling round by one and removing the previous
93  * byte), we can carefully remove the hash chain entry.
94  */
95
96 #define INVALID -1                     /* invalid hash _and_ invalid offset */
97 struct WindowEntry {
98     int next, prev;                    /* array indices within the window */
99     int hashval;
100 };
101
102 struct HashEntry {
103     int first;                         /* window index of first in chain */
104 };
105
106 struct Match {
107     int distance, len;
108 };
109
110 struct LZ77InternalContext {
111     struct WindowEntry win[WINSIZE];
112     unsigned char data[WINSIZE];
113     int winpos;
114     struct HashEntry hashtab[HASHMAX];
115     unsigned char pending[HASHCHARS];
116     int npending;
117 };
118
119 static int lz77_hash(unsigned char *data) {
120     return (257*data[0] + 263*data[1] + 269*data[2]) % HASHMAX;
121 }
122
123 static int lz77_init(struct LZ77Context *ctx) {
124     struct LZ77InternalContext *st;
125     int i;
126
127     st = (struct LZ77InternalContext *)smalloc(sizeof(*st));
128     if (!st)
129         return 0;
130
131     ctx->ictx = st;
132
133     for (i = 0; i < WINSIZE; i++)
134         st->win[i].next = st->win[i].prev = st->win[i].hashval = INVALID;
135     for (i = 0; i < HASHMAX; i++)
136         st->hashtab[i].first = INVALID;
137     st->winpos = 0;
138
139     st->npending = 0;
140
141     return 1;
142 }
143
144 static void lz77_advance(struct LZ77InternalContext *st,
145                          unsigned char c, int hash) {
146     int off;
147
148     /*
149      * Remove the hash entry at winpos from the tail of its chain,
150      * or empty the chain if it's the only thing on the chain.
151      */
152     if (st->win[st->winpos].prev != INVALID) {
153         st->win[st->win[st->winpos].prev].next = INVALID;
154     } else if (st->win[st->winpos].hashval != INVALID) {
155         st->hashtab[st->win[st->winpos].hashval].first = INVALID;
156     }
157
158     /*
159      * Create a new entry at winpos and add it to the head of its
160      * hash chain.
161      */
162     st->win[st->winpos].hashval = hash;
163     st->win[st->winpos].prev = INVALID;
164     off = st->win[st->winpos].next = st->hashtab[hash].first;
165     st->hashtab[hash].first = st->winpos;
166     if (off != INVALID)
167         st->win[off].prev = st->winpos;
168     st->data[st->winpos] = c;
169
170     /*
171      * Advance the window pointer.
172      */
173     st->winpos = (st->winpos + 1) & (WINSIZE-1);
174 }
175
176 #define CHARAT(k) ( (k)<0 ? st->data[(st->winpos+k)&(WINSIZE-1)] : data[k] )
177
178 static void lz77_compress(struct LZ77Context *ctx,
179                           unsigned char *data, int len, int compress) {
180     struct LZ77InternalContext *st = ctx->ictx;
181     int i, hash, distance, off, nmatch, matchlen, advance;
182     struct Match defermatch, matches[MAXMATCH];
183     int deferchr;
184
185     /*
186      * Add any pending characters from last time to the window. (We
187      * might not be able to.)
188      */
189     for (i = 0; i < st->npending; i++) {
190         unsigned char foo[HASHCHARS];
191         int j;
192         if (len + st->npending - i < HASHCHARS) {
193             /* Update the pending array. */
194             for (j = i; j < st->npending; j++)
195                 st->pending[j-i] = st->pending[j];
196             break;
197         }
198         for (j = 0; j < HASHCHARS; j++)
199             foo[j] = (i + j < st->npending ? st->pending[i+j] :
200                       data[i + j - st->npending]);
201         lz77_advance(st, foo[0], lz77_hash(foo));
202     }
203     st->npending -= i;
204
205     defermatch.len = 0;
206     while (len > 0) {
207
208         /* Don't even look for a match, if we're not compressing. */
209         if (compress && len >= HASHCHARS) {
210             /*
211              * Hash the next few characters.
212              */
213             hash = lz77_hash(data);
214
215             /*
216              * Look the hash up in the corresponding hash chain and see
217              * what we can find.
218              */
219             nmatch = 0;
220             for (off = st->hashtab[hash].first;
221                  off != INVALID; off = st->win[off].next) {
222                 /* distance = 1       if off == st->winpos-1 */
223                 /* distance = WINSIZE if off == st->winpos   */
224                 distance = WINSIZE - (off + WINSIZE - st->winpos) % WINSIZE;
225                 for (i = 0; i < HASHCHARS; i++)
226                     if (CHARAT(i) != CHARAT(i-distance))
227                         break;
228                 if (i == HASHCHARS) {
229                     matches[nmatch].distance = distance;
230                     matches[nmatch].len = 3;
231                     if (++nmatch >= MAXMATCH)
232                         break;
233                 }
234             }
235         } else {
236             nmatch = 0;
237             hash = INVALID;
238         }
239
240         if (nmatch > 0) {
241             /*
242              * We've now filled up matches[] with nmatch potential
243              * matches. Follow them down to find the longest. (We
244              * assume here that it's always worth favouring a
245              * longer match over a shorter one.)
246              */
247             matchlen = HASHCHARS;
248             while (matchlen < len) {
249                 int j;
250                 for (i = j = 0; i < nmatch; i++) {
251                     if (CHARAT(matchlen) ==
252                         CHARAT(matchlen - matches[i].distance)) {
253                         matches[j++] = matches[i];
254                     }
255                 }
256                 if (j == 0)
257                     break;
258                 matchlen++;
259                 nmatch = j;
260             }
261
262             /*
263              * We've now got all the longest matches. We favour the
264              * shorter distances, which means we go with matches[0].
265              * So see if we want to defer it or throw it away.
266              */
267             matches[0].len = matchlen;
268             if (defermatch.len > 0) {
269                 if (matches[0].len > defermatch.len + 1) {
270                     /* We have a better match. Emit the deferred char,
271                      * and defer this match. */
272                     ctx->literal(ctx, (unsigned char)deferchr);
273                     defermatch = matches[0];
274                     deferchr = data[0];
275                     advance = 1;
276                 } else {
277                     /* We don't have a better match. Do the deferred one. */
278                     ctx->match(ctx, defermatch.distance, defermatch.len);
279                     advance = defermatch.len - 1;
280                     defermatch.len = 0;
281                 }
282             } else {
283                 /* There was no deferred match. Defer this one. */
284                 defermatch = matches[0];
285                 deferchr = data[0];
286                 advance = 1;
287             }       
288         } else {
289             /*
290              * We found no matches. Emit the deferred match, if
291              * any; otherwise emit a literal.
292              */
293             if (defermatch.len > 0) {
294                 ctx->match(ctx, defermatch.distance, defermatch.len);
295                 advance = defermatch.len - 1;
296                 defermatch.len = 0;
297             } else {
298                 ctx->literal(ctx, data[0]);
299                 advance = 1;
300             }
301         }
302
303         /*
304          * Now advance the position by `advance' characters,
305          * keeping the window and hash chains consistent.
306          */
307         while (advance > 0) {
308             if (len >= HASHCHARS) {
309                 lz77_advance(st, *data, lz77_hash(data));
310             } else {
311                 st->pending[st->npending++] = *data;
312             }
313             data++;
314             len--;
315             advance--;
316         }
317     }
318 }
319
320 /* ----------------------------------------------------------------------
321  * Zlib compression. We always use the static Huffman tree option.
322  * Mostly this is because it's hard to scan a block in advance to
323  * work out better trees; dynamic trees are great when you're
324  * compressing a large file under no significant time constraint,
325  * but when you're compressing little bits in real time, things get
326  * hairier.
327  * 
328  * I suppose it's possible that I could compute Huffman trees based
329  * on the frequencies in the _previous_ block, as a sort of
330  * heuristic, but I'm not confident that the gain would balance out
331  * having to transmit the trees.
332  */
333
334 static struct LZ77Context ectx;
335
336 struct Outbuf {
337     unsigned char *outbuf;
338     int outlen, outsize;
339     unsigned long outbits;
340     int noutbits;
341     int firstblock;
342     int comp_disabled;
343 };
344
345 static void outbits(struct Outbuf *out, unsigned long bits, int nbits) {
346     assert(out->noutbits + nbits <= 32);
347     out->outbits |= bits << out->noutbits;
348     out->noutbits += nbits;
349     while (out->noutbits >= 8) {
350         if (out->outlen >= out->outsize) {
351             out->outsize = out->outlen + 64;
352             out->outbuf = srealloc(out->outbuf, out->outsize);
353         }
354         out->outbuf[out->outlen++] = (unsigned char)(out->outbits & 0xFF);
355         out->outbits >>= 8;
356         out->noutbits -= 8;
357     }
358 }
359
360 static const unsigned char mirrorbytes[256] = {
361     0x00, 0x80, 0x40, 0xc0, 0x20, 0xa0, 0x60, 0xe0,
362     0x10, 0x90, 0x50, 0xd0, 0x30, 0xb0, 0x70, 0xf0,
363     0x08, 0x88, 0x48, 0xc8, 0x28, 0xa8, 0x68, 0xe8,
364     0x18, 0x98, 0x58, 0xd8, 0x38, 0xb8, 0x78, 0xf8,
365     0x04, 0x84, 0x44, 0xc4, 0x24, 0xa4, 0x64, 0xe4,
366     0x14, 0x94, 0x54, 0xd4, 0x34, 0xb4, 0x74, 0xf4,
367     0x0c, 0x8c, 0x4c, 0xcc, 0x2c, 0xac, 0x6c, 0xec,
368     0x1c, 0x9c, 0x5c, 0xdc, 0x3c, 0xbc, 0x7c, 0xfc,
369     0x02, 0x82, 0x42, 0xc2, 0x22, 0xa2, 0x62, 0xe2,
370     0x12, 0x92, 0x52, 0xd2, 0x32, 0xb2, 0x72, 0xf2,
371     0x0a, 0x8a, 0x4a, 0xca, 0x2a, 0xaa, 0x6a, 0xea,
372     0x1a, 0x9a, 0x5a, 0xda, 0x3a, 0xba, 0x7a, 0xfa,
373     0x06, 0x86, 0x46, 0xc6, 0x26, 0xa6, 0x66, 0xe6,
374     0x16, 0x96, 0x56, 0xd6, 0x36, 0xb6, 0x76, 0xf6,
375     0x0e, 0x8e, 0x4e, 0xce, 0x2e, 0xae, 0x6e, 0xee,
376     0x1e, 0x9e, 0x5e, 0xde, 0x3e, 0xbe, 0x7e, 0xfe,
377     0x01, 0x81, 0x41, 0xc1, 0x21, 0xa1, 0x61, 0xe1,
378     0x11, 0x91, 0x51, 0xd1, 0x31, 0xb1, 0x71, 0xf1,
379     0x09, 0x89, 0x49, 0xc9, 0x29, 0xa9, 0x69, 0xe9,
380     0x19, 0x99, 0x59, 0xd9, 0x39, 0xb9, 0x79, 0xf9,
381     0x05, 0x85, 0x45, 0xc5, 0x25, 0xa5, 0x65, 0xe5,
382     0x15, 0x95, 0x55, 0xd5, 0x35, 0xb5, 0x75, 0xf5,
383     0x0d, 0x8d, 0x4d, 0xcd, 0x2d, 0xad, 0x6d, 0xed,
384     0x1d, 0x9d, 0x5d, 0xdd, 0x3d, 0xbd, 0x7d, 0xfd,
385     0x03, 0x83, 0x43, 0xc3, 0x23, 0xa3, 0x63, 0xe3,
386     0x13, 0x93, 0x53, 0xd3, 0x33, 0xb3, 0x73, 0xf3,
387     0x0b, 0x8b, 0x4b, 0xcb, 0x2b, 0xab, 0x6b, 0xeb,
388     0x1b, 0x9b, 0x5b, 0xdb, 0x3b, 0xbb, 0x7b, 0xfb,
389     0x07, 0x87, 0x47, 0xc7, 0x27, 0xa7, 0x67, 0xe7,
390     0x17, 0x97, 0x57, 0xd7, 0x37, 0xb7, 0x77, 0xf7,
391     0x0f, 0x8f, 0x4f, 0xcf, 0x2f, 0xaf, 0x6f, 0xef,
392     0x1f, 0x9f, 0x5f, 0xdf, 0x3f, 0xbf, 0x7f, 0xff,
393 };
394
395 typedef struct {
396     int code, extrabits, min, max;
397 } coderecord;
398
399 static const coderecord lencodes[] = {
400     {257, 0, 3,3},
401     {258, 0, 4,4},
402     {259, 0, 5,5},
403     {260, 0, 6,6},
404     {261, 0, 7,7},
405     {262, 0, 8,8},
406     {263, 0, 9,9},
407     {264, 0, 10,10},
408     {265, 1, 11,12},
409     {266, 1, 13,14},
410     {267, 1, 15,16},
411     {268, 1, 17,18},
412     {269, 2, 19,22},
413     {270, 2, 23,26},
414     {271, 2, 27,30},
415     {272, 2, 31,34},
416     {273, 3, 35,42},
417     {274, 3, 43,50},
418     {275, 3, 51,58},
419     {276, 3, 59,66},
420     {277, 4, 67,82},
421     {278, 4, 83,98},
422     {279, 4, 99,114},
423     {280, 4, 115,130},
424     {281, 5, 131,162},
425     {282, 5, 163,194},
426     {283, 5, 195,226},
427     {284, 5, 227,257},
428     {285, 0, 258,258},
429 };
430
431 static const coderecord distcodes[] = {
432     {0, 0, 1,1},
433     {1, 0, 2,2},
434     {2, 0, 3,3},
435     {3, 0, 4,4},
436     {4, 1, 5,6},
437     {5, 1, 7,8},
438     {6, 2, 9,12},
439     {7, 2, 13,16},
440     {8, 3, 17,24},
441     {9, 3, 25,32},
442     {10, 4, 33,48},
443     {11, 4, 49,64},
444     {12, 5, 65,96},
445     {13, 5, 97,128},
446     {14, 6, 129,192},
447     {15, 6, 193,256},
448     {16, 7, 257,384},
449     {17, 7, 385,512},
450     {18, 8, 513,768},
451     {19, 8, 769,1024},
452     {20, 9, 1025,1536},
453     {21, 9, 1537,2048},
454     {22, 10, 2049,3072},
455     {23, 10, 3073,4096},
456     {24, 11, 4097,6144},
457     {25, 11, 6145,8192},
458     {26, 12, 8193,12288},
459     {27, 12, 12289,16384},
460     {28, 13, 16385,24576},
461     {29, 13, 24577,32768},
462 };
463
464 static void zlib_literal(struct LZ77Context *ectx, unsigned char c) {
465     struct Outbuf *out = (struct Outbuf *)ectx->userdata;
466
467     if (out->comp_disabled) {
468         /*
469          * We're in an uncompressed block, so just output the byte.
470          */
471         outbits(out, c, 8);
472         return;
473     }
474
475     if (c <= 143) {
476         /* 0 through 143 are 8 bits long starting at 00110000. */
477         outbits(out, mirrorbytes[0x30 + c], 8);
478     } else {
479         /* 144 through 255 are 9 bits long starting at 110010000. */
480         outbits(out, 1 + 2*mirrorbytes[0x90 - 144 + c], 9);
481     }
482 }
483
484 static void zlib_match(struct LZ77Context *ectx, int distance, int len) {
485     const coderecord *d, *l;
486     int i, j, k;
487     struct Outbuf *out = (struct Outbuf *)ectx->userdata;
488
489     assert(!out->comp_disabled);
490
491     while (len > 0) {
492         int thislen;
493         
494         /*
495          * We can transmit matches of lengths 3 through 258
496          * inclusive. So if len exceeds 258, we must transmit in
497          * several steps, with 258 or less in each step.
498          * 
499          * Specifically: if len >= 261, we can transmit 258 and be
500          * sure of having at least 3 left for the next step. And if
501          * len <= 258, we can just transmit len. But if len == 259
502          * or 260, we must transmit len-3.
503          */
504         thislen = (len > 260 ? 258 : len <= 258 ? len : len-3);
505         len -= thislen;
506
507         /*
508          * Binary-search to find which length code we're
509          * transmitting.
510          */
511         i = -1; j = sizeof(lencodes)/sizeof(*lencodes);
512         while (j - i >= 2) {
513             k = (j+i)/2;
514             if (thislen < lencodes[k].min)
515                 j = k;
516             else if (thislen > lencodes[k].max)
517                 i = k;
518             else {
519                 l = &lencodes[k];
520                 break;                 /* found it! */
521             }
522         }
523
524         /*
525          * Transmit the length code. 256-279 are seven bits
526          * starting at 0000000; 280-287 are eight bits starting at
527          * 11000000.
528          */
529         if (l->code <= 279) {
530             outbits(out, mirrorbytes[(l->code-256)*2], 7);
531         } else {
532             outbits(out, mirrorbytes[0xc0 - 280 + l->code], 8);
533         }
534
535         /*
536          * Transmit the extra bits.
537          */
538         if (l->extrabits)
539             outbits(out, thislen - l->min, l->extrabits);
540
541         /*
542          * Binary-search to find which distance code we're
543          * transmitting.
544          */
545         i = -1; j = sizeof(distcodes)/sizeof(*distcodes);
546         while (j - i >= 2) {
547             k = (j+i)/2;
548             if (distance < distcodes[k].min)
549                 j = k;
550             else if (distance > distcodes[k].max)
551                 i = k;
552             else {
553                 d = &distcodes[k];
554                 break;                 /* found it! */
555             }
556         }
557
558         /*
559          * Transmit the distance code. Five bits starting at 00000.
560          */
561         outbits(out, mirrorbytes[d->code*8], 5);
562
563         /*
564          * Transmit the extra bits.
565          */
566         if (d->extrabits)
567             outbits(out, distance - d->min, d->extrabits);
568     }
569 }
570
571 void zlib_compress_init(void) {
572     struct Outbuf *out;
573
574     lz77_init(&ectx);
575     ectx.literal = zlib_literal;
576     ectx.match = zlib_match;
577
578     out = smalloc(sizeof(struct Outbuf));
579     out->outbits = out->noutbits = 0;
580     out->firstblock = 1;
581     out->comp_disabled = FALSE;
582     ectx.userdata = out;
583
584     logevent("Initialised zlib (RFC1950) compression");
585 }
586
587 /*
588  * Turn off actual LZ77 analysis for one block, to facilitate
589  * construction of a precise-length IGNORE packet. Returns the
590  * length adjustment (which is only valid for packets < 65536
591  * bytes, but that seems reasonable enough).
592  */
593 int zlib_disable_compression(void) {
594     struct Outbuf *out = (struct Outbuf *)ectx.userdata;
595     int n, startbits;
596
597     out->comp_disabled = TRUE;
598
599     n = 0;
600     /*
601      * If this is the first block, we will start by outputting two
602      * header bytes, and then three bits to begin an uncompressed
603      * block. This will cost three bytes (because we will start on
604      * a byte boundary, this is certain).
605      */
606     if (out->firstblock) {
607         n = 3;
608     } else {
609         /*
610          * Otherwise, we will output seven bits to close the
611          * previous static block, and _then_ three bits to begin an
612          * uncompressed block, and then flush the current byte.
613          * This may cost two bytes or three, depending on noutbits.
614          */
615         n += (out->noutbits + 10) / 8;
616     }
617
618     /*
619      * Now we output four bytes for the length / ~length pair in
620      * the uncompressed block.
621      */
622     n += 4;
623
624     return n;
625 }
626
627 int zlib_compress_block(unsigned char *block, int len,
628                         unsigned char **outblock, int *outlen) {
629     struct Outbuf *out = (struct Outbuf *)ectx.userdata;
630     int in_block;
631
632     out->outbuf = NULL;
633     out->outlen = out->outsize = 0;
634
635     /*
636      * If this is the first block, output the Zlib (RFC1950) header
637      * bytes 78 9C. (Deflate compression, 32K window size, default
638      * algorithm.)
639      */
640     if (out->firstblock) {
641         outbits(out, 0x9C78, 16);
642         out->firstblock = 0;
643
644         in_block = FALSE;
645     }
646
647     if (out->comp_disabled) {
648         if (in_block)
649             outbits(out, 0, 7);                /* close static block */
650
651         while (len > 0) {
652             int blen = (len < 65535 ? len : 65535);
653
654             /*
655              * Start a Deflate (RFC1951) uncompressed block. We
656              * transmit a zero bit (BFINAL=0), followed by a zero
657              * bit and a one bit (BTYPE=00). Of course these are in
658              * the wrong order (00 0).
659              */
660             outbits(out, 0, 3);
661
662             /*
663              * Output zero bits to align to a byte boundary.
664              */
665             if (out->noutbits)
666                 outbits(out, 0, 8 - out->noutbits);
667
668             /*
669              * Output the block length, and then its one's
670              * complement. They're little-endian, so all we need to
671              * do is pass them straight to outbits() with bit count
672              * 16.
673              */
674             outbits(out, blen, 16);
675             outbits(out, blen ^ 0xFFFF, 16);
676
677             /*
678              * Do the `compression': we need to pass the data to
679              * lz77_compress so that it will be taken into account
680              * for subsequent (distance,length) pairs. But
681              * lz77_compress is passed FALSE, which means it won't
682              * actually find (or even look for) any matches; so
683              * every character will be passed straight to
684              * zlib_literal which will spot out->comp_disabled and
685              * emit in the uncompressed format.
686              */
687             lz77_compress(&ectx, block, blen, FALSE);
688
689             len -= blen;
690             block += blen;
691         }
692         outbits(out, 2, 3);                    /* open new block */
693     } else {
694         if (!in_block) {
695             /*
696              * Start a Deflate (RFC1951) fixed-trees block. We
697              * transmit a zero bit (BFINAL=0), followed by a zero
698              * bit and a one bit (BTYPE=01). Of course these are in
699              * the wrong order (01 0).
700              */
701             outbits(out, 2, 3);
702         }
703
704         /*
705          * Do the compression.
706          */
707         lz77_compress(&ectx, block, len, TRUE);
708
709         /*
710          * End the block (by transmitting code 256, which is
711          * 0000000 in fixed-tree mode), and transmit some empty
712          * blocks to ensure we have emitted the byte containing the
713          * last piece of genuine data. There are three ways we can
714          * do this:
715          *
716          *  - Minimal flush. Output end-of-block and then open a
717          *    new static block. This takes 9 bits, which is
718          *    guaranteed to flush out the last genuine code in the
719          *    closed block; but allegedly zlib can't handle it.
720          *
721          *  - Zlib partial flush. Output EOB, open and close an
722          *    empty static block, and _then_ open the new block.
723          *    This is the best zlib can handle.
724          *
725          *  - Zlib sync flush. Output EOB, then an empty
726          *    _uncompressed_ block (000, then sync to byte
727          *    boundary, then send bytes 00 00 FF FF). Then open the
728          *    new block.
729          *
730          * For the moment, we will use Zlib partial flush.
731          */
732         outbits(out, 0, 7);                    /* close block */
733         outbits(out, 2, 3+7);          /* empty static block */
734         outbits(out, 2, 3);                    /* open new block */
735     }
736
737     out->comp_disabled = FALSE;
738
739     *outblock = out->outbuf;
740     *outlen = out->outlen;
741
742     return 1;
743 }
744
745 /* ----------------------------------------------------------------------
746  * Zlib decompression. Of course, even though our compressor always
747  * uses static trees, our _decompressor_ has to be capable of
748  * handling dynamic trees if it sees them.
749  */
750
751 /*
752  * The way we work the Huffman decode is to have a table lookup on
753  * the first N bits of the input stream (in the order they arrive,
754  * of course, i.e. the first bit of the Huffman code is in bit 0).
755  * Each table entry lists the number of bits to consume, plus
756  * either an output code or a pointer to a secondary table.
757  */
758 struct zlib_table;
759 struct zlib_tableentry;
760
761 struct zlib_tableentry {
762     unsigned char nbits;
763     int code;
764     struct zlib_table *nexttable;
765 };
766
767 struct zlib_table {
768     int mask;                          /* mask applied to input bit stream */
769     struct zlib_tableentry *table;
770 };
771
772 #define MAXCODELEN 16
773 #define MAXSYMS 288
774
775 /*
776  * Build a single-level decode table for elements
777  * [minlength,maxlength) of the provided code/length tables, and
778  * recurse to build subtables.
779  */
780 static struct zlib_table *zlib_mkonetab(int *codes, unsigned char *lengths,
781                                         int nsyms,
782                                         int pfx, int pfxbits, int bits) {
783     struct zlib_table *tab = smalloc(sizeof(struct zlib_table));
784     int pfxmask = (1 << pfxbits) - 1;
785     int nbits, i, j, code;
786
787     tab->table = smalloc((1 << bits) * sizeof(struct zlib_tableentry));
788     tab->mask = (1 << bits) - 1;
789
790     for (code = 0; code <= tab->mask; code++) {
791         tab->table[code].code = -1;
792         tab->table[code].nbits = 0;
793         tab->table[code].nexttable = NULL;
794     }
795
796     for (i = 0; i < nsyms; i++) {
797         if (lengths[i] <= pfxbits || (codes[i] & pfxmask) != pfx)
798             continue;
799         code = (codes[i] >> pfxbits) & tab->mask;
800         for (j = code; j <= tab->mask; j += 1 << (lengths[i]-pfxbits)) {
801             tab->table[j].code = i;
802             nbits = lengths[i] - pfxbits;
803             if (tab->table[j].nbits < nbits)
804                 tab->table[j].nbits = nbits;
805         }
806     }
807     for (code = 0; code <= tab->mask; code++) {
808         if (tab->table[code].nbits <= bits)
809             continue;
810         /* Generate a subtable. */
811         tab->table[code].code = -1;
812         nbits = tab->table[code].nbits - bits;
813         if (nbits > 7)
814             nbits = 7;
815         tab->table[code].nbits = bits;
816         tab->table[code].nexttable = zlib_mkonetab(codes, lengths, nsyms,
817                                                    pfx | (code << pfxbits),
818                                                    pfxbits + bits, nbits);
819     }
820
821     return tab;
822 }
823
824 /*
825  * Build a decode table, given a set of Huffman tree lengths.
826  */
827 static struct zlib_table *zlib_mktable(unsigned char *lengths, int nlengths) {
828     int count[MAXCODELEN], startcode[MAXCODELEN], codes[MAXSYMS];
829     int code, maxlen;
830     int i, j;
831
832     /* Count the codes of each length. */
833     maxlen = 0;
834     for (i = 1; i < MAXCODELEN; i++) count[i] = 0;
835     for (i = 0; i < nlengths; i++) {
836         count[lengths[i]]++;
837         if (maxlen < lengths[i])
838             maxlen = lengths[i];
839     }
840     /* Determine the starting code for each length block. */
841     code = 0;
842     for (i = 1; i < MAXCODELEN; i++) {
843         startcode[i] = code;
844         code += count[i];
845         code <<= 1;
846     }
847     /* Determine the code for each symbol. Mirrored, of course. */
848     for (i = 0; i < nlengths; i++) {
849         code = startcode[lengths[i]]++;
850         codes[i] = 0;
851         for (j = 0; j < lengths[i]; j++) {
852             codes[i] = (codes[i] << 1) | (code & 1);
853             code >>= 1;
854         }
855     }
856
857     /*
858      * Now we have the complete list of Huffman codes. Build a
859      * table.
860      */
861     return zlib_mkonetab(codes, lengths, nlengths, 0, 0,
862                          maxlen < 9 ? maxlen : 9);
863 }
864
865 static int zlib_freetable(struct zlib_table ** ztab) {
866     struct zlib_table *tab;
867     int code;
868
869     if (ztab == NULL)
870         return -1;
871
872     if (*ztab == NULL)
873         return 0;
874
875     tab = *ztab;
876
877     for (code = 0; code <= tab->mask; code++)
878         if (tab->table[code].nexttable != NULL)
879             zlib_freetable(&tab->table[code].nexttable);
880
881     sfree(tab->table);
882     tab->table = NULL;
883
884     sfree(tab);
885     *ztab = NULL;
886
887     return(0);
888 }
889
890 static struct zlib_decompress_ctx {
891     struct zlib_table *staticlentable, *staticdisttable;
892     struct zlib_table *currlentable, *currdisttable, *lenlentable;
893     enum {
894         START, OUTSIDEBLK,
895         TREES_HDR, TREES_LENLEN, TREES_LEN, TREES_LENREP,
896         INBLK, GOTLENSYM, GOTLEN, GOTDISTSYM,
897         UNCOMP_LEN, UNCOMP_NLEN, UNCOMP_DATA
898     } state;
899     int sym, hlit, hdist, hclen, lenptr, lenextrabits, lenaddon, len, lenrep;
900     int uncomplen;
901     unsigned char lenlen[19];
902     unsigned char lengths[286+32];
903     unsigned long bits;
904     int nbits;
905     unsigned char window[WINSIZE];
906     int winpos;
907     unsigned char *outblk;
908     int outlen, outsize;
909 } dctx;
910
911 void zlib_decompress_init(void) {
912     unsigned char lengths[288];
913     memset(lengths, 8, 144);
914     memset(lengths+144, 9, 256-144);
915     memset(lengths+256, 7, 280-256);
916     memset(lengths+280, 8, 288-280);
917     dctx.staticlentable = zlib_mktable(lengths, 288);
918     memset(lengths, 5, 32);
919     dctx.staticdisttable = zlib_mktable(lengths, 32);
920     dctx.state = START;                /* even before header */
921     dctx.currlentable = dctx.currdisttable = dctx.lenlentable = NULL;
922     dctx.bits = 0;
923     dctx.nbits = 0;
924     logevent("Initialised zlib (RFC1950) decompression");
925 }
926
927 int zlib_huflookup(unsigned long *bitsp, int *nbitsp, struct zlib_table *tab) {
928     unsigned long bits = *bitsp;
929     int nbits = *nbitsp;
930     while (1) {
931         struct zlib_tableentry *ent;
932         ent = &tab->table[bits & tab->mask];
933         if (ent->nbits > nbits)
934             return -1;                 /* not enough data */
935         bits >>= ent->nbits;
936         nbits -= ent->nbits;
937         if (ent->code == -1)
938             tab = ent->nexttable;
939         else {
940             *bitsp = bits;
941             *nbitsp = nbits;
942             return ent->code;
943         }
944     }
945 }
946
947 static void zlib_emit_char(int c) {
948     dctx.window[dctx.winpos] = c;
949     dctx.winpos = (dctx.winpos + 1) & (WINSIZE-1);
950     if (dctx.outlen >= dctx.outsize) {
951         dctx.outsize = dctx.outlen + 512;
952         dctx.outblk = srealloc(dctx.outblk, dctx.outsize);
953     }
954     dctx.outblk[dctx.outlen++] = c;
955 }
956
957 #define EATBITS(n) ( dctx.nbits -= (n), dctx.bits >>= (n) )
958
959 int zlib_decompress_block(unsigned char *block, int len,
960                           unsigned char **outblock, int *outlen) {
961     const coderecord *rec;
962     int code, blktype, rep, dist, nlen;
963     static const unsigned char lenlenmap[] = {
964         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
965     };
966
967     dctx.outblk = NULL;
968     dctx.outsize = dctx.outlen = 0;
969
970     while (len > 0 || dctx.nbits > 0) {
971         while (dctx.nbits < 24 && len > 0) {
972             dctx.bits |= (*block++) << dctx.nbits;
973             dctx.nbits += 8;
974             len--;
975         }
976         switch (dctx.state) {
977           case START:
978             /* Expect 16-bit zlib header, which we'll dishonourably ignore. */
979             if (dctx.nbits < 16)
980                 goto finished;         /* done all we can */
981             EATBITS(16);
982             dctx.state = OUTSIDEBLK;
983             break;
984           case OUTSIDEBLK:
985             /* Expect 3-bit block header. */
986             if (dctx.nbits < 3)
987                 goto finished;         /* done all we can */
988             EATBITS(1);
989             blktype = dctx.bits & 3;
990             EATBITS(2);
991             if (blktype == 0) {
992                 int to_eat = dctx.nbits & 7;
993                 dctx.state = UNCOMP_LEN;
994                 EATBITS(to_eat);       /* align to byte boundary */
995             } else if (blktype == 1) {
996                 dctx.currlentable = dctx.staticlentable;
997                 dctx.currdisttable = dctx.staticdisttable;
998                 dctx.state = INBLK;
999             } else if (blktype == 2) {
1000                 dctx.state = TREES_HDR;
1001             }
1002             break;
1003           case TREES_HDR:
1004             /*
1005              * Dynamic block header. Five bits of HLIT, five of
1006              * HDIST, four of HCLEN.
1007              */
1008             if (dctx.nbits < 5+5+4)
1009                 goto finished;         /* done all we can */
1010             dctx.hlit = 257 + (dctx.bits & 31); EATBITS(5);
1011             dctx.hdist = 1 + (dctx.bits & 31); EATBITS(5);
1012             dctx.hclen = 4 + (dctx.bits & 15); EATBITS(4);
1013             dctx.lenptr = 0;
1014             dctx.state = TREES_LENLEN;
1015             memset(dctx.lenlen, 0, sizeof(dctx.lenlen));
1016             break;
1017           case TREES_LENLEN:
1018             if (dctx.nbits < 3)
1019                 goto finished;
1020             while (dctx.lenptr < dctx.hclen && dctx.nbits >= 3) {
1021                 dctx.lenlen[lenlenmap[dctx.lenptr++]] =
1022                     (unsigned char)(dctx.bits & 7);
1023                 EATBITS(3);
1024             }
1025             if (dctx.lenptr == dctx.hclen) {
1026                 dctx.lenlentable = zlib_mktable(dctx.lenlen, 19);
1027                 dctx.state = TREES_LEN;
1028                 dctx.lenptr = 0;
1029             }
1030             break;
1031           case TREES_LEN:
1032             if (dctx.lenptr >= dctx.hlit+dctx.hdist) {
1033                 dctx.currlentable = zlib_mktable(dctx.lengths, dctx.hlit);
1034                 dctx.currdisttable = zlib_mktable(dctx.lengths + dctx.hlit,
1035                                                   dctx.hdist);
1036                 zlib_freetable(&dctx.lenlentable);
1037                 dctx.state = INBLK;
1038                 break;
1039             }
1040             code = zlib_huflookup(&dctx.bits, &dctx.nbits, dctx.lenlentable);
1041             if (code == -1)
1042                 goto finished;
1043             if (code < 16)
1044                 dctx.lengths[dctx.lenptr++] = code;
1045             else {
1046                 dctx.lenextrabits = (code == 16 ? 2 : code == 17 ? 3 : 7);
1047                 dctx.lenaddon = (code == 18 ? 11 : 3);
1048                 dctx.lenrep = (code == 16 && dctx.lenptr > 0 ?
1049                                dctx.lengths[dctx.lenptr-1] : 0);
1050                 dctx.state = TREES_LENREP;
1051             }
1052             break;
1053           case TREES_LENREP:
1054             if (dctx.nbits < dctx.lenextrabits)
1055                 goto finished;
1056             rep = dctx.lenaddon + (dctx.bits & ((1<<dctx.lenextrabits)-1));
1057             EATBITS(dctx.lenextrabits);
1058             while (rep > 0 && dctx.lenptr < dctx.hlit+dctx.hdist) {
1059                 dctx.lengths[dctx.lenptr] = dctx.lenrep;
1060                 dctx.lenptr++;
1061                 rep--;
1062             }
1063             dctx.state = TREES_LEN;
1064             break;
1065           case INBLK:
1066             code = zlib_huflookup(&dctx.bits, &dctx.nbits, dctx.currlentable);
1067             if (code == -1)
1068                 goto finished;
1069             if (code < 256)
1070                 zlib_emit_char(code);
1071             else if (code == 256) {
1072                 dctx.state = OUTSIDEBLK;
1073                 if (dctx.currlentable != dctx.staticlentable)
1074                     zlib_freetable(&dctx.currlentable);
1075                 if (dctx.currdisttable != dctx.staticdisttable)
1076                     zlib_freetable(&dctx.currdisttable);
1077             } else if (code < 286) {   /* static tree can give >285; ignore */
1078                 dctx.state = GOTLENSYM;
1079                 dctx.sym = code;
1080             }
1081             break;
1082           case GOTLENSYM:
1083             rec = &lencodes[dctx.sym - 257];
1084             if (dctx.nbits < rec->extrabits)
1085                 goto finished;
1086             dctx.len = rec->min + (dctx.bits & ((1<<rec->extrabits)-1));
1087             EATBITS(rec->extrabits);
1088             dctx.state = GOTLEN;
1089             break;
1090           case GOTLEN:
1091             code = zlib_huflookup(&dctx.bits, &dctx.nbits, dctx.currdisttable);
1092             if (code == -1)
1093                 goto finished;
1094             dctx.state = GOTDISTSYM;
1095             dctx.sym = code;
1096             break;
1097           case GOTDISTSYM:
1098             rec = &distcodes[dctx.sym];
1099             if (dctx.nbits < rec->extrabits)
1100                 goto finished;
1101             dist = rec->min + (dctx.bits & ((1<<rec->extrabits)-1));
1102             EATBITS(rec->extrabits);
1103             dctx.state = INBLK;
1104             while (dctx.len--)
1105                 zlib_emit_char(dctx.window[(dctx.winpos-dist) & (WINSIZE-1)]);
1106             break;
1107           case UNCOMP_LEN:
1108             /*
1109              * Uncompressed block. We expect to see a 16-bit LEN.
1110              */
1111             if (dctx.nbits < 16)
1112                 goto finished;
1113             dctx.uncomplen = dctx.bits & 0xFFFF;
1114             EATBITS(16);
1115             dctx.state = UNCOMP_NLEN;
1116             break;
1117           case UNCOMP_NLEN:
1118             /*
1119              * Uncompressed block. We expect to see a 16-bit NLEN,
1120              * which should be the one's complement of the previous
1121              * LEN.
1122              */
1123             if (dctx.nbits < 16)
1124                 goto finished;
1125             nlen = dctx.bits & 0xFFFF;
1126             EATBITS(16);
1127             dctx.state = UNCOMP_DATA;
1128             break;
1129           case UNCOMP_DATA:
1130             if (dctx.nbits < 8)
1131                 goto finished;
1132             zlib_emit_char(dctx.bits & 0xFF);
1133             EATBITS(8);
1134             if (--dctx.uncomplen == 0)
1135                 dctx.state = OUTSIDEBLK;   /* end of uncompressed block */
1136             break;
1137         }
1138     }
1139
1140     finished:
1141     *outblock = dctx.outblk;
1142     *outlen = dctx.outlen;
1143
1144     return 1;
1145 }
1146
1147 const struct ssh_compress ssh_zlib = {
1148     "zlib",
1149     zlib_compress_init,
1150     zlib_compress_block,
1151     zlib_decompress_init,
1152     zlib_decompress_block,
1153     zlib_disable_compression
1154 };