* add more module handlers:
[citadel.git] / webcit / tcp_sockets.c
1 /*
2  * $Id$
3  */
4
5 /*
6  * Uncomment this to log all communications with the Citadel server
7 #define SERV_TRACE 1
8  */
9
10
11 #include "webcit.h"
12 #include "webserver.h"
13
14 extern int DisableGzip;
15
16 /*
17  *  register the timeout
18  *  signum signalhandler number
19  * \return signals
20  */
21 RETSIGTYPE timeout(int signum)
22 {
23         lprintf(1, "Connection timed out; unable to reach citserver\n");
24         /* no exit here, since we need to server the connection unreachable thing. exit(3); */
25 }
26
27
28 /*
29  *  Connect a unix domain socket
30  *  sockpath where to open a unix domain socket
31  */
32 int uds_connectsock(char *sockpath)
33 {
34         struct sockaddr_un addr;
35         int s;
36
37         memset(&addr, 0, sizeof(addr));
38         addr.sun_family = AF_UNIX;
39         strncpy(addr.sun_path, sockpath, sizeof addr.sun_path);
40
41         s = socket(AF_UNIX, SOCK_STREAM, 0);
42         if (s < 0) {
43                 lprintf(1, "Can't create socket[%s]: %s\n",
44                         sockpath,
45                         strerror(errno));
46                 return(-1);
47         }
48
49         if (connect(s, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
50                 lprintf(1, "Can't connect [%s]: %s\n",
51                         sockpath,
52                         strerror(errno));
53                 close(s);
54                 return(-1);
55         }
56
57         return s;
58 }
59
60
61 /*
62  *  Connect a TCP/IP socket
63  *  host the host to connect to
64  *  service the service on the host to call
65  */
66 int tcp_connectsock(char *host, char *service)
67 {
68         int fdflags;
69         struct hostent *phe;
70         struct servent *pse;
71         struct protoent *ppe;
72         struct sockaddr_in sin;
73         int s;
74
75         memset(&sin, 0, sizeof(sin));
76         sin.sin_family = AF_INET;
77
78         pse = getservbyname(service, "tcp");
79         if (pse) {
80                 sin.sin_port = pse->s_port;
81         } else if ((sin.sin_port = htons((u_short) atoi(service))) == 0) {
82                 lprintf(1, "Can't get %s service entry\n", service);
83                 return (-1);
84         }
85         phe = gethostbyname(host);
86         if (phe) {
87                 memcpy(&sin.sin_addr, phe->h_addr, phe->h_length);
88         } else if ((sin.sin_addr.s_addr = inet_addr(host)) == INADDR_NONE) {
89                 lprintf(1, "Can't get %s host entry: %s\n",
90                         host, strerror(errno));
91                 return (-1);
92         }
93         if ((ppe = getprotobyname("tcp")) == 0) {
94                 lprintf(1, "Can't get TCP protocol entry: %s\n",
95                         strerror(errno));
96                 return (-1);
97         }
98
99         s = socket(PF_INET, SOCK_STREAM, ppe->p_proto);
100         if (s < 0) {
101                 lprintf(1, "Can't create socket: %s\n", strerror(errno));
102                 return (-1);
103         }
104
105         fdflags = fcntl(s, F_GETFL);
106         if (fdflags < 0)
107                 lprintf(1, "unable to get socket flags!  %s.%s: %s \n",
108                         host, service, strerror(errno));
109         fdflags = fdflags | O_NONBLOCK;
110         if (fcntl(s, F_SETFD, fdflags) < 0)
111                 lprintf(1, "unable to set socket nonblocking flags!  %s.%s: %s \n",
112                         host, service, strerror(errno));
113
114         signal(SIGALRM, timeout);
115         alarm(30);
116
117         if (connect(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) {
118                 lprintf(1, "Can't connect to %s.%s: %s\n",
119                         host, service, strerror(errno));
120                 close(s);
121                 return (-1);
122         }
123         alarm(0);
124         signal(SIGALRM, SIG_IGN);
125
126         fdflags = fcntl(s, F_GETFL);
127         if (fdflags < 0)
128                 lprintf(1, "unable to get socket flags!  %s.%s: %s \n",
129                         host, service, strerror(errno));
130         fdflags = fdflags | O_NONBLOCK;
131         if (fcntl(s, F_SETFD, fdflags) < 0)
132                 lprintf(1, "unable to set socket nonblocking flags!  %s.%s: %s \n",
133                         host, service, strerror(errno));
134         return (s);
135 }
136
137
138
139 /*
140  *  input string from pipe
141  */
142 int serv_getln(char *strbuf, int bufsize)
143 {
144         wcsession *WCC = WC;
145         int len;
146
147         *strbuf = '\0';
148         StrBuf_ServGetln(WCC->MigrateReadLineBuf);
149         len = StrLength(WCC->MigrateReadLineBuf);
150         if (len > bufsize)
151                 len = bufsize - 1;
152         memcpy(strbuf, ChrPtr(WCC->MigrateReadLineBuf), len);
153         FlushStrBuf(WCC->MigrateReadLineBuf);
154         strbuf[len] = '\0';
155 #ifdef SERV_TRACE
156         lprintf(9, "%3d>%s\n", WC->serv_sock, strbuf);
157 #endif
158         return len;
159 }
160
161
162 int StrBuf_ServGetln(StrBuf *buf)
163 {
164         wcsession *WCC = WC;
165         const char *ErrStr = NULL;
166         int rc;
167
168         rc = StrBufTCP_read_buffered_line_fast(buf, 
169                                                WCC->ReadBuf, 
170                                                &WCC->ReadPos, 
171                                                &WCC->serv_sock, 
172                                                5, 1, 
173                                                &ErrStr);
174         if (rc < 0)
175         {
176                 lprintf(1, "Server connection broken: %s\n",
177                         ErrStr);
178                 wc_backtrace();
179                 WCC->serv_sock = (-1);
180                 WCC->connected = 0;
181                 WCC->logged_in = 0;
182         }
183         return rc;
184 }
185
186 int StrBuf_ServGetBLOBBuffered(StrBuf *buf, long BlobSize)
187 {
188         wcsession *WCC = WC;
189         const char *Err;
190         int rc;
191         
192         rc = StrBufReadBLOBBuffered(buf, 
193                                     WCC->ReadBuf, 
194                                     &WCC->ReadPos,
195                                     &WCC->serv_sock, 
196                                     1, 
197                                     BlobSize, 
198                                     NNN_TERM,
199                                     &Err);
200         if (rc < 0)
201         {
202                 lprintf(1, "Server connection broken: %s\n",
203                         Err);
204                 wc_backtrace();
205                 WCC->serv_sock = (-1);
206                 WCC->connected = 0;
207                 WCC->logged_in = 0;
208         }
209         return rc;
210 }
211
212 int StrBuf_ServGetBLOB(StrBuf *buf, long BlobSize)
213 {
214         wcsession *WCC = WC;
215         const char *Err;
216         int rc;
217         
218         WCC->ReadPos = NULL;
219         rc = StrBufReadBLOB(buf, &WCC->serv_sock, 1, BlobSize, &Err);
220         if (rc < 0)
221         {
222                 lprintf(1, "Server connection broken: %s\n",
223                         Err);
224                 wc_backtrace();
225                 WCC->serv_sock = (-1);
226                 WCC->connected = 0;
227                 WCC->logged_in = 0;
228         }
229         return rc;
230 }
231
232 /*
233  *  send binary to server
234  *  buf the buffer to write to citadel server
235  *  nbytes how many bytes to send to citadel server
236  */
237 void serv_write(const char *buf, int nbytes)
238 {
239         int bytes_written = 0;
240         int retval;
241         while (bytes_written < nbytes) {
242                 retval = write(WC->serv_sock, &buf[bytes_written],
243                                nbytes - bytes_written);
244                 if (retval < 1) {
245                         lprintf(1, "Server connection broken: %s\n",
246                                 strerror(errno));
247                         close(WC->serv_sock);
248                         WC->serv_sock = (-1);
249                         WC->connected = 0;
250                         WC->logged_in = 0;
251                         return;
252                 }
253                 bytes_written = bytes_written + retval;
254         }
255 }
256
257
258 /*
259  *  send line to server
260  *  string the line to send to the citadel server
261  */
262 void serv_puts(const char *string)
263 {
264         wcsession *WCC = WC;
265 #ifdef SERV_TRACE
266         lprintf(9, "%3d<%s\n", WC->serv_sock, string);
267 #endif
268         FlushStrBuf(WCC->ReadBuf);
269         WCC->ReadPos = NULL;
270
271         serv_write(string, strlen(string));
272         serv_write("\n", 1);
273 }
274
275 /*
276  *  send line to server
277  *  string the line to send to the citadel server
278  */
279 void serv_putbuf(const StrBuf *string)
280 {
281         wcsession *WCC = WC;
282 #ifdef SERV_TRACE
283         lprintf(9, "%3d<%s\n", WC->serv_sock, ChrPtr(string));
284 #endif
285         FlushStrBuf(WCC->ReadBuf);
286         WCC->ReadPos = NULL;
287
288         serv_write(ChrPtr(string), StrLength(string));
289         serv_write("\n", 1);
290 }
291
292
293 /*
294  *  convenience function to send stuff to the server
295  *  format the formatstring
296  *  ... the entities to insert into format 
297  */
298 void serv_printf(const char *format,...)
299 {
300         wcsession *WCC = WC;
301         va_list arg_ptr;
302         char buf[SIZ];
303         size_t len;
304
305         FlushStrBuf(WCC->ReadBuf);
306         WCC->ReadPos = NULL;
307
308         va_start(arg_ptr, format);
309         vsnprintf(buf, sizeof buf, format, arg_ptr);
310         va_end(arg_ptr);
311
312         len = strlen(buf);
313         buf[len++] = '\n';
314         buf[len] = '\0';
315         serv_write(buf, len);
316 #ifdef SERV_TRACE
317         lprintf(9, "<%s", buf);
318 #endif
319 }
320
321
322
323
324 int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
325 {
326         const char *Error, *pch, *pchs;
327         int rlen, len, retval = 0;
328
329 #ifdef HAVE_OPENSSL
330         if (is_https) {
331                 int ntries = 0;
332                 if (StrLength(CLineBuf) > 0) {
333                         pchs = ChrPtr(CLineBuf);
334                         pch = strchr(pchs, '\n');
335                         if (pch != NULL) {
336                                 rlen = 0;
337                                 len = pch - pchs;
338                                 if (len > 0 && (*(pch - 1) == '\r') )
339                                         rlen ++;
340                                 StrBufSub(Target, CLineBuf, 0, len - rlen);
341                                 StrBufCutLeft(CLineBuf, len + 1);
342                                 return len - rlen;
343                         }
344                 }
345
346                 while (retval == 0) { 
347                                 pch = NULL;
348                                 pchs = ChrPtr(CLineBuf);
349                                 if (*pchs != '\0')
350                                         pch = strchr(pchs, '\n');
351                                 if (pch == NULL) {
352                                         retval = client_read_sslbuffer(CLineBuf, SLEEPING);
353                                         pchs = ChrPtr(CLineBuf);
354                                         pch = strchr(pchs, '\n');
355                                 }
356                                 if (retval == 0) {
357                                         sleeeeeeeeeep(1);
358                                         ntries ++;
359                                 }
360                                 if (ntries > 10)
361                                         return 0;
362                 }
363                 if ((retval > 0) && (pch != NULL)) {
364                         rlen = 0;
365                         len = pch - pchs;
366                         if (len > 0 && (*(pch - 1) == '\r') )
367                                 rlen ++;
368                         StrBufSub(Target, CLineBuf, 0, len - rlen);
369                         StrBufCutLeft(CLineBuf, len + 1);
370                         return len - rlen;
371
372                 }
373                 else 
374                         return -1;
375         }
376         else 
377 #endif
378                 return StrBufTCP_read_buffered_line_fast(Target, 
379                                                          CLineBuf,
380                                                          Pos,
381                                                          sock,
382                                                          5,
383                                                          1,
384                                                          &Error);
385 }
386
387 /* 
388  * This is a generic function to set up a master socket for listening on
389  * a TCP port.  The server shuts down if the bind fails.
390  *
391  * ip_addr      IP address to bind
392  * port_number  port number to bind
393  * queue_len    number of incoming connections to allow in the queue
394  */
395 int ig_tcp_server(char *ip_addr, int port_number, int queue_len)
396 {
397         struct sockaddr_in sin;
398         int s, i;
399
400         memset(&sin, 0, sizeof(sin));
401         sin.sin_family = AF_INET;
402         if (ip_addr == NULL) {
403                 sin.sin_addr.s_addr = INADDR_ANY;
404         } else {
405                 sin.sin_addr.s_addr = inet_addr(ip_addr);
406         }
407
408         if (sin.sin_addr.s_addr == INADDR_NONE) {
409                 sin.sin_addr.s_addr = INADDR_ANY;
410         }
411
412         if (port_number == 0) {
413                 lprintf(1, "Cannot start: no port number specified.\n");
414                 exit(WC_EXIT_BIND);
415         }
416         sin.sin_port = htons((u_short) port_number);
417
418         s = socket(PF_INET, SOCK_STREAM, (getprotobyname("tcp")->p_proto));
419         if (s < 0) {
420                 lprintf(1, "Can't create a socket: %s\n", strerror(errno));
421                 exit(WC_EXIT_BIND);
422         }
423         /* Set some socket options that make sense. */
424         i = 1;
425         setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i));
426
427         #ifndef __APPLE__
428         fcntl(s, F_SETFL, O_NONBLOCK); /* maide: this statement is incorrect
429                                           there should be a preceding F_GETFL
430                                           and a bitwise OR with the previous
431                                           fd flags */
432         #endif
433         
434         if (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) {
435                 lprintf(1, "Can't bind: %s\n", strerror(errno));
436                 exit(WC_EXIT_BIND);
437         }
438         if (listen(s, queue_len) < 0) {
439                 lprintf(1, "Can't listen: %s\n", strerror(errno));
440                 exit(WC_EXIT_BIND);
441         }
442         return (s);
443 }
444
445
446
447 /*
448  * Create a Unix domain socket and listen on it
449  * sockpath - file name of the unix domain socket
450  * queue_len - Number of incoming connections to allow in the queue
451  */
452 int ig_uds_server(char *sockpath, int queue_len)
453 {
454         struct sockaddr_un addr;
455         int s;
456         int i;
457         int actual_queue_len;
458
459         actual_queue_len = queue_len;
460         if (actual_queue_len < 5) actual_queue_len = 5;
461
462         i = unlink(sockpath);
463         if ((i != 0) && (errno != ENOENT)) {
464                 lprintf(1, "webcit: can't unlink %s: %s\n",
465                         sockpath, strerror(errno));
466                 exit(WC_EXIT_BIND);
467         }
468
469         memset(&addr, 0, sizeof(addr));
470         addr.sun_family = AF_UNIX;
471         safestrncpy(addr.sun_path, sockpath, sizeof addr.sun_path);
472
473         s = socket(AF_UNIX, SOCK_STREAM, 0);
474         if (s < 0) {
475                 lprintf(1, "webcit: Can't create a socket: %s\n",
476                         strerror(errno));
477                 exit(WC_EXIT_BIND);
478         }
479
480         if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
481                 lprintf(1, "webcit: Can't bind: %s\n",
482                         strerror(errno));
483                 exit(WC_EXIT_BIND);
484         }
485
486         if (listen(s, actual_queue_len) < 0) {
487                 lprintf(1, "webcit: Can't listen: %s\n",
488                         strerror(errno));
489                 exit(WC_EXIT_BIND);
490         }
491
492         chmod(sockpath, 0777);
493         return(s);
494 }
495
496
497
498
499 /*
500  * Read data from the client socket.
501  *
502  * sock         socket fd to read from
503  * buf          buffer to read into 
504  * bytes        number of bytes to read
505  * timeout      Number of seconds to wait before timing out
506  *
507  * Possible return values:
508  *      1       Requested number of bytes has been read.
509  *      0       Request timed out.
510  *      -1      Connection is broken, or other error.
511  */
512 int client_read_to(int *sock, StrBuf *Target, StrBuf *Buf, const char **Pos, int bytes, int timeout)
513 {
514         const char *Error;
515         int retval = 0;
516
517 #ifdef HAVE_OPENSSL
518         if (is_https) {
519                 long bufremain = StrLength(Buf) - (*Pos - ChrPtr(Buf));
520                 StrBufAppendBufPlain(Target, *Pos, bufremain, 0);
521                 *Pos = NULL;
522                 FlushStrBuf(Buf);
523
524                 while ((StrLength(Buf) + StrLength(Target) < bytes) &&
525                        (retval >= 0))
526                         retval = client_read_sslbuffer(Buf, timeout);
527                 if (retval >= 0) {
528                         StrBufAppendBuf(Target, Buf, 0); /* todo: Buf > bytes? */
529 #ifdef HTTP_TRACING
530                         write(2, "\033[32m", 5);
531                         write(2, buf, bytes);
532                         write(2, "\033[30m", 5);
533 #endif
534                         return 1;
535                 }
536                 else {
537                         lprintf(2, "client_read_ssl() failed\n");
538                         return -1;
539                 }
540         }
541 #endif
542
543         retval = StrBufReadBLOBBuffered(Target, 
544                                         Buf, Pos, 
545                                         sock, 
546                                         1, 
547                                         bytes,
548                                         O_TERM,
549                                         &Error);
550         if (retval < 0) {
551                 lprintf(2, "client_read() failed: %s\n",
552                         Error);
553                 return retval;
554         }
555
556 #ifdef HTTP_TRACING
557         write(2, "\033[32m", 5);
558         write(2, buf, bytes);
559         write(2, "\033[30m", 5);
560 #endif
561         return 1;
562 }
563
564
565 /*
566  * Begin buffering HTTP output so we can transmit it all in one write operation later.
567  */
568 void begin_burst(void)
569 {
570         if (WC->WBuf == NULL) {
571                 WC->WBuf = NewStrBufPlain(NULL, 32768);
572         }
573 }
574
575
576 /*
577  * Finish buffering HTTP output.  [Compress using zlib and] output with a Content-Length: header.
578  */
579 long end_burst(void)
580 {
581         wcsession *WCC = WC;
582         const char *ptr, *eptr;
583         long count;
584         ssize_t res;
585         fd_set wset;
586         int fdflags;
587
588         if (!DisableGzip && (WCC->gzip_ok) && CompressBuffer(WCC->WBuf))
589         {
590                 hprintf("Content-encoding: gzip\r\n");
591         }
592
593         hprintf("Content-length: %d\r\n\r\n", StrLength(WCC->WBuf));
594
595         ptr = ChrPtr(WCC->HBuf);
596         count = StrLength(WCC->HBuf);
597         eptr = ptr + count;
598
599 #ifdef HAVE_OPENSSL
600         if (is_https) {
601                 client_write_ssl(WCC->HBuf);
602                 client_write_ssl(WCC->WBuf);
603                 return (count);
604         }
605 #endif
606
607         
608 #ifdef HTTP_TRACING
609         
610         write(2, "\033[34m", 5);
611         write(2, ptr, StrLength(WCC->WBuf));
612         write(2, "\033[30m", 5);
613 #endif
614         fdflags = fcntl(WC->http_sock, F_GETFL);
615
616         while (ptr < eptr) {
617                 if ((fdflags & O_NONBLOCK) == O_NONBLOCK) {
618                         FD_ZERO(&wset);
619                         FD_SET(WCC->http_sock, &wset);
620                         if (select(WCC->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
621                                 lprintf(2, "client_write: Socket select failed (%s)\n", strerror(errno));
622                                 return -1;
623                         }
624                 }
625
626                 if ((res = write(WCC->http_sock, 
627                                  ptr,
628                                  count)) == -1) {
629                         lprintf(2, "client_write: Socket write failed (%s)\n", strerror(errno));
630                         wc_backtrace();
631                         return res;
632                 }
633                 count -= res;
634                 ptr += res;
635         }
636
637         ptr = ChrPtr(WCC->WBuf);
638         count = StrLength(WCC->WBuf);
639         eptr = ptr + count;
640
641 #ifdef HTTP_TRACING
642         
643         write(2, "\033[34m", 5);
644         write(2, ptr, StrLength(WCC->WBuf));
645         write(2, "\033[30m", 5);
646 #endif
647
648         while (ptr < eptr) {
649                 if ((fdflags & O_NONBLOCK) == O_NONBLOCK) {
650                         FD_ZERO(&wset);
651                         FD_SET(WCC->http_sock, &wset);
652                         if (select(WCC->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
653                                 lprintf(2, "client_write: Socket select failed (%s)\n", strerror(errno));
654                                 return -1;
655                         }
656                 }
657
658                 if ((res = write(WCC->http_sock, 
659                                  ptr,
660                                  count)) == -1) {
661                         lprintf(2, "client_write: Socket write failed (%s)\n", strerror(errno));
662                         wc_backtrace();
663                         return res;
664                 }
665                 count -= res;
666                 ptr += res;
667         }
668
669         return StrLength(WCC->WBuf);
670 }
671
672
673
674 void
675 SessionNewModule_TCPSOCKETS
676 (wcsession *sess)
677 {
678         sess->CLineBuf = NewStrBuf();
679         sess->MigrateReadLineBuf = NewStrBuf();
680 }
681
682 void 
683 SessionDestroyModule_TCPSOCKETS
684 (wcsession *sess)
685 {
686         FreeStrBuf(&sess->CLineBuf);
687         FreeStrBuf(&sess->ReadBuf);
688         FreeStrBuf(&sess->MigrateReadLineBuf);
689 }