]> asedeno.scripts.mit.edu Git - pssh.git/blob - data/hostkeys.c
My 2006-07-10 release.
[pssh.git] / data / hostkeys.c
1 /**********
2  * Copyright (c) 2003-2004 Greg Parker.  All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY GREG PARKER ``AS IS'' AND ANY EXPRESS OR
14  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
15  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
16  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
17  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
18  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
19  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
20  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
22  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  **********/
24
25 #include "includes.h"
26 #include "recordlist.h"
27 #include "rsrc/rsrc.h"
28 #include "ssh/openssh/key.h"
29 #include "ssh/openssh/match.h"
30
31 #include "hostkeys.h"
32
33
34 /* host key record format:
35
36    2 bytes length
37    n bytes hostname,hostname,hostname\0
38    2 bytes length
39    n bytes key blob
40 */
41
42
43 #define HostKeyDBName "pssh Known Host Keys"
44 #define HostKeyDBType 'HKey'
45 static DmOpenRef HostKeyDB = 0;
46 static RecordList *HostKeyList = NULL;
47
48 static Boolean ReadHostKeyRecord(uint8_t *recordP, char **hostnames, uint8_t **keyblob, uint16_t *keybloblen) HOSTKEYS_SEGMENT;
49 static Boolean WriteHostKeyRecord(MemPtr recordP, const char *hostnames, uint8_t *keyblob, uint16_t keybloblen) HOSTKEYS_SEGMENT;
50 static void DrawHostKeyRecord(MemPtr recordP, UInt16 index, RectanglePtr bounds) HOSTKEYS_SEGMENT;
51
52 extern DmOpenRef OpenDB(UInt32 type, char *name, Boolean resDB, Boolean create);
53
54 Boolean HostKeysInit(void)
55 {
56     HostKeyDB = OpenDB(HostKeyDBType, HostKeyDBName, false, true);
57     if (!HostKeyDB) return false;
58
59     HostKeyList = 
60         RecordListNew(HostKeyDB, HostKeysFormID, HostKeysFormKeyTableID, 
61                       HostKeysFormKeyScrollbarID, DrawHostKeyRecord);
62     if (!HostKeyList) return false;
63
64     return true;
65 }
66
67
68 void HostKeysFree(void)
69 {
70     RecordListFree(HostKeyList);
71     DmCloseDatabase(HostKeyDB);
72 }
73
74
75 void HostKeysUpdate(void)
76 {
77     return RecordListUpdate(HostKeyList);
78 }
79
80
81 Boolean HostKeysHandleEvent(EventPtr event)
82 {
83     return RecordListHandleEvent(HostKeyList, event);
84 }
85
86
87 UInt16 HostKeysSelectedIndex(void)
88 {
89     return RecordListSelectedIndex(HostKeyList);
90 }
91
92
93 void HostKeysDeleteSelectedRecord(void)
94 {
95     return RecordListDeleteSelectedRecord(HostKeyList);
96 }
97
98
99 static Boolean ReadHostKeyRecord(uint8_t *recordP, char **hostnames, 
100                                  uint8_t **keyblob, uint16_t *keybloblen)
101 {
102 #define CHECK_SPACE(n) do { if (p+(n)>end) goto bad; } while (0)
103
104     uint8_t *p;
105     uint8_t *end;
106     uint16_t len;
107
108     p = recordP;
109     end = recordP + MemPtrSize(recordP);
110
111     // read hostnames string
112
113     CHECK_SPACE(2);
114     len = *(uint16_t *)p;
115     p += 2;
116     CHECK_SPACE(len);
117     *hostnames = (char *)p;
118     if ((*hostnames)[len-1] != '\0') goto bad;
119     p += len;
120
121     // read key blob
122     
123     CHECK_SPACE(2);
124     *keybloblen = *(uint16_t *)p;
125     p += 2;
126     CHECK_SPACE(*keybloblen);
127     *keyblob = p;
128
129     p += *keybloblen;
130
131     // allow trailing data for forward compatibility
132
133     return true;
134
135  bad:
136     return false;
137
138 #undef CHECK_SPACE
139 }
140
141
142 static Boolean WriteHostKeyRecord(MemPtr recordP, const char *hostnames, 
143                                   uint8_t *keyblob, uint16_t keybloblen)
144 {
145     Err err = 0;
146     uint32_t offset = 0;
147     uint16_t hostnameslen = strlen(hostnames) + 1;
148
149     if (!err) err = DmWrite(recordP, offset, &hostnameslen, 2);
150     offset += 2;
151     if (!err) err = DmWrite(recordP, offset, hostnames, hostnameslen);
152     offset += hostnameslen;
153     if (!err) err = DmWrite(recordP, offset, &keybloblen, 2);
154     offset += 2;
155     if (!err) err = DmWrite(recordP, offset, keyblob, keybloblen);
156
157     return (err == 0);
158 }
159
160
161
162 MemHandle HostKeysQuerySelectedRecord(char **hostnames, 
163                                       uint8_t **keyblob, uint16_t *keybloblen)
164 {
165     return HostKeysQueryIndexedRecord(RecordListSelectedIndex(HostKeyList), 
166                                       hostnames, keyblob, keybloblen);
167 }
168
169
170 MemHandle HostKeysQueryIndexedRecord(UInt16 index, char **hostnames, 
171                                      uint8_t **keyblob, uint16_t *keybloblen)
172 {
173     MemHandle recordH;
174     MemPtr recordP;
175     Boolean ok;
176
177     recordH = RecordListQueryIndexedRecord(HostKeyList, index);
178     if (!recordH) return NULL;
179
180     recordP = MemHandleLock(recordH);
181     ok = ReadHostKeyRecord(recordP, hostnames, keyblob, keybloblen);
182     
183     if (!ok) {
184         MemHandleUnlock(recordH);
185         return NULL;
186     } else {
187         return recordH;
188     }
189 }
190
191
192 UInt16 HostKeysFindRecordForHostname(const char *hostname)
193 {
194     UInt16 count = RecordListCount(HostKeyList);
195     UInt16 index;
196     MemHandle recordH;
197     char *savedhosts;
198     uint8_t *keyblob;
199     uint16_t keybloblen;
200
201     for (index = 0; index < count; index++) {
202         if ((recordH = HostKeysQueryIndexedRecord(index, &savedhosts, 
203                                                   &keyblob, &keybloblen)))
204         {
205             char *host = match_list(hostname, savedhosts, NULL);
206             MemHandleUnlock(recordH);
207             if (host) {
208                 xfree(host);
209                 return index;
210             }
211         }
212     }
213
214     return noRecord;
215 }
216
217
218 UInt16 HostKeysFindRecordForKey(Key *hostkey)
219 {
220     UInt16 count = RecordListCount(HostKeyList);
221     UInt16 index;
222     MemHandle recordH;
223     char *savedhosts;
224     uint8_t *keyblob;
225     uint16_t keybloblen;
226
227     for (index = 0; index < count; index++) {
228         if ((recordH = HostKeysQueryIndexedRecord(index, &savedhosts, 
229                                                   &keyblob, &keybloblen)))
230         {
231             Key *savedkey = key_from_blob(keyblob, keybloblen);
232             Boolean match = key_equal(savedkey, hostkey);
233             key_free(savedkey);
234             MemHandleUnlock(recordH);
235             
236             if (match) return index;
237         }
238     }
239
240     return noRecord;
241 }
242
243
244 Boolean HostKeysAddRecord(const char *hostname, Key *hostkey)
245 {
246     uint16_t keybloblen = 1;
247     uint32_t recordlen;
248     MemHandle recordH;
249     MemPtr recordP;
250     uint8_t *keyblob;
251     Boolean ok;
252
253     key_to_blob(hostkey, &keyblob, &keybloblen);
254
255     recordlen = 2L + strlen(hostname) + 1 + 2 + keybloblen;
256     RecordListClearSelection(HostKeyList);
257     recordH = RecordListGetSelectedRecord(HostKeyList, recordlen);
258     if (!recordH) {
259         xfree(keyblob);
260         return false;
261     }
262
263     recordP = MemHandleLock(recordH);
264
265     ok = WriteHostKeyRecord(recordP, hostname, keyblob, keybloblen);
266
267     // fixme delete new record on failure
268     xfree(keyblob);
269     MemHandleUnlock(recordH);
270     RecordListReleaseRecord(HostKeyList, recordH, true);
271
272     return ok;
273 }
274
275
276 Boolean HostKeysAddHostnameToRecord(const char *hostname, UInt16 index)
277 {
278     MemHandle recordH;
279     MemPtr recordP;
280     uint32_t recordlen;
281     char *hostnames;
282     uint8_t *keyblob;
283     uint16_t keybloblen;
284     Boolean result = false;
285     Boolean ok;
286     char *newhostnames;
287     uint8_t *newkeyblob;
288
289     RecordListClearSelection(HostKeyList);
290     RecordListSetSelectedIndex(HostKeyList, index);
291
292     recordH = RecordListQuerySelectedRecord(HostKeyList);
293     if (!recordH) return false;
294
295     // new length includes comma and new hostname
296     // (assumes record contains at least 1 hostname, and that 
297     // the new hostname isn't there already)
298     recordlen = MemHandleSize(recordH) + strlen(hostname) + strlen(",");
299
300     recordH = DmResizeRecord(HostKeyDB, index, recordlen);
301     if (!recordH) return false;
302     recordP = MemHandleLock(recordH);
303
304     ok = ReadHostKeyRecord(recordP, &hostnames, &keyblob, &keybloblen);
305     if (!ok) goto bad;
306     
307     newhostnames = arena_malloc(strlen(hostnames) + strlen(",") + strlen(hostname) + 1);
308     strcpy(newhostnames, hostnames);
309     strcat(newhostnames, ",");
310     strcat(newhostnames, hostname);
311
312     newkeyblob = arena_malloc(keybloblen);
313     memcpy(newkeyblob, keyblob, keybloblen);
314     
315     ok = WriteHostKeyRecord(recordP, newhostnames, newkeyblob, keybloblen);
316     arena_free(newhostnames);
317     arena_free(newkeyblob);
318     if (!ok) goto bad;
319
320     result = true;
321
322  bad:
323     // fixme destroy record on failure
324     MemHandleUnlock(recordH);
325     RecordListReleaseRecord(HostKeyList, recordH, true);
326     return result;
327 }
328
329
330
331 Boolean HostKeysRemoveHostnameFromRecord(const char *hostname, UInt16 index)
332 {
333     MemHandle recordH;
334     MemPtr recordP;
335     uint32_t recordlen;
336     char *hostnames;
337     uint8_t *keyblob;
338     uint16_t keybloblen;
339     Boolean ok;
340     char *newhostnames;
341     uint8_t *newkeyblob;
342     char *start, *end;
343
344     RecordListClearSelection(HostKeyList);
345     RecordListSetSelectedIndex(HostKeyList, index);
346
347     recordH = RecordListQuerySelectedRecord(HostKeyList);
348     if (!recordH) return false;
349     recordP = MemHandleLock(recordH);
350
351     ok = ReadHostKeyRecord(recordP, &hostnames, &keyblob, &keybloblen);
352     if (!ok) {
353         MemHandleUnlock(recordH);
354         return false;
355     }
356
357     // If this is the only hostname in the record - kill it completely
358     if (0 == strcasecmp(hostnames, hostname)) {
359         MemHandleUnlock(recordH);
360         RecordListDeleteSelectedRecord(HostKeyList);
361         return true;
362     }
363
364     // Make a copy of newhostnames that does not include hostname
365     newhostnames = arena_strdup(hostnames);
366     start = end = NULL;
367     if (!start) {
368         // try ...,hostname,...
369         char *commahostnamecomma = arena_malloc(1 + strlen(hostname) + 1 + 1);
370         strcpy(commahostnamecomma, ",");
371         strcat(commahostnamecomma, hostname);
372         strcat(commahostnamecomma, ",");
373         start = strcasestr(newhostnames, commahostnamecomma);
374         if (start) end = start + strlen(commahostnamecomma);
375         arena_free(commahostnamecomma);
376     }
377     if (!start) {
378         // try hostname,...
379         char *hostnamecomma = arena_malloc(strlen(hostname) + 1 + 1);
380         strcpy(hostnamecomma, hostname);
381         strcat(hostnamecomma, ",");
382         if (0 == strncmp(newhostnames, hostnamecomma, strlen(hostnamecomma))) {
383             start = newhostnames;
384             end = start + strlen(hostnamecomma);
385         }
386         arena_free(hostnamecomma);
387     }
388     if (!start) {
389         // try ...,hostname
390         char *commahostname = arena_malloc(1 + strlen(hostname) + 1);
391         strcpy(commahostname, ",");
392         strcat(commahostname, hostname);
393         start = strrchr(newhostnames, ',');
394         if (start  &&  0 == strcmp(start, commahostname)) {
395             end = start + strlen(commahostname);
396         }
397         arena_free(commahostname);
398     }
399     if (!start) {
400         // didn't find hostname in hostname list
401         MemHandleUnlock(recordH);
402         arena_free(newhostnames);
403         return false;
404     }
405
406     // kill [start..end]
407     memmove(start, end, strlen(end)+1);
408     
409     newkeyblob = arena_malloc(keybloblen);
410     memcpy(newkeyblob, keyblob, keybloblen);
411     
412     recordlen = 2L + strlen(newhostnames) + 1 + 2 + keybloblen;
413     
414     // Reopen the record for writing and resize
415     MemHandleUnlock(recordH);
416     recordH = DmResizeRecord(HostKeyDB, RecordListSelectedIndex(HostKeyList), 
417                              recordlen);
418     if (!recordH) {
419         arena_free(newhostnames);
420         arena_free(newkeyblob);
421         return false;
422     }
423
424     // Write the new record data
425     recordP = MemHandleLock(recordH);
426     ok = WriteHostKeyRecord(recordP, newhostnames, newkeyblob, keybloblen);
427     MemHandleUnlock(recordH);
428     RecordListReleaseRecord(HostKeyList, recordH, true);
429     arena_free(newhostnames);
430     arena_free(newkeyblob);
431     if (!ok) return false;
432
433     return true;
434 }
435
436
437 static void DrawHostKeyRecord(MemPtr recordP, UInt16 index, 
438                               RectanglePtr bounds)
439 {
440     char *hostnames;
441     uint8_t *keyblob;
442     uint16_t keybloblen;
443
444     if (ReadHostKeyRecord(recordP, &hostnames, 
445                           &keyblob, &keybloblen))
446     {
447         // "hostname,hostname,hostname"
448         int len;
449         int x = bounds->topLeft.x + 1;
450         int y = bounds->topLeft.y;
451
452         len = StrLen(hostnames);
453         WinDrawTruncChars(hostnames, len, x, y, 
454                           bounds->topLeft.x + bounds->extent.x - x - 1);
455         x += FntCharsWidth(hostnames, len);
456     }
457 }