]> asedeno.scripts.mit.edu Git - PuTTY.git/blob - windows/winsftp.c
General mechanism for ensuring a dodgy SFTP server can't return
[PuTTY.git] / windows / winsftp.c
1 /*
2  * winsftp.c: the Windows-specific parts of PSFTP and PSCP.
3  */
4
5 #include <assert.h>
6
7 #include "putty.h"
8 #include "psftp.h"
9
10 /* ----------------------------------------------------------------------
11  * Interface to GUI driver program.
12  */
13
14 /* This is just a base value from which the main message numbers are
15  * derived. */
16 #define   WM_APP_BASE           0x8000
17
18 /* These two pass a single character value in wParam. They represent
19  * the visible output from PSCP. */
20 #define   WM_STD_OUT_CHAR       ( WM_APP_BASE+400 )
21 #define   WM_STD_ERR_CHAR       ( WM_APP_BASE+401 )
22
23 /* These pass a transfer status update. WM_STATS_CHAR passes a single
24  * character in wParam, and is called repeatedly to pass the name of
25  * the file, terminated with "\n". WM_STATS_SIZE passes the size of
26  * the file being transferred in wParam. WM_STATS_ELAPSED is called
27  * to pass the elapsed time (in seconds) in wParam, and
28  * WM_STATS_PERCENT passes the percentage of the transfer which is
29  * complete, also in wParam. */
30 #define   WM_STATS_CHAR         ( WM_APP_BASE+402 )
31 #define   WM_STATS_SIZE         ( WM_APP_BASE+403 )
32 #define   WM_STATS_PERCENT      ( WM_APP_BASE+404 )
33 #define   WM_STATS_ELAPSED      ( WM_APP_BASE+405 )
34
35 /* These are used at the end of a run to pass an error code in
36  * wParam: zero means success, nonzero means failure. WM_RET_ERR_CNT
37  * is used after a copy, and WM_LS_RET_ERR_CNT is used after a file
38  * list operation. */
39 #define   WM_RET_ERR_CNT        ( WM_APP_BASE+406 )
40 #define   WM_LS_RET_ERR_CNT     ( WM_APP_BASE+407 )
41
42 /* More transfer status update messages. WM_STATS_DONE passes the
43  * number of bytes sent so far in wParam. WM_STATS_ETA passes the
44  * estimated time to completion (in seconds). WM_STATS_RATEBS passes
45  * the average transfer rate (in bytes per second). */
46 #define   WM_STATS_DONE         ( WM_APP_BASE+408 )
47 #define   WM_STATS_ETA          ( WM_APP_BASE+409 )
48 #define   WM_STATS_RATEBS       ( WM_APP_BASE+410 )
49
50 #define NAME_STR_MAX 2048
51 static char statname[NAME_STR_MAX + 1];
52 static unsigned long statsize = 0;
53 static unsigned long statdone = 0;
54 static unsigned long stateta = 0;
55 static unsigned long statratebs = 0;
56 static int statperct = 0;
57 static unsigned long statelapsed = 0;
58
59 static HWND gui_hwnd = NULL;
60
61 static void send_msg(HWND h, UINT message, WPARAM wParam)
62 {
63     while (!PostMessage(h, message, wParam, 0))
64         SleepEx(1000, TRUE);
65 }
66
67 void gui_send_char(int is_stderr, int c)
68 {
69     unsigned int msg_id = WM_STD_OUT_CHAR;
70     if (is_stderr)
71         msg_id = WM_STD_ERR_CHAR;
72     send_msg(gui_hwnd, msg_id, (WPARAM) c);
73 }
74
75 void gui_send_errcount(int list, int errs)
76 {
77     unsigned int msg_id = WM_RET_ERR_CNT;
78     if (list)
79         msg_id = WM_LS_RET_ERR_CNT;
80     while (!PostMessage(gui_hwnd, msg_id, (WPARAM) errs, 0))
81         SleepEx(1000, TRUE);
82 }
83
84 void gui_update_stats(char *name, unsigned long size,
85                       int percentage, unsigned long elapsed,
86                       unsigned long done, unsigned long eta,
87                       unsigned long ratebs)
88 {
89     unsigned int i;
90
91     if (strcmp(name, statname) != 0) {
92         for (i = 0; i < strlen(name); ++i)
93             send_msg(gui_hwnd, WM_STATS_CHAR, (WPARAM) name[i]);
94         send_msg(gui_hwnd, WM_STATS_CHAR, (WPARAM) '\n');
95         strcpy(statname, name);
96     }
97     if (statsize != size) {
98         send_msg(gui_hwnd, WM_STATS_SIZE, (WPARAM) size);
99         statsize = size;
100     }
101     if (statdone != done) {
102         send_msg(gui_hwnd, WM_STATS_DONE, (WPARAM) done);
103         statdone = done;
104     }
105     if (stateta != eta) {
106         send_msg(gui_hwnd, WM_STATS_ETA, (WPARAM) eta);
107         stateta = eta;
108     }
109     if (statratebs != ratebs) {
110         send_msg(gui_hwnd, WM_STATS_RATEBS, (WPARAM) ratebs);
111         statratebs = ratebs;
112     }
113     if (statelapsed != elapsed) {
114         send_msg(gui_hwnd, WM_STATS_ELAPSED, (WPARAM) elapsed);
115         statelapsed = elapsed;
116     }
117     if (statperct != percentage) {
118         send_msg(gui_hwnd, WM_STATS_PERCENT, (WPARAM) percentage);
119         statperct = percentage;
120     }
121 }
122
123 void gui_enable(char *arg)
124 {
125     gui_hwnd = (HWND) atoi(arg);
126 }
127
128 /* ----------------------------------------------------------------------
129  * File access abstraction.
130  */
131
132 /*
133  * Set local current directory. Returns NULL on success, or else an
134  * error message which must be freed after printing.
135  */
136 char *psftp_lcd(char *dir)
137 {
138     char *ret = NULL;
139
140     if (!SetCurrentDirectory(dir)) {
141         LPVOID message;
142         int i;
143         FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER |
144                       FORMAT_MESSAGE_FROM_SYSTEM |
145                       FORMAT_MESSAGE_IGNORE_INSERTS,
146                       NULL, GetLastError(),
147                       MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
148                       (LPTSTR)&message, 0, NULL);
149         i = strcspn((char *)message, "\n");
150         ret = dupprintf("%.*s", i, (LPCTSTR)message);
151         LocalFree(message);
152     }
153
154     return ret;
155 }
156
157 /*
158  * Get local current directory. Returns a string which must be
159  * freed.
160  */
161 char *psftp_getcwd(void)
162 {
163     char *ret = snewn(256, char);
164     int len = GetCurrentDirectory(256, ret);
165     if (len > 256)
166         ret = sresize(ret, len, char);
167     GetCurrentDirectory(len, ret);
168     return ret;
169 }
170
171 #define TIME_POSIX_TO_WIN(t, ft) (*(LONGLONG*)&(ft) = \
172         ((LONGLONG) (t) + (LONGLONG) 11644473600) * (LONGLONG) 10000000)
173 #define TIME_WIN_TO_POSIX(ft, t) ((t) = (unsigned long) \
174         ((*(LONGLONG*)&(ft)) / (LONGLONG) 10000000 - (LONGLONG) 11644473600))
175
176 struct RFile {
177     HANDLE h;
178 };
179
180 RFile *open_existing_file(char *name, unsigned long *size,
181                           unsigned long *mtime, unsigned long *atime)
182 {
183     HANDLE h;
184     RFile *ret;
185
186     h = CreateFile(name, GENERIC_READ, FILE_SHARE_READ, NULL,
187                    OPEN_EXISTING, 0, 0);
188     if (h == INVALID_HANDLE_VALUE)
189         return NULL;
190
191     ret = snew(RFile);
192     ret->h = h;
193
194     if (size)
195         *size = GetFileSize(h, NULL);
196
197     if (mtime || atime) {
198         FILETIME actime, wrtime;
199         GetFileTime(h, NULL, &actime, &wrtime);
200         if (atime)
201             TIME_WIN_TO_POSIX(actime, *atime);
202         if (mtime)
203             TIME_WIN_TO_POSIX(wrtime, *mtime);
204     }
205
206     return ret;
207 }
208
209 int read_from_file(RFile *f, void *buffer, int length)
210 {
211     int ret, read;
212     ret = ReadFile(f->h, buffer, length, &read, NULL);
213     if (!ret)
214         return -1;                     /* error */
215     else
216         return read;
217 }
218
219 void close_rfile(RFile *f)
220 {
221     CloseHandle(f->h);
222     sfree(f);
223 }
224
225 struct WFile {
226     HANDLE h;
227 };
228
229 WFile *open_new_file(char *name)
230 {
231     HANDLE h;
232     WFile *ret;
233
234     h = CreateFile(name, GENERIC_WRITE, 0, NULL,
235                    CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
236     if (h == INVALID_HANDLE_VALUE)
237         return NULL;
238
239     ret = snew(WFile);
240     ret->h = h;
241
242     return ret;
243 }
244
245 int write_to_file(WFile *f, void *buffer, int length)
246 {
247     int ret, written;
248     ret = WriteFile(f->h, buffer, length, &written, NULL);
249     if (!ret)
250         return -1;                     /* error */
251     else
252         return written;
253 }
254
255 void set_file_times(WFile *f, unsigned long mtime, unsigned long atime)
256 {
257     FILETIME actime, wrtime;
258     TIME_POSIX_TO_WIN(atime, actime);
259     TIME_POSIX_TO_WIN(mtime, wrtime);
260     SetFileTime(f->h, NULL, &actime, &wrtime);
261 }
262
263 void close_wfile(WFile *f)
264 {
265     CloseHandle(f->h);
266     sfree(f);
267 }
268
269 int file_type(char *name)
270 {
271     DWORD attr;
272     attr = GetFileAttributes(name);
273     /* We know of no `weird' files under Windows. */
274     if (attr == (DWORD)-1)
275         return FILE_TYPE_NONEXISTENT;
276     else if (attr & FILE_ATTRIBUTE_DIRECTORY)
277         return FILE_TYPE_DIRECTORY;
278     else
279         return FILE_TYPE_FILE;
280 }
281
282 struct DirHandle {
283     HANDLE h;
284     char *name;
285 };
286
287 DirHandle *open_directory(char *name)
288 {
289     HANDLE h;
290     WIN32_FIND_DATA fdat;
291     char *findfile;
292     DirHandle *ret;
293
294     /* Enumerate files in dir `foo'. */
295     findfile = dupcat(name, "/*", NULL);
296     h = FindFirstFile(findfile, &fdat);
297     if (h == INVALID_HANDLE_VALUE)
298         return NULL;
299     sfree(findfile);
300
301     ret = snew(DirHandle);
302     ret->h = h;
303     ret->name = dupstr(fdat.cFileName);
304     return ret;
305 }
306
307 char *read_filename(DirHandle *dir)
308 {
309     while (!dir->name) {
310         WIN32_FIND_DATA fdat;
311         int ok = FindNextFile(dir->h, &fdat);
312
313         if (!ok)
314             return NULL;
315
316         if (fdat.cFileName[0] == '.' &&
317             (fdat.cFileName[1] == '\0' ||
318              (fdat.cFileName[1] == '.' && fdat.cFileName[2] == '\0')))
319             dir->name = NULL;
320         else
321             dir->name = dupstr(fdat.cFileName);
322     }
323
324     if (dir->name) {
325         char *ret = dir->name;
326         dir->name = NULL;
327         return ret;
328     } else
329         return NULL;
330 }
331
332 void close_directory(DirHandle *dir)
333 {
334     FindClose(dir->h);
335     if (dir->name)
336         sfree(dir->name);
337     sfree(dir);
338 }
339
340 int test_wildcard(char *name, int cmdline)
341 {
342     HANDLE fh;
343     WIN32_FIND_DATA fdat;
344
345     /* First see if the exact name exists. */
346     if (GetFileAttributes(name) != (DWORD)-1)
347         return WCTYPE_FILENAME;
348
349     /* Otherwise see if a wildcard match finds anything. */
350     fh = FindFirstFile(name, &fdat);
351     if (fh == INVALID_HANDLE_VALUE)
352         return WCTYPE_NONEXISTENT;
353
354     FindClose(fh);
355     return WCTYPE_WILDCARD;
356 }
357
358 struct WildcardMatcher {
359     HANDLE h;
360     char *name;
361     char *srcpath;
362 };
363
364 /*
365  * Return a pointer to the portion of str that comes after the last
366  * slash (or backslash or colon, if `local' is TRUE).
367  */
368 static char *stripslashes(char *str, int local)
369 {
370     char *p;
371
372     if (local) {
373         p = strchr(str, ':');
374         if (p) str = p+1;
375     }
376
377     p = strrchr(str, '/');
378     if (p) str = p+1;
379
380     if (local) {
381         p = strrchr(str, '\\');
382         if (p) str = p+1;
383     }
384
385     return str;
386 }
387
388 WildcardMatcher *begin_wildcard_matching(char *name)
389 {
390     HANDLE h;
391     WIN32_FIND_DATA fdat;
392     WildcardMatcher *ret;
393     char *last;
394
395     h = FindFirstFile(name, &fdat);
396     if (h == INVALID_HANDLE_VALUE)
397         return NULL;
398
399     ret = snew(WildcardMatcher);
400     ret->h = h;
401     ret->srcpath = dupstr(name);
402     last = stripslashes(ret->srcpath, 1);
403     *last = '\0';
404     if (fdat.cFileName[0] == '.' &&
405         (fdat.cFileName[1] == '\0' ||
406          (fdat.cFileName[1] == '.' && fdat.cFileName[2] == '\0')))
407         ret->name = NULL;
408     else
409         ret->name = dupcat(ret->srcpath, fdat.cFileName, NULL);
410
411     return ret;
412 }
413
414 char *wildcard_get_filename(WildcardMatcher *dir)
415 {
416     while (!dir->name) {
417         WIN32_FIND_DATA fdat;
418         int ok = FindNextFile(dir->h, &fdat);
419
420         if (!ok)
421             return NULL;
422
423         if (fdat.cFileName[0] == '.' &&
424             (fdat.cFileName[1] == '\0' ||
425              (fdat.cFileName[1] == '.' && fdat.cFileName[2] == '\0')))
426             dir->name = NULL;
427         else
428             dir->name = dupcat(dir->srcpath, fdat.cFileName, NULL);
429     }
430
431     if (dir->name) {
432         char *ret = dir->name;
433         dir->name = NULL;
434         return ret;
435     } else
436         return NULL;
437 }
438
439 void finish_wildcard_matching(WildcardMatcher *dir)
440 {
441     FindClose(dir->h);
442     if (dir->name)
443         sfree(dir->name);
444     sfree(dir->srcpath);
445     sfree(dir);
446 }
447
448 int vet_filename(char *name)
449 {
450     if (strchr(name, '/') || strchr(name, '\\') || strchr(name, ':'))
451         return FALSE;
452
453     if (!name[strspn(name, ".")])      /* entirely composed of dots */
454         return FALSE;
455
456     return TRUE;
457 }
458
459 int create_directory(char *name)
460 {
461     return CreateDirectory(name, NULL) != 0;
462 }
463
464 char *dir_file_cat(char *dir, char *file)
465 {
466     return dupcat(dir, "\\", file, NULL);
467 }
468
469 /* ----------------------------------------------------------------------
470  * Platform-specific network handling.
471  */
472
473 /*
474  * Be told what socket we're supposed to be using.
475  */
476 static SOCKET sftp_ssh_socket = INVALID_SOCKET;
477 static HANDLE netevent = NULL;
478 char *do_select(SOCKET skt, int startup)
479 {
480     int events;
481     if (startup)
482         sftp_ssh_socket = skt;
483     else
484         sftp_ssh_socket = INVALID_SOCKET;
485
486     if (p_WSAEventSelect) {
487         if (startup) {
488             events = (FD_CONNECT | FD_READ | FD_WRITE |
489                       FD_OOB | FD_CLOSE | FD_ACCEPT);
490             netevent = CreateEvent(NULL, FALSE, FALSE, NULL);
491         } else {
492             events = 0;
493         }
494         if (p_WSAEventSelect(skt, netevent, events) == SOCKET_ERROR) {
495             switch (p_WSAGetLastError()) {
496               case WSAENETDOWN:
497                 return "Network is down";
498               default:
499                 return "WSAEventSelect(): unknown error";
500             }
501         }
502     }
503     return NULL;
504 }
505 extern int select_result(WPARAM, LPARAM);
506
507 int do_eventsel_loop(HANDLE other_event)
508 {
509     int n;
510     long next, ticks;
511     HANDLE handles[2];
512     SOCKET *sklist;
513     int skcount;
514     long now = GETTICKCOUNT();
515
516     if (!netevent) {
517         return -1;                     /* doom */
518     }
519
520     handles[0] = netevent;
521     handles[1] = other_event;
522
523     if (run_timers(now, &next)) {
524         ticks = next - GETTICKCOUNT();
525         if (ticks < 0) ticks = 0;  /* just in case */
526     } else {
527         ticks = INFINITE;
528     }
529
530     n = MsgWaitForMultipleObjects(other_event ? 2 : 1, handles, FALSE, ticks,
531                                   QS_POSTMESSAGE);
532
533     if (n == WAIT_OBJECT_0 + 0) {
534         WSANETWORKEVENTS things;
535         SOCKET socket;
536         extern SOCKET first_socket(int *), next_socket(int *);
537         extern int select_result(WPARAM, LPARAM);
538         int i, socketstate;
539
540         /*
541          * We must not call select_result() for any socket
542          * until we have finished enumerating within the
543          * tree. This is because select_result() may close
544          * the socket and modify the tree.
545          */
546         /* Count the active sockets. */
547         i = 0;
548         for (socket = first_socket(&socketstate);
549              socket != INVALID_SOCKET;
550              socket = next_socket(&socketstate)) i++;
551
552         /* Expand the buffer if necessary. */
553         sklist = snewn(i, SOCKET);
554
555         /* Retrieve the sockets into sklist. */
556         skcount = 0;
557         for (socket = first_socket(&socketstate);
558              socket != INVALID_SOCKET;
559              socket = next_socket(&socketstate)) {
560             sklist[skcount++] = socket;
561         }
562
563         /* Now we're done enumerating; go through the list. */
564         for (i = 0; i < skcount; i++) {
565             WPARAM wp;
566             socket = sklist[i];
567             wp = (WPARAM) socket;
568             if (!p_WSAEnumNetworkEvents(socket, NULL, &things)) {
569                 static const struct { int bit, mask; } eventtypes[] = {
570                     {FD_CONNECT_BIT, FD_CONNECT},
571                     {FD_READ_BIT, FD_READ},
572                     {FD_CLOSE_BIT, FD_CLOSE},
573                     {FD_OOB_BIT, FD_OOB},
574                     {FD_WRITE_BIT, FD_WRITE},
575                     {FD_ACCEPT_BIT, FD_ACCEPT},
576                 };
577                 int e;
578
579                 noise_ultralight(socket);
580                 noise_ultralight(things.lNetworkEvents);
581
582                 for (e = 0; e < lenof(eventtypes); e++)
583                     if (things.lNetworkEvents & eventtypes[e].mask) {
584                         LPARAM lp;
585                         int err = things.iErrorCode[eventtypes[e].bit];
586                         lp = WSAMAKESELECTREPLY(eventtypes[e].mask, err);
587                         select_result(wp, lp);
588                     }
589             }
590         }
591
592         sfree(sklist);
593     }
594
595     if (n == WAIT_TIMEOUT) {
596         now = next;
597     } else {
598         now = GETTICKCOUNT();
599     }
600
601     if (other_event && n == WAIT_OBJECT_0 + 1)
602         return 1;
603
604     return 0;
605 }
606
607 /*
608  * Wait for some network data and process it.
609  *
610  * We have two variants of this function. One uses select() so that
611  * it's compatible with WinSock 1. The other uses WSAEventSelect
612  * and MsgWaitForMultipleObjects, so that we can consistently use
613  * WSAEventSelect throughout; this enables us to also implement
614  * ssh_sftp_get_cmdline() using a parallel mechanism.
615  */
616 int ssh_sftp_loop_iteration(void)
617 {
618     if (sftp_ssh_socket == INVALID_SOCKET)
619         return -1;                     /* doom */
620
621     if (p_WSAEventSelect == NULL) {
622         fd_set readfds;
623         int ret;
624         long now = GETTICKCOUNT();
625
626         if (socket_writable(sftp_ssh_socket))
627             select_result((WPARAM) sftp_ssh_socket, (LPARAM) FD_WRITE);
628
629         do {
630             long next, ticks;
631             struct timeval tv, *ptv;
632
633             if (run_timers(now, &next)) {
634                 ticks = next - GETTICKCOUNT();
635                 if (ticks <= 0)
636                     ticks = 1;         /* just in case */
637                 tv.tv_sec = ticks / 1000;
638                 tv.tv_usec = ticks % 1000 * 1000;
639                 ptv = &tv;
640             } else {
641                 ptv = NULL;
642             }
643
644             FD_ZERO(&readfds);
645             FD_SET(sftp_ssh_socket, &readfds);
646             ret = p_select(1, &readfds, NULL, NULL, ptv);
647
648             if (ret < 0)
649                 return -1;                     /* doom */
650             else if (ret == 0)
651                 now = next;
652             else
653                 now = GETTICKCOUNT();
654
655         } while (ret == 0);
656
657         select_result((WPARAM) sftp_ssh_socket, (LPARAM) FD_READ);
658
659         return 0;
660     } else {
661         return do_eventsel_loop(NULL);
662     }
663 }
664
665 /*
666  * Read a command line from standard input.
667  * 
668  * In the presence of WinSock 2, we can use WSAEventSelect to
669  * mediate between the socket and stdin, meaning we can send
670  * keepalives and respond to server events even while waiting at
671  * the PSFTP command prompt. Without WS2, we fall back to a simple
672  * fgets.
673  */
674 struct command_read_ctx {
675     HANDLE event;
676     char *line;
677 };
678
679 static DWORD WINAPI command_read_thread(void *param)
680 {
681     struct command_read_ctx *ctx = (struct command_read_ctx *) param;
682
683     ctx->line = fgetline(stdin);
684
685     SetEvent(ctx->event);
686
687     return 0;
688 }
689
690 char *ssh_sftp_get_cmdline(char *prompt, int no_fds_ok)
691 {
692     int ret;
693     struct command_read_ctx actx, *ctx = &actx;
694     DWORD threadid;
695
696     fputs(prompt, stdout);
697     fflush(stdout);
698
699     if ((sftp_ssh_socket == INVALID_SOCKET && no_fds_ok) ||
700         p_WSAEventSelect == NULL) {
701         return fgetline(stdin);        /* very simple */
702     }
703
704     /*
705      * Create a second thread to read from stdin. Process network
706      * and timing events until it terminates.
707      */
708     ctx->event = CreateEvent(NULL, FALSE, FALSE, NULL);
709     ctx->line = NULL;
710
711     if (!CreateThread(NULL, 0, command_read_thread,
712                       ctx, 0, &threadid)) {
713         fprintf(stderr, "Unable to create command input thread\n");
714         cleanup_exit(1);
715     }
716
717     do {
718         ret = do_eventsel_loop(ctx->event);
719
720         /* Error return can only occur if netevent==NULL, and it ain't. */
721         assert(ret >= 0);
722     } while (ret == 0);
723
724     return ctx->line;
725 }
726
727 /* ----------------------------------------------------------------------
728  * Main program. Parse arguments etc.
729  */
730 int main(int argc, char *argv[])
731 {
732     int ret;
733
734     ret = psftp_main(argc, argv);
735
736     return ret;
737 }