* move serv_func.c:read_server_binary() to tcp_sockets.c: serv_read_binary()
[citadel.git] / webcit / tcp_sockets.c
index 9c74b5782b8a4c1a094c7ecde2832426f81135f8..23bdf7bea76b777eeb73e1f3eb2e86b1a93e2b30 100644 (file)
 #include "webserver.h"
 
 extern int DisableGzip;
+long MaxRead = -1; /* should we do READ scattered or all at once? */
 
 /*
- *  register the timeout
- *  signum signalhandler number
- * \return signals
+ * register the timeout
  */
 RETSIGTYPE timeout(int signum)
 {
@@ -122,15 +121,16 @@ int tcp_connectsock(char *host, char *service)
        }
        alarm(0);
        signal(SIGALRM, SIG_IGN);
-
-       fdflags = fcntl(s, F_GETFL);
-       if (fdflags < 0)
-               lprintf(1, "unable to get socket flags!  %s.%s: %s \n",
-                       host, service, strerror(errno));
-       fdflags = fdflags | O_NONBLOCK;
-       if (fcntl(s, F_SETFD, fdflags) < 0)
-               lprintf(1, "unable to set socket nonblocking flags!  %s.%s: %s \n",
-                       host, service, strerror(errno));
+       if (!is_https) {
+               fdflags = fcntl(s, F_GETFL);
+               if (fdflags < 0)
+                       lprintf(1, "unable to get socket flags!  %s.%s: %s \n",
+                               host, service, strerror(errno));
+               fdflags = fdflags | O_NONBLOCK;
+               if (fcntl(s, F_SETFD, fdflags) < 0)
+                       lprintf(1, "unable to set socket nonblocking flags!  %s.%s: %s \n",
+                               host, service, strerror(errno));
+       }
        return (s);
 }
 
@@ -174,7 +174,7 @@ int StrBuf_ServGetln(StrBuf *buf)
        if (rc < 0)
        {
                lprintf(1, "Server connection broken: %s\n",
-                       ErrStr);
+                       (ErrStr)?ErrStr:"");
                wc_backtrace();
                WCC->serv_sock = (-1);
                WCC->connected = 0;
@@ -186,7 +186,7 @@ int StrBuf_ServGetln(StrBuf *buf)
 int StrBuf_ServGetBLOBBuffered(StrBuf *buf, long BlobSize)
 {
        wcsession *WCC = WC;
-       const char *Err;
+       const char *ErrStr;
        int rc;
        
        rc = StrBufReadBLOBBuffered(buf, 
@@ -196,11 +196,11 @@ int StrBuf_ServGetBLOBBuffered(StrBuf *buf, long BlobSize)
                                    1, 
                                    BlobSize, 
                                    NNN_TERM,
-                                   &Err);
+                                   &ErrStr);
        if (rc < 0)
        {
                lprintf(1, "Server connection broken: %s\n",
-                       Err);
+                       (ErrStr)?ErrStr:"");
                wc_backtrace();
                WCC->serv_sock = (-1);
                WCC->connected = 0;
@@ -212,15 +212,15 @@ int StrBuf_ServGetBLOBBuffered(StrBuf *buf, long BlobSize)
 int StrBuf_ServGetBLOB(StrBuf *buf, long BlobSize)
 {
        wcsession *WCC = WC;
-       const char *Err;
+       const char *ErrStr;
        int rc;
        
        WCC->ReadPos = NULL;
-       rc = StrBufReadBLOB(buf, &WCC->serv_sock, 1, BlobSize, &Err);
+       rc = StrBufReadBLOB(buf, &WCC->serv_sock, 1, BlobSize, &ErrStr);
        if (rc < 0)
        {
                lprintf(1, "Server connection broken: %s\n",
-                       Err);
+                       (ErrStr)?ErrStr:"");
                wc_backtrace();
                WCC->serv_sock = (-1);
                WCC->connected = 0;
@@ -236,18 +236,23 @@ int StrBuf_ServGetBLOB(StrBuf *buf, long BlobSize)
  */
 void serv_write(const char *buf, int nbytes)
 {
+       wcsession *WCC = WC;
        int bytes_written = 0;
        int retval;
+
+       FlushStrBuf(WCC->ReadBuf);
+       WCC->ReadPos = NULL;
        while (bytes_written < nbytes) {
-               retval = write(WC->serv_sock, &buf[bytes_written],
+               retval = write(WCC->serv_sock, &buf[bytes_written],
                               nbytes - bytes_written);
                if (retval < 1) {
+                       const char *ErrStr = strerror(errno);
                        lprintf(1, "Server connection broken: %s\n",
-                               strerror(errno));
-                       close(WC->serv_sock);
-                       WC->serv_sock = (-1);
-                       WC->connected = 0;
-                       WC->logged_in = 0;
+                               (ErrStr)?ErrStr:"");
+                       close(WCC->serv_sock);
+                       WCC->serv_sock = (-1);
+                       WCC->connected = 0;
+                       WCC->logged_in = 0;
                        return;
                }
                bytes_written = bytes_written + retval;
@@ -320,8 +325,118 @@ void serv_printf(const char *format,...)
 
 
 
+/**
+ * Read binary data from server into memory using a series of
+ * server READ commands.
+ * \return the read content as StrBuf
+ */
+int serv_read_binary(StrBuf *Ret, size_t total_len, StrBuf *Buf) 
+{
+       wcsession *WCC = WC;
+       size_t bytes = 0;
+       size_t thisblock = 0;
+       
+       if (Ret == NULL)
+           return -1;
+
+       if (MaxRead == -1)
+       {
+               serv_printf("READ %d|%d", 0, total_len);
+               if (StrBuf_ServGetln(Buf) > 0)
+               {
+                       long YetRead;
+                       const char *ErrStr;
+                       const char *pch;
+                       int rc;
+
+                       if (GetServerStatus(Buf, NULL) == 6)
+                       {
+                           StrBufCutLeft(Buf, 4);
+                           thisblock = StrTol(Buf);
+                           if (WCC->serv_sock==-1) {
+                                   FlushStrBuf(Ret); 
+                                   return -1; 
+                           }
+
+                           pch = ChrPtr(WCC->ReadBuf);
+                           YetRead = WCC->ReadPos - pch;
+                           if (YetRead > 0)
+                           {
+                                   long StillThere;
+                                   
+                                   StillThere = StrLength(WCC->ReadBuf) - 
+                                           YetRead;
+
+                                   StrBufPlain(Ret, 
+                                               WCC->ReadPos,
+                                               StillThere);
+                                   total_len -= StillThere;
+                           }
+                           FlushStrBuf(WCC->ReadBuf);
+                           WCC->ReadPos = NULL;
+                           
+                           if (total_len > 0)
+                           {
+                                   rc = StrBufReadBLOB(Ret, 
+                                                       &WCC->serv_sock, 
+                                                       1, 
+                                                       total_len,
+                                                       &ErrStr);
+                                   if (rc < 0)
+                                   {
+                                           lprintf(1, "Server connection broken: %s\n",
+                                                   (ErrStr)?ErrStr:"");
+                                           wc_backtrace();
+                                           WCC->serv_sock = (-1);
+                                           WCC->connected = 0;
+                                           WCC->logged_in = 0;
+                                           return rc;
+                                   }
+                                   else
+                                           return StrLength(Ret);
+                           }
+                           else 
+                                   return StrLength(Ret);
+                       }
+               }
+               else
+                       return -1;
+       }
+       else while ((WCC->serv_sock!=-1) &&
+              (bytes < total_len)) {
+               thisblock = MaxRead;
+               if ((total_len - bytes) < thisblock) {
+                       thisblock = total_len - bytes;
+                       if (thisblock == 0) {
+                               FlushStrBuf(Ret); 
+                               return -1; 
+                       }
+               }
+               serv_printf("READ %d|%d", (int)bytes, (int)thisblock);
+               if (StrBuf_ServGetln(Buf) > 0)
+               {
+                       if (GetServerStatus(Buf, NULL) == 6)
+                       {
+                           StrBufCutLeft(Buf, 4);
+                           thisblock = StrTol(Buf);
+                           if (WCC->serv_sock==-1) {
+                                   FlushStrBuf(Ret); 
+                                   return -1; 
+                           }
+                           StrBuf_ServGetBLOBBuffered(Ret, thisblock);
+                           bytes += thisblock;
+                   }
+                   else {
+                           lprintf(3, "Error: %s\n", ChrPtr(Buf) + 4);
+                           return -1;
+                   }
+               }
+       }
+       return StrLength(Ret);
+}
+
 
-int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
+int ClientGetLine(ParsedHttpHdrs *Hdr, StrBuf *Target)
 {
        const char *Error, *pch, *pchs;
        int rlen, len, retval = 0;
@@ -329,28 +444,28 @@ int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
 #ifdef HAVE_OPENSSL
        if (is_https) {
                int ntries = 0;
-               if (StrLength(CLineBuf) > 0) {
-                       pchs = ChrPtr(CLineBuf);
+               if (StrLength(Hdr->ReadBuf) > 0) {
+                       pchs = ChrPtr(Hdr->ReadBuf);
                        pch = strchr(pchs, '\n');
                        if (pch != NULL) {
                                rlen = 0;
                                len = pch - pchs;
                                if (len > 0 && (*(pch - 1) == '\r') )
                                        rlen ++;
-                               StrBufSub(Target, CLineBuf, 0, len - rlen);
-                               StrBufCutLeft(CLineBuf, len + 1);
+                               StrBufSub(Target, Hdr->ReadBuf, 0, len - rlen);
+                               StrBufCutLeft(Hdr->ReadBuf, len + 1);
                                return len - rlen;
                        }
                }
 
                while (retval == 0) { 
                                pch = NULL;
-                               pchs = ChrPtr(CLineBuf);
+                               pchs = ChrPtr(Hdr->ReadBuf);
                                if (*pchs != '\0')
                                        pch = strchr(pchs, '\n');
                                if (pch == NULL) {
-                                       retval = client_read_sslbuffer(CLineBuf, SLEEPING);
-                                       pchs = ChrPtr(CLineBuf);
+                                       retval = client_read_sslbuffer(Hdr->ReadBuf, SLEEPING);
+                                       pchs = ChrPtr(Hdr->ReadBuf);
                                        pch = strchr(pchs, '\n');
                                }
                                if (retval == 0) {
@@ -365,8 +480,8 @@ int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
                        len = pch - pchs;
                        if (len > 0 && (*(pch - 1) == '\r') )
                                rlen ++;
-                       StrBufSub(Target, CLineBuf, 0, len - rlen);
-                       StrBufCutLeft(CLineBuf, len + 1);
+                       StrBufSub(Target, Hdr->ReadBuf, 0, len - rlen);
+                       StrBufCutLeft(Hdr->ReadBuf, len + 1);
                        return len - rlen;
 
                }
@@ -376,9 +491,9 @@ int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
        else 
 #endif
                return StrBufTCP_read_buffered_line_fast(Target, 
-                                                        CLineBuf,
-                                                        Pos,
-                                                        sock,
+                                                        Hdr->ReadBuf,
+                                                        &Hdr->Pos,
+                                                        &Hdr->http_sock,
                                                         5,
                                                         1,
                                                         &Error);
@@ -394,6 +509,7 @@ int ClientGetLine(int *sock, StrBuf *Target, StrBuf *CLineBuf, const char **Pos)
  */
 int ig_tcp_server(char *ip_addr, int port_number, int queue_len)
 {
+       struct protoent *p;
        struct sockaddr_in sin;
        int s, i;
 
@@ -411,14 +527,16 @@ int ig_tcp_server(char *ip_addr, int port_number, int queue_len)
 
        if (port_number == 0) {
                lprintf(1, "Cannot start: no port number specified.\n");
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
        sin.sin_port = htons((u_short) port_number);
 
-       s = socket(PF_INET, SOCK_STREAM, (getprotobyname("tcp")->p_proto));
+       p = getprotobyname("tcp");
+
+       s = socket(PF_INET, SOCK_STREAM, (p->p_proto));
        if (s < 0) {
                lprintf(1, "Can't create a socket: %s\n", strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
        /* Set some socket options that make sense. */
        i = 1;
@@ -433,11 +551,11 @@ int ig_tcp_server(char *ip_addr, int port_number, int queue_len)
        
        if (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) {
                lprintf(1, "Can't bind: %s\n", strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
        if (listen(s, queue_len) < 0) {
                lprintf(1, "Can't listen: %s\n", strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
        return (s);
 }
@@ -463,7 +581,7 @@ int ig_uds_server(char *sockpath, int queue_len)
        if ((i != 0) && (errno != ENOENT)) {
                lprintf(1, "webcit: can't unlink %s: %s\n",
                        sockpath, strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
 
        memset(&addr, 0, sizeof(addr));
@@ -474,19 +592,19 @@ int ig_uds_server(char *sockpath, int queue_len)
        if (s < 0) {
                lprintf(1, "webcit: Can't create a socket: %s\n",
                        strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
 
        if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
                lprintf(1, "webcit: Can't bind: %s\n",
                        strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
 
        if (listen(s, actual_queue_len) < 0) {
                lprintf(1, "webcit: Can't listen: %s\n",
                        strerror(errno));
-               exit(WC_EXIT_BIND);
+               return (-WC_EXIT_BIND);
        }
 
        chmod(sockpath, 0777);
@@ -509,40 +627,55 @@ int ig_uds_server(char *sockpath, int queue_len)
  *      0       Request timed out.
  *     -1      Connection is broken, or other error.
  */
-int client_read_to(int *sock, StrBuf *Target, StrBuf *Buf, const char **Pos, int bytes, int timeout)
+int client_read_to(ParsedHttpHdrs *Hdr, StrBuf *Target, int bytes, int timeout)
 {
        const char *Error;
        int retval = 0;
 
 #ifdef HAVE_OPENSSL
        if (is_https) {
-               long bufremain = StrLength(Buf) - (*Pos - ChrPtr(Buf));
-               StrBufAppendBufPlain(Target, *Pos, bufremain, 0);
-               *Pos = NULL;
-               FlushStrBuf(Buf);
-
-               while ((StrLength(Buf) + StrLength(Target) < bytes) &&
-                      (retval >= 0))
-                       retval = client_read_sslbuffer(Buf, timeout);
-               if (retval >= 0) {
-                       StrBufAppendBuf(Target, Buf, 0); /* todo: Buf > bytes? */
+               long bufremain;
+               long baselen;
+
+               baselen = StrLength(Target);
+
+               if (Hdr->Pos == NULL)
+                       Hdr->Pos = ChrPtr(Hdr->ReadBuf);
+               bufremain = StrLength(Hdr->ReadBuf) - (Hdr->Pos - ChrPtr(Hdr->ReadBuf));
+
+               if (bytes < bufremain)
+                       bufremain = bytes;
+               StrBufAppendBufPlain(Target, Hdr->Pos, bufremain, 0);
+               StrBufCutLeft(Hdr->ReadBuf, bufremain);
+
+               if (bytes > bufremain) 
+               {
+                       while ((StrLength(Hdr->ReadBuf) + StrLength(Target) < bytes + baselen) &&
+                              (retval >= 0))
+                               retval = client_read_sslbuffer(Hdr->ReadBuf, timeout);
+                       if (retval >= 0) {
+                               StrBufAppendBuf(Target, Hdr->ReadBuf, 0); /* todo: Buf > bytes? */
 #ifdef HTTP_TRACING
-                       write(2, "\033[32m", 5);
-                       write(2, buf, bytes);
-                       write(2, "\033[30m", 5);
+                               write(2, "\033[32m", 5);
+                               write(2, buf, bytes);
+                               write(2, "\033[30m", 5);
 #endif
-                       return 1;
-               }
-               else {
-                       lprintf(2, "client_read_ssl() failed\n");
-                       return -1;
+                               return 1;
+                       }
+                       else {
+                               lprintf(2, "client_read_ssl() failed\n");
+                               return -1;
+                       }
                }
+               else 
+                       return 1;
        }
 #endif
 
        retval = StrBufReadBLOBBuffered(Target, 
-                                       Buf, Pos, 
-                                       sock, 
+                                       Hdr->ReadBuf, 
+                                       &Hdr->Pos, 
+                                       &Hdr->http_sock, 
                                        1, 
                                        bytes,
                                        O_TERM,
@@ -581,15 +714,22 @@ long end_burst(void)
        wcsession *WCC = WC;
         const char *ptr, *eptr;
         long count;
-       ssize_t res;
+       ssize_t res = 0;
         fd_set wset;
         int fdflags;
 
-       if (!DisableGzip && (WCC->gzip_ok) && CompressBuffer(WCC->WBuf))
+       if (!DisableGzip && (WCC->Hdr->HR.gzip_ok))
        {
-               hprintf("Content-encoding: gzip\r\n");
+               if (CompressBuffer(WCC->WBuf) > 0)
+                       hprintf("Content-encoding: gzip\r\n");
+               else {
+                       lprintf(CTDL_ALERT, "Compression failed: %d [%s] sending uncompressed\n", errno, strerror(errno));
+                       wc_backtrace();
+               }
        }
 
+       if (WCC->Hdr->HR.prohibit_caching)
+               hprintf("Pragma: no-cache\r\nCache-Control: no-store\r\nExpires:-1\r\n");
        hprintf("Content-length: %d\r\n\r\n", StrLength(WCC->WBuf));
 
        ptr = ChrPtr(WCC->HBuf);
@@ -611,19 +751,22 @@ long end_burst(void)
        write(2, ptr, StrLength(WCC->WBuf));
        write(2, "\033[30m", 5);
 #endif
-       fdflags = fcntl(WC->http_sock, F_GETFL);
+       if (WCC->Hdr->http_sock == -1)
+               return -1;
+       fdflags = fcntl(WC->Hdr->http_sock, F_GETFL);
 
-       while (ptr < eptr) {
+       while ((ptr < eptr) && (WCC->Hdr->http_sock != -1)){
                 if ((fdflags & O_NONBLOCK) == O_NONBLOCK) {
                         FD_ZERO(&wset);
-                        FD_SET(WCC->http_sock, &wset);
-                        if (select(WCC->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
+                        FD_SET(WCC->Hdr->http_sock, &wset);
+                        if (select(WCC->Hdr->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
                                 lprintf(2, "client_write: Socket select failed (%s)\n", strerror(errno));
                                 return -1;
                         }
                 }
 
-                if ((res = write(WCC->http_sock, 
+                if ((WCC->Hdr->http_sock == -1) || 
+                   (res = write(WCC->Hdr->http_sock, 
                                 ptr,
                                 count)) == -1) {
                         lprintf(2, "client_write: Socket write failed (%s)\n", strerror(errno));
@@ -645,17 +788,18 @@ long end_burst(void)
        write(2, "\033[30m", 5);
 #endif
 
-        while (ptr < eptr) {
+        while ((ptr < eptr) && (WCC->Hdr->http_sock != -1)) {
                 if ((fdflags & O_NONBLOCK) == O_NONBLOCK) {
                         FD_ZERO(&wset);
-                        FD_SET(WCC->http_sock, &wset);
-                        if (select(WCC->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
+                        FD_SET(WCC->Hdr->http_sock, &wset);
+                        if (select(WCC->Hdr->http_sock + 1, NULL, &wset, NULL, NULL) == -1) {
                                 lprintf(2, "client_write: Socket select failed (%s)\n", strerror(errno));
                                 return -1;
                         }
                 }
 
-                if ((res = write(WCC->http_sock, 
+                if ((WCC->Hdr->http_sock == -1) || 
+                   (res = write(WCC->Hdr->http_sock, 
                                 ptr,
                                 count)) == -1) {
                         lprintf(2, "client_write: Socket write failed (%s)\n", strerror(errno));
@@ -670,6 +814,69 @@ long end_burst(void)
 }
 
 
+/*
+ * lingering_close() a`la Apache. see
+ * http://www.apache.org/docs/misc/fin_wait_2.html for rationale
+ */
+int lingering_close(int fd)
+{
+       char buf[SIZ];
+       int i;
+       fd_set set;
+       struct timeval tv, start;
+
+       gettimeofday(&start, NULL);
+       if (fd == -1)
+               return -1;
+       shutdown(fd, 1);
+       do {
+               do {
+                       gettimeofday(&tv, NULL);
+                       tv.tv_sec = SLEEPING - (tv.tv_sec - start.tv_sec);
+                       tv.tv_usec = start.tv_usec - tv.tv_usec;
+                       if (tv.tv_usec < 0) {
+                               tv.tv_sec--;
+                               tv.tv_usec += 1000000;
+                       }
+                       FD_ZERO(&set);
+                       FD_SET(fd, &set);
+                       i = select(fd + 1, &set, NULL, NULL, &tv);
+               } while (i == -1 && errno == EINTR);
+
+               if (i <= 0)
+                       break;
+
+               i = read(fd, buf, sizeof buf);
+       } while (i != 0 && (i != -1 || errno == EINTR));
+
+       return close(fd);
+}
+
+void
+HttpNewModule_TCPSOCKETS
+(ParsedHttpHdrs *httpreq)
+{
+
+       httpreq->ReadBuf = NewStrBufPlain(NULL, SIZ * 4);
+}
+
+void
+HttpDetachModule_TCPSOCKETS
+(ParsedHttpHdrs *httpreq)
+{
+
+       FlushStrBuf(httpreq->ReadBuf);
+       ReAdjustEmptyBuf(httpreq->ReadBuf, 4 * SIZ, SIZ);
+}
+
+void
+HttpDestroyModule_TCPSOCKETS
+(ParsedHttpHdrs *httpreq)
+{
+
+       FreeStrBuf(&httpreq->ReadBuf);
+}
+
 
 void
 SessionNewModule_TCPSOCKETS
@@ -685,6 +892,7 @@ SessionDestroyModule_TCPSOCKETS
 {
        FreeStrBuf(&sess->CLineBuf);
        FreeStrBuf(&sess->ReadBuf);
+       sess->ReadPos = NULL;
        FreeStrBuf(&sess->MigrateReadLineBuf);
        if (sess->serv_sock > 0)
                close(sess->serv_sock);