src/server/util/io.c

Sat, 30 Mar 2024 12:35:29 +0100

author
Olaf Wintermann <olaf.wintermann@gmail.com>
date
Sat, 30 Mar 2024 12:35:29 +0100
changeset 514
922bfe380c8e
parent 513
9a49c245a49c
child 539
d556b45b0d24
permissions
-rw-r--r--

merge

/*
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
 *
 * Copyright 2013 Olaf Wintermann. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#ifdef __gnu_linux__
#define _GNU_SOURCE
#endif

#include <unistd.h>
#include <stdlib.h>

#ifdef XP_UNIX
#include <sys/uio.h>
#include <sys/uio.h>
#endif

#ifdef XP_WIN32

#endif

#if defined(LINUX) || defined(SOLARIS)
#include <sys/sendfile.h>
#define WS_SENDFILE
#elif defined(BSD)
#if defined(__NetBSD__) || defined(__OpenBSD__)
#define net_sys_sendfile net_fallback_sendfile
#else
#define WS_SENDFILE
#endif
#endif

#ifdef WS_SENDFILE
#define NET_SYS_SENDFILE net_sys_sendfile
#else
#define NET_SYS_SENDFILE net_fallback_sendfile
#endif



#include "../daemon/vfs.h"
#include "io.h"
#include "pool.h"
#include "../daemon/netsite.h"
#include "../daemon/event.h"
#include "cx/utils.h"
#include <cx/printf.h>

IOStream native_io_funcs = {
    (io_write_f)net_sys_write,
    (io_writev_f)net_sys_writev,
    (io_read_f)net_sys_read,
    (io_sendfile_f)NET_SYS_SENDFILE,
    (io_close_f)net_sys_close,
    NULL,
    (io_setmode_f)net_sys_setmode,
    (io_poll_f)net_sys_poll,
    0,
    0
};

IOStream http_io_funcs = {
    (io_write_f)net_http_write,
    (io_writev_f)net_http_writev,
    (io_read_f)net_http_read,
    (io_sendfile_f)net_http_sendfile,
    (io_close_f)net_http_close,
    (io_finish_f)net_http_finish,
    (io_setmode_f)net_http_setmode,
    (io_poll_f)net_http_poll,
    0,
    IO_STREAM_TYPE_HTTP
};

IOStream ssl_io_funcs = {
    (io_write_f)net_ssl_write,
    (io_writev_f)net_ssl_writev,
    (io_read_f)net_ssl_read,
    NULL,
    (io_close_f)net_ssl_close,
    (io_finish_f)net_ssl_finish,
    (io_setmode_f)net_ssl_setmode,
    (io_poll_f)net_ssl_poll,
    0,
    IO_STREAM_TYPE_SSL
};

static int net_write_max_attempts = 16384;

void io_set_max_writes(int n) {
    net_write_max_attempts = 1;
}

/*
 * Sysstream implementation
 */

IOStream* Sysstream_new(pool_handle_t *pool, SYS_SOCKET fd) {
    Sysstream *st = pool_malloc(pool, sizeof(Sysstream));
    st->st = native_io_funcs;
    st->fd = fd;
    return (IOStream*)st;
}

#ifdef XP_UNIX
ssize_t net_sys_write(Sysstream *st, const void *buf, size_t nbytes) {
    ssize_t r = write(st->fd, buf, nbytes);
    st->st.io_errno = errno;
    return r;
}

ssize_t net_sys_writev(Sysstream *st, struct iovec *iovec, int iovcnt) {
    ssize_t r = writev(st->fd, iovec, iovcnt);
    st->st.io_errno = errno;
    return r;
}

ssize_t net_sys_read(Sysstream *st, void *buf, size_t nbytes) {
    ssize_t r = read(st->fd, buf, nbytes);
    st->st.io_errno = errno;
    return r;
}

#ifdef WS_SENDFILE
ssize_t net_sys_sendfile(Sysstream *st, sendfiledata *sfd) {
    ssize_t ret = 0;
    off_t fileoffset = sfd->offset;
    if(sfd->fd->fd != -1) {
#ifdef BSD
        struct iovec hdvec;
        hdvec.iov_base = (void*)sfd->header;
        hdvec.iov_len = sfd->hlen;
        struct iovec trvec;
        trvec.iov_base = (void*)sfd->trailer;
        trvec.iov_len = sfd->tlen;
        struct sf_hdtr hdtr;
        hdtr.headers = &hdvec;
        hdtr.hdr_cnt = sfd->hlen > 0 ? 1 : 0;
        hdtr.trailers = &trvec;
        hdtr.trl_cnt = sfd->tlen > 0 ? 1 : 0;
        
        off_t len = sfd->len;
#ifdef OSX
        ret = sendfile(sfd->fd->fd, st->fd, fileoffset, &len, &hdtr, 0);
#else // BSD
        ret = sendfile(
                sfd->fd->fd,
                st->fd,
                fileoffset,
                sfd->len,
                &hdtr,
                NULL,
                0);
#endif
        if(ret == 0) {
            ret = sfd->hlen + sfd->tlen + sfd->len;
        }
        
#else // Solaris/Linux
        if(sfd->header) {
            ret += write(st->fd, sfd->header, sfd->hlen);
        }
        ret += sendfile(st->fd, sfd->fd->fd, &fileoffset, sfd->len);
        if(sfd->trailer) {
            ret += write(st->fd, sfd->trailer, sfd->tlen);
        }
#endif
    } else {
        return net_fallback_sendfile((IOStream*)st, sfd);
    }
    st->st.io_errno = errno;
    return ret;
}
#endif

void net_sys_close(Sysstream *st) {
    system_close(st->fd);
}

void net_sys_setmode(Sysstream *st, int mode) {
    int flags;
    if (-1 == (flags = fcntl(st->fd, F_GETFL, 0))) {
        flags = 0;
    }
    if(mode == IO_MODE_BLOCKING) {
        if (fcntl(st->fd, F_SETFL, flags & ~O_NONBLOCK) != 0) {
            perror("fcntl");
            // TODO: error
        }
    } else if(mode == IO_MODE_NONBLOCKING) {
        if (fcntl(st->fd, F_SETFL, flags | O_NONBLOCK) != 0) {
            perror("fcntl");
            // TODO: error
        }
    }
}

int net_sys_poll(Sysstream *st, EventHandler *ev, int events, Event *cb) {
    switch(events) {
        default: return -1;
        case IO_POLL_NONE: return ev_remove_poll(ev, st->fd);
        case IO_POLL_IN: return ev_pollin(ev, st->fd, cb);
        case IO_POLL_OUT: return ev_pollout(ev, st->fd, cb);
        case IO_POLL_IN | IO_POLL_OUT: return -1; // TODO: implement
    }
}

#elif defined(XP_WIN32)

ssize_t net_sys_write(Sysstream *st, void *buf, size_t nbytes) {
    int ret = send(st->fd, buf, nbytes, 0);
    if(ret == SOCKET_ERROR) {
        return IO_ERROR;
    }
    return ret;
}

ssize_t net_sys_writev(Sysstream *st, struct iovec *iovec, int iovcnt) {
    // TODO
}

ssize_t net_sys_read(Sysstream *st, void *buf, size_t nbytes) {
    int ret = recv(st->fd, buf, nbytes, 0);
    if(ret == SOCKET_ERROR) {
        return IO_ERROR;
    }
    return ret;
}

ssize_t net_sys_sendfile(Sysstream *st, sendfiledata *sfd) {
    // TODO
}

void net_sys_close(Sysstream *st) {
    closesocket(st->fd);
}

#endif


/*
 * HttpStream implementation
 */

IOStream* httpstream_new(pool_handle_t *pool, IOStream *fd) {
    HttpStream *st = pool_malloc(pool, sizeof(HttpStream));
    st->st = http_io_funcs;
    st->fd = fd;
    st->written = 0;
    st->max_read = 0;
    st->read = 0;
    st->read_total = 0;
    st->readbuf = NULL;
    st->bufsize = 0;
    st->buflen = NULL;
    st->bufpos = NULL;
    st->chunk_buf_pos = 0;
    st->current_chunk_length = 0;
    st->current_chunk_pos = 0;
    st->current_trailer = 0;
    st->write_chunk_buf_len = 0;
    st->write_chunk_buf_pos = 0;
    st->chunked_enc = WS_FALSE;
    st->read_eof = WS_TRUE;
    st->write_eof = WS_FALSE;
    return (IOStream*)st;
}

int httpstream_enable_chunked_read(IOStream *st, char *buffer, size_t bufsize, int *cursize, int *pos) {
    if(st->read != (io_read_f)net_http_read) {
        log_ereport(LOG_FAILURE, "%s", "httpstream_enable_chunked_read: IOStream is not an HttpStream");
        return 1;
    }
    st->read = (io_read_f)net_http_read_chunked;
    HttpStream *http = (HttpStream*)st;
    http->max_read = 0;
    http->read = 0;
    http->readbuf = buffer;
    http->bufsize = bufsize;
    http->buflen = cursize;
    http->bufpos = pos;
    http->chunk_buf_pos = 0;
    http->read_eof = WS_FALSE;
    return 0;
}

int httpstream_enable_chunked_write(IOStream *st) {
    if(st->type != IO_STREAM_TYPE_HTTP) {
        log_ereport(LOG_FAILURE, "%s", "httpstream_enable_chunked_write: IOStream is not an HttpStream");
        return 1;
    }
    HttpStream *http = (HttpStream*)st;
    http->chunked_enc = WS_TRUE;
    return 0;
}

int httpstream_set_max_read(IOStream *st, int64_t maxread) {
    if(st->write != (io_write_f)net_http_write) {
        log_ereport(LOG_FAILURE, "%s", "httpstream_set_max_read: IOStream is not an HttpStream");
        return 1;
    }
    HttpStream *http = (HttpStream*)st;
    http->max_read = maxread;
    return 0;
}

WSBool httpstream_eof(IOStream *st) {
    HttpStream *http = (HttpStream*)st;
    return http->read_eof;
}

int64_t httpstream_written(IOStream *st) {
    HttpStream *http = (HttpStream*)st;
    return http->written;
}

/*
 * iovec callback func
 * returns number of payload bytes written (number of bytes returned back to the net_write caller)
 */
typedef ssize_t(*writeop_finish_func)(HttpStream *st, char *base, size_t len, size_t written, void *udata);

static ssize_t httpstream_finish_prev_header(HttpStream *st, char *base, size_t len, size_t written, void *udata) {
    st->write_chunk_buf_pos += written;
    if(st->write_chunk_buf_pos == st->write_chunk_buf_len) {
        st->write_chunk_buf_len = 0;
        st->write_chunk_buf_pos = 0;
    }
    return 0;
}

static ssize_t httpstream_finish_data(HttpStream *st, char *base, size_t len, size_t written, void *udata) {
    st->current_chunk_pos += written;
    if(st->current_chunk_pos == st->current_chunk_length) {
        st->current_chunk_length = 0;
        st->current_chunk_pos = 0;
        st->current_trailer = 2;
    }
    return written;
}

static ssize_t httpstream_finish_new_header(HttpStream *st, char *base, size_t len, size_t written, void *udata) {
    size_t *chunk_len = udata;
    st->current_chunk_length = *chunk_len;
    st->current_chunk_pos = 0; // new chunk started
    if(written < len) {
        st->write_chunk_buf_len = len-written;
        st->write_chunk_buf_pos = 0;
        memcpy(st->write_chunk_buf + st->write_chunk_buf_pos, base+written, st->write_chunk_buf_len);
    } else {
        st->write_chunk_buf_len = 0;
        st->write_chunk_buf_pos = 0;
    }
    return 0;
}

static ssize_t httpstream_finish_trailer(HttpStream *st, char *base, size_t len, size_t written, void *udata) {
    st->current_trailer -= written;
    return 0;
}

ssize_t net_http_write(HttpStream *st, const void *buf, size_t nbytes) {
    st->st.io_errno = 0;
    if(st->write_eof) return 0;
    IOStream *fd = st->fd;
    if(!st->chunked_enc) {
        ssize_t w = fd->write(fd, buf, nbytes);
        st->written += w > 0 ? w : 0;
        return w;
    } else {
        struct iovec io[8];
        writeop_finish_func io_finished[8];
        void *io_finished_udata[8];
        int iovec_len = 0;
        
        char *str_crlf = "\r\n";
        
        size_t prev_chunk_len = st->current_chunk_length;
        size_t new_chunk_len = 0;
        
        // was the previous chunk header completely sent?
        if(st->write_chunk_buf_len > 0) {
            io[0].iov_base = &st->write_chunk_buf[st->write_chunk_buf_pos];
            io[0].iov_len = st->write_chunk_buf_len - st->write_chunk_buf_pos;
            io_finished[0] = httpstream_finish_prev_header;
            io_finished_udata[0] = &prev_chunk_len;
            iovec_len++;
        }
        
        // was the previous chunk payload completely sent?
        if(st->current_chunk_length != 0) {
            size_t chunk_remaining = st->current_chunk_length - st->current_chunk_pos;
            size_t prev_nbytes = chunk_remaining > nbytes ? nbytes : chunk_remaining;
            io[iovec_len].iov_base = (char*)buf;
            io[iovec_len].iov_len = prev_nbytes;
            io_finished[iovec_len] = httpstream_finish_data;
            buf = ((char*)buf) + prev_nbytes;
            nbytes -= prev_nbytes;
            iovec_len++;
            
            io[iovec_len].iov_base = str_crlf;
            io[iovec_len].iov_len = 2;
            io_finished[iovec_len] = httpstream_finish_trailer;
            iovec_len++;
        } else if(st->current_trailer > 0) {
            io[iovec_len].iov_base = str_crlf + 2 - st->current_trailer;
            io[iovec_len].iov_len = st->current_trailer;
            io_finished[iovec_len] = httpstream_finish_trailer;
            iovec_len++;
        }
        
        // TODO: on some plattforms iov_len is smaller than size_t
        //       if nbytes > INT_MAX, it should be devided into multiple
        //       iovec entries
        char chunk_len[16];
        if(nbytes > 0) {
            new_chunk_len = nbytes;
            io[iovec_len].iov_base = chunk_len;
            io[iovec_len].iov_len = snprintf(chunk_len, 16, "%zx\r\n", nbytes);
            io_finished[iovec_len] = httpstream_finish_new_header;
            io_finished_udata[iovec_len] = &new_chunk_len;
            iovec_len++;
            
            io[iovec_len].iov_base = (char*)buf;
            io[iovec_len].iov_len = nbytes;
            io_finished[iovec_len] = httpstream_finish_data;
            iovec_len++;
            
            io[iovec_len].iov_base = str_crlf;
            io[iovec_len].iov_len = 2;
            io_finished[iovec_len] = httpstream_finish_trailer;
            iovec_len++;
        }
        
        ssize_t wv = fd->writev(fd, io, iovec_len);
        if(wv <= 0) {
            st->st.io_errno = net_errno(st->fd);
            return wv;
        }
        
        ssize_t ret_w = 0;
        int i = 0;
        while(wv > 0) {
            char *base = io[i].iov_base;
            size_t len = io[i].iov_len;
            size_t wlen = wv > len ? len : wv;
            ret_w += io_finished[i](st, base, len, wlen, io_finished_udata[i]);
            wv -= wlen;
            i++;
        }
        
        st->written += ret_w;
        if(ret_w == 0) {
            st->st.io_errno = EWOULDBLOCK; // not sure if this is really correct
            //ret_w = -1;
        }
        return ret_w;
    }
}

ssize_t net_http_writev(HttpStream *st, struct iovec *iovec, int iovcnt) {
    if(st->write_eof) return 0;
    IOStream *fd = st->fd;
    if(st->chunked_enc) {
        struct iovec *io = calloc(iovcnt + 1, sizeof(struct iovec));
        if(!io) {
            return 0;
        }
        char chunk_len[16];
        io[0].iov_base = chunk_len;
        size_t len = 0;
        for(int i=0;i<iovcnt;i++) {
            len += iovec[i].iov_len;
        }
        io[0].iov_len = snprintf(chunk_len, 16, "\r\n%zx\r\n", len);
        memcpy(io + 1, iovec, iovcnt * sizeof(struct iovec));
        ssize_t r = fd->writev(fd, io, iovcnt + 1);
        
        ssize_t ret = r - io[0].iov_len;
        free(io);
        st->written += ret;
        return ret;
    } else {
        ssize_t w = fd->writev(fd, iovec, iovcnt);
        st->written += w;
        return w;
    }
}

ssize_t net_http_read(HttpStream *st, void *buf, size_t nbytes) {
    if(st->read >= st->max_read) {
        st->read_eof = WS_TRUE;
        return 0;
    }
    ssize_t r = st->fd->read(st->fd, buf, nbytes);
    if(r < 0) {
        st->st.io_errno = st->fd->io_errno;
    }
    st->read += r;
    return r;
}

#define BUF_UNNEEDED_DIFF 64
/*
 * read from st->chunk_buf first, read from st->fd if perform_io is true
 */
static ssize_t net_http_read_buffered(HttpStream *st, char *buf, size_t nbytes, WSBool read_data, WSBool *perform_io) {
    ssize_t r = 0;
    
    //memset(buf, 'x', nbytes);
    //char *orig_buf = buf;
    
    // copy available data from st->readbuf to buf
    int pos = *st->bufpos;
    size_t buf_available = *st->buflen - pos;
    if(buf_available) {
        size_t cplen = buf_available > nbytes ? nbytes : buf_available;
        if(read_data) {
            // if we read data (and not a chunk header), we limit the
            // amount of bytes we copy
            size_t chunk_available = st->max_read - st->read;
            cplen = cplen > chunk_available ? chunk_available : cplen;
            st->read += cplen;
        }
        memcpy(buf, st->readbuf + pos, cplen);
        *st->bufpos += cplen;
        r += cplen;
        buf += cplen;
        nbytes -= cplen;
    }
    
    // maybe perform IO and refill the read buffer
    // if we read data (read_data == true), make sure not to perform IO,
    // when a chunk is completed
    //
    // if we read a chunk header (read_data == false) it is very important
    // to not perform IO, if we have previously copied data from readbuf
    // this ensures we never override non-chunk-header data
    if(*perform_io && ((read_data && nbytes > 0 && st->max_read - st->read) || (!read_data && r == 0))) {
        if(*st->buflen - *st->bufpos > 0) {
            printf("todo: fix, should not happen, remove later\n");
        }
        // fill buffer again
        ssize_t rlen = st->fd->read(st->fd, st->readbuf, st->bufsize);
        *st->buflen = rlen;
        *st->bufpos = 0;
        *perform_io = WS_FALSE;
        if(rlen < 0) {
            st->st.io_errno = st->fd->io_errno;
        }
        
        if(rlen > 0) {
            // call func again to get data from buffer (no IO will be performed)
            r += net_http_read_buffered(st, buf, nbytes, read_data, perform_io);
        }
    }
    
    return r;
}


/*
 * parses a chunk header
 * the chunk length is stored in chunklen
 * return:  0 if the data is incomplete
 *         -1 if an error occured
 *         >0 chunk header length
 */
int http_stream_parse_chunk_header(char *str, int len, WSBool first, int64_t *chunklen) {
    char *hdr_start = NULL;
    char *hdr_end = NULL;
    int i = 0;
    if(first) {
        hdr_start = str;
    } else {
        if(len < 3) {
            return 0;
        }
        if(str[0] == '\r' && str[1] == '\n') {
            hdr_start = str+2;
            i = 2;
        } else if(str[0] == '\n') {
            hdr_start = str+1;
            i = 1;
        } else {
            return -1;
        }
    }
    
    for(;i<len;i++) {
        char c = str[i];
        if(c == '\r' || c == '\n') {
            hdr_end = str+i;
            break;
        }
    }
    if(!hdr_end || i == len) {
        return 0; // incomplete
    }
    
    if(*hdr_end == '\r') {
        // we also need '\n'
        if(hdr_end[1] != '\n') {
            return -1;
        }
        i++; // '\n' found
    }
    
    // parse
    char save_c = *hdr_end;
    *hdr_end = '\0';
    char *end;
    int64_t clen;
    errno = 0;
    clen = strtoll(hdr_start, &end, 16);
    *hdr_end = save_c;
    if(errno) {
        return -1;
    }
    i++;
    
    if(clen == 0) {
        // chunk length of 0 indicates the end
        // an additional \r\n is required (we also accept \n)
        if(i >= len) {
            return 0;
        }
        if(str[i] == '\n') {
            i++;
        } else if(str[i] == '\r') {
            if(++i >= len) {
                return 0;
            }
            if(str[i] == '\n') {
                i++;
            } else {
                return -1;
            }
        } else {
            return -1;
        }
    }
    
    *chunklen = clen;
    return i;
}

ssize_t net_http_read_chunked(HttpStream *st, void *buf, size_t nbytes) {
    if(st->read_eof) {
        return 0;
    }
    
    char *rbuf = buf; // buffer pos
    size_t rd = 0; // number of bytes read
    size_t rbuflen = nbytes; // number of bytes until end of buf
    WSBool perform_io = WS_TRUE; // we do only 1 read before we abort
    while(rd < nbytes && (perform_io || (st->max_read - st->read) > 0)) {
        // how many bytes are available in the current chunk
        size_t chunk_available = st->max_read - st->read;
        if(chunk_available > 0) {
            ssize_t r = net_http_read_buffered(st, rbuf, rbuflen, TRUE, &perform_io);
            if(r == 0) {
                break;
            }
            rd += r;
            st->read_total += r;
            rbuf += r;
            rbuflen -= r;
        } else {
            int chunkbuf_avail = HTTP_STREAM_CBUF_SIZE - st->chunk_buf_pos;
            if(chunkbuf_avail == 0) {
                // for some reason HTTP_STREAM_CBUF_SIZE is not enough
                // to store the chunk header
                // this indicates that something has gone wrong (or this is an attack)
                st->read_eof = WS_TRUE;
                return -1;
            }
            // fill st->chunk_buf
            ssize_t r = net_http_read_buffered(st, &st->chunk_buf[st->chunk_buf_pos], chunkbuf_avail, FALSE, &perform_io);
            if(r == 0) {
                break;
            }
            int chunkbuf_len = st->chunk_buf_pos + r;
            int64_t chunklen;
            int ret = http_stream_parse_chunk_header(st->chunk_buf, chunkbuf_len, st->read_total > 0 ? FALSE : TRUE, &chunklen);
            if(ret == 0) {
                // incomplete chunk header
                st->chunk_buf_pos = chunkbuf_len;
            } else if(ret < 0) {
                // error
                st->read_eof = WS_TRUE;
                return -1;
            } else if(ret > 0) {
                st->max_read = chunklen;
                st->read = 0;
                int remaining_len = chunkbuf_len - ret;
                if(remaining_len > 0) {
                    // we have read more into chunk_buf than the chunk_header
                    // it is save to just move bufpos back
                    *st->bufpos -= remaining_len;
                }
                //st->remaining_len = chunkbuf_len - ret;
                st->chunk_buf_pos = 0;
                
                if(chunklen == 0) {
                    st->read_eof = WS_TRUE;
                    break;
                }
            }
        }
        
        if(!perform_io && rd == 0) {
            perform_io = WS_TRUE;
        }
    }
    
    return rd;
}

ssize_t net_http_sendfile(HttpStream *st, sendfiledata *sfd) {
    if(st->write_eof) return 0;
    ssize_t ret = 0;
    // TODO: support chunked transfer encoding
    if(st->fd->sendfile) {
        ret = st->fd->sendfile(st->fd, sfd);
    } else {
        ret = net_fallback_sendfile((IOStream*)st, sfd);
    }
    
    st->written += ret > 0 ? ret : 0;
    
    return ret;
}

void net_http_close(HttpStream *st) {
    st->fd->close(st->fd);
}

void net_http_finish(HttpStream *st) {
    if(st->chunked_enc && !st->write_eof) {
        st->fd->write(st->fd, "0\r\n\r\n", 5);
    }
    st->write_eof = WS_TRUE;
}

void net_http_setmode(HttpStream *st, int mode) {
    st->fd->setmode(st->fd, mode);
}

int net_http_poll(HttpStream *st, EventHandler *ev, int events, Event *cb) {
    return st->fd->poll(st->fd, ev, events, cb);
}


/*
 * SSLStream implementation
 */

IOStream* sslstream_new(pool_handle_t *pool, SSL *ssl) {
    SSLStream *st = pool_malloc(pool, sizeof(SSLStream));
    st->st = ssl_io_funcs;
    st->ssl = ssl;
    st->error = 0;
    return (IOStream*)st;
}

ssize_t net_ssl_write(SSLStream *st, const void *buf, size_t nbytes) {
    int ret = SSL_write(st->ssl, buf, nbytes);
    if(ret <= 0) {
        st->error = SSL_get_error(st->ssl, ret);
        if(st->error == SSL_ERROR_WANT_WRITE || st->error == SSL_ERROR_WANT_READ) {
            st->st.io_errno = EWOULDBLOCK;
        } else {
            st->st.io_errno = -1;
        }
        ret = -1;
    }
    return ret;
}

ssize_t net_ssl_writev(SSLStream *st, struct iovec *iovec, int iovcnt) {
    ssize_t r = 0;
    for(int i=0;i<iovcnt;i++) {
        int ret = SSL_write(st->ssl, iovec[i].iov_base, iovec[i].iov_len);
        if(ret <= 0) {
            if(r == 0) {
                st->error = SSL_get_error(st->ssl, ret);
                if(st->error == SSL_ERROR_WANT_WRITE || st->error == SSL_ERROR_WANT_READ) {
                    st->st.io_errno = EWOULDBLOCK;
                } else {
                    st->st.io_errno = -1;
                }
            }
            break;
        }
        r += ret;
    }
    return r == 0 ? -1 : r;
}

ssize_t net_ssl_read(SSLStream *st, void *buf, size_t nbytes) {
    int ret = SSL_read(st->ssl, buf, nbytes);
    if(ret <= 0) {
        st->error = SSL_get_error(st->ssl, ret);
    }
    return ret;
}

void net_ssl_close(SSLStream *st) {
    int ret = SSL_shutdown(st->ssl);
    if(ret != 1) {
        st->error = SSL_get_error(st->ssl, ret);
    }
    system_close(SSL_get_fd(st->ssl));
}

void net_ssl_finish(SSLStream *st) {
    
}

void net_ssl_setmode(SSLStream *st, int mode) {
    int flags;
    if (-1 == (flags = fcntl(SSL_get_fd(st->ssl), F_GETFL, 0))) {
        flags = 0;
    }
    if(mode == IO_MODE_BLOCKING) {
        if (fcntl(SSL_get_fd(st->ssl), F_SETFL, flags & ~O_NONBLOCK) != 0) {
            perror("fcntl");
            // TODO: error
        }
    } else if(mode == IO_MODE_NONBLOCKING) {
        if (fcntl(SSL_get_fd(st->ssl), F_SETFL, flags | O_NONBLOCK) != 0) {
            perror("fcntl");
            // TODO: error
        }
    }
}

int net_ssl_poll(SSLStream *st, EventHandler *ev, int events, Event *cb) {
    int fd = SSL_get_fd(st->ssl);
    switch(events) {
        default: return -1;
        case IO_POLL_NONE: return ev_remove_poll(ev, fd);
        case IO_POLL_IN: return ev_pollin(ev, fd, cb);
        case IO_POLL_OUT: return ev_pollout(ev, fd, cb);
        case IO_POLL_IN | IO_POLL_OUT: return -1; // TODO: implement
    }
}

/* -------------------- public nsapi network functions -------------------- */

ssize_t net_read(SYS_NETFD fd, void *buf, size_t nbytes) {
    ssize_t r = ((IOStream*)fd)->read(fd, buf, nbytes);
    if(r == 0) {
        return IO_EOF;
    } else if(r < 0) {
        ((IOStream*)fd)->io_errno = errno;
        return IO_ERROR;
    }
    return r;
}

ssize_t net_write(SYS_NETFD fd, const void *buf, size_t nbytes) {
    size_t w = 0;
    size_t remaining = nbytes;
    const char *cbuf = buf;
    ssize_t r = 0;
    int attempts = 0;
    while(w < nbytes && attempts < net_write_max_attempts) {
        r = ((IOStream*)fd)->write(fd, cbuf, remaining);
        if(r <= 0) {
            break;
        }
        w += r;
        cbuf += r;
        remaining -= r;
        attempts++;
    }
    if(r < 0 && w == 0) {
        return IO_ERROR;
    }  
    return w;
}

ssize_t net_writev(SYS_NETFD fd, struct iovec *iovec, int iovcnt) {
    ssize_t r = ((IOStream*)fd)->writev(fd, iovec, iovcnt);
    if(r < 0) {
        ((IOStream*)fd)->io_errno = errno;
        return IO_ERROR;
    }
    return r;
}

ssize_t net_printf(SYS_NETFD fd, char *format, ...) {
    va_list arg;
    va_start(arg, format);
    cxmutstr buf = cx_vasprintf_a(cxDefaultAllocator, format, arg);
    ssize_t r = buf.length > 0 ? net_write(fd, buf.ptr, buf.length) : 0;
    free(buf.ptr);
    va_end(arg);
    if(r < 0) {
        ((IOStream*)fd)->io_errno = errno;
    }
    return r;
}

ssize_t net_sendfile(SYS_NETFD fd, sendfiledata *sfd) {
    IOStream *out = fd;
    if(out->sendfile && sfd->fd && sfd->fd->fd != -1) {
        ssize_t r = out->sendfile(fd, sfd);
        if(r < 0) {
            out->io_errno = errno;
            return IO_ERROR;
        }
        return r;
    } else {
        // stream/file does not support sendfile
        // do regular copy
        return net_fallback_sendfile(out, sfd);
    }
}

// private
ssize_t net_fallback_sendfile(IOStream *fd, sendfiledata *sfd) {
    char *buf = malloc(4096);
    if(!buf) {
        // TODO: out of memory error
        return IO_ERROR;
    }
    char *header = (char*)sfd->header;
    int hlen = sfd->hlen;
    char *trailer = (char*)sfd->trailer;
    int tlen = sfd->tlen;
    if(header == NULL) {
        hlen = 0;
    }
    if(trailer == NULL) {
        tlen = 0;
    }

    ssize_t r;
    while(hlen > 0) {
        r = fd->write(fd, header, hlen);
        header += r;
        hlen -= r;
        if(r <= 0) {
            free(buf);
            fd->io_errno = errno;
            return IO_ERROR;
        }
    }

    if(system_lseek(sfd->fd, sfd->offset, SEEK_SET) == -1) {
        free(buf);
        fd->io_errno = errno;
        return IO_ERROR;
    }

    size_t length = sfd->len;
    while(length > 0) {
        // TODO: remove
        if(length > sfd->len) {
            log_ereport(LOG_WARN, "net_fallback_sendfile: length > sfd->len: %zu > %zu", length, sfd->len);
            free(buf);
            return IO_ERROR;
        }
        
        if((r = system_fread(sfd->fd, buf, 4096)) <= 0) {
            break;
        }
        char *write_buf = buf;
        while(r > 0) {
            ssize_t w = fd->write(fd, write_buf, r);
            // TODO: remove
            if(w > r) {
                log_ereport(LOG_WARN, "net_fallback_sendfile: w > r, %zd > %zd", w, r);
                w = 0;
            }
            
            if(w <= 0) {
                free(buf);
                fd->io_errno = errno;
                return IO_ERROR;
            }
            r -= w;
            length -= w;
            write_buf += w;
        }
    }
    free(buf);
    if(length > 0) {
        fd->io_errno = errno;
        return IO_ERROR;
    }

    while(tlen > 0) {
        r = fd->write(fd, trailer, tlen);
        trailer += r;
        tlen -= r;
        if(r <= 0) {
            fd->io_errno = errno;
            return IO_ERROR;
        }
    }

    return sfd->hlen + sfd->len + sfd->tlen;
}

int net_flush(SYS_NETFD sd) {
    // TODO: implement
    return 0;
}

void net_close(SYS_NETFD fd) {
    ((IOStream*)fd)->close(fd);
}

int net_setnonblock(SYS_NETFD fd, int nonblock) {
    ((IOStream*)fd)->setmode(
            fd,
            nonblock ? IO_MODE_NONBLOCKING : IO_MODE_BLOCKING);
    return 0;
}

int net_errno(SYS_NETFD fd) {
    return ((IOStream*)fd)->io_errno;
}

// private
void net_finish(SYS_NETFD fd) {
    ((IOStream*)fd)->finish(fd);
}

mercurial