src/server/daemon/websocket.c

Sat, 03 Dec 2022 16:31:08 +0100

author
Olaf Wintermann <olaf.wintermann@gmail.com>
date
Sat, 03 Dec 2022 16:31:08 +0100
changeset 448
02b003f7560c
parent 415
d938228c382e
permissions
-rw-r--r--

use separate buffer for chunked transfer encoding, not inbuf

/*
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER.
 *
 * Copyright 2016 Olaf Wintermann. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <stdio.h>
#include <stdlib.h>

#include "websocket.h"

#include "../util/io.h"
#include "../util/pblock.h"
#include "../util/util.h"
#include "../util/strbuf.h"
#include <cx/string.h>

#define WS_BUFFER_LEN 2048

NSAPI_PUBLIC int http_handle_websocket(Session *sn, Request *rq, WebSocket *websocket) {
    char *connection = pblock_findkeyval(pb_key_connection, rq->headers);
    char *upgrade = pblock_findval("upgrade", rq->headers);
    char *origin = pblock_findval("origin", rq->headers);
    char *wskey = pblock_findval("sec-websocket-key", rq->headers);
    char *wsprot = pblock_findval("sec-websocket-protocol", rq->headers);
    char *wsv = pblock_findval("sec-websocket-version", rq->headers);
    
    if(!connection || !upgrade) {
        return REQ_NOACTION;
    }
    
    if(cx_strcasecmp(cx_str(connection), (cxstring)CX_STR("upgrade"))) {
        return REQ_NOACTION;
    }
    if(cx_strcasecmp(cx_str(upgrade), (cxstring)CX_STR("websocket"))) {
        return REQ_NOACTION;
    }
    
    cxmutstr wsaccept = cx_strcat(2, cx_str(wskey), (cxstring)CX_STR("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
    unsigned char hash[20];
    SHA1((const unsigned char*)wsaccept.ptr, wsaccept.length, hash);
    char *websocket_accept = util_base64encode((char*)hash, 20);
    
    sbuf_t *response = sbuf_new(512);
    sbuf_append(response, (cxstring)CX_STR("HTTP/1.1 101 Switching Protocols\r\n"));
    sbuf_append(response, (cxstring)CX_STR("Upgrade: websocket\r\n"));
    sbuf_append(response, (cxstring)CX_STR("Connection: Upgrade\r\n"));
    sbuf_append(response, (cxstring)CX_STR("Sec-WebSocket-Accept: "));
    sbuf_puts(response, websocket_accept);
    sbuf_append(response, (cxstring)CX_STR("\r\n\r\n"));
    
    net_write(sn->csd, response->ptr, response->length);
    sbuf_free(response);
    free(websocket_accept);
    free(wsaccept.ptr);
    
    // start websocket I/O
    WSParser *parser = websocket_parser(sn);
    
    WSFrame frame;
    
    int ret = REQ_PROCEED;
    char *inbuf = pool_malloc(sn->pool, WS_BUFFER_LEN);
    ssize_t r = 0;
    while((r = net_read(sn->csd, inbuf, WS_BUFFER_LEN)) > 0) {
        websocket_input(parser, inbuf, r);
        WSMessage *msg;
        int error;
        while((msg = websocket_get_message(parser, &error)) != NULL) {
            websocket->on_message(websocket, msg);
        }
        if(error) {
            log_ereport(LOG_FAILURE, "websocket protocol error");
            break;
        }
    }
    
    return ret;
}


WSParser* websocket_parser(Session *sn) {
    WSParser *parser = pool_malloc(sn->pool, sizeof(WSParser));
    if(!parser) {
        return NULL;
    }
    ZERO(parser, sizeof(WSParser));
    parser->pool = sn->pool;
    return parser;
}

void websocket_input(WSParser *parser, const char *data, size_t length) {
    parser->inbuf = data;
    parser->length = length;
    parser->pos = 0;
}

WSMessage* websocket_get_message(WSParser *parser, int *error) {
    WSFrame rframe;
    WSMessage *retmsg = NULL;
    
    while(parser->pos < parser->length) {
        const char *inbuf = parser->inbuf + parser->pos;
        size_t length = parser->length - parser->pos;
        
        if(parser->state == 0) {
            WSFrame frame;
            ZERO(&frame, sizeof(WSFrame));
            
            /*
             * small buffer for a websocket frame without payload data
             * I know using so many buffers it not zero copy but
             * it makes things a little bit easier :)
             */
            char frame_data[WS_FRAMEHEADER_BUFLEN];
            size_t flen = 0;
            
            /*
             * when the last call of websocket_get_message didn't completed
             * a frame header, the tmpbuf contains the remaining bytes
             * in this case we combine tmpbuf and inputbuf
             */
            if(parser->tmplen > 0) {
                memcpy(parser->tmpbuf, frame_data, parser->tmplen);
                flen = parser->tmplen;
            }
            size_t cp_remaining = length < WS_FRAMEHEADER_BUFLEN-flen ?
                length : WS_FRAMEHEADER_BUFLEN-flen;
            memcpy(&frame_data[flen], inbuf, cp_remaining);
            flen += cp_remaining;
            
            // ready to parse the frame
            ssize_t frame_hlen = websocket_get_frameheader(
                    &frame,
                    frame_data,
                    flen);
            
            if(frame_hlen == -1) {
                // protocol error, abort
                *error = 1;
                return NULL;
            }
            if(frame_hlen == 0) {
                memcpy(parser->tmpbuf, frame_data, flen);
            } else {
                inbuf += frame_hlen;
                length -= frame_hlen;
                parser->pos += frame_hlen;
                
                // frame complete, create a message object
                if(frame.payload_length > 0) {
                    WSMessage *msg = pool_malloc(parser->pool, sizeof(WSMessage));
                    msg->data = pool_malloc(parser->pool, frame.payload_length);
                    msg->length = frame.payload_length;
                    msg->next = NULL;
                    msg->type = frame.opcode;
                    
                    if(frame.payload_length >= length) {
                        // message complete
                        memcpy(msg->data, inbuf, frame.payload_length);
                        parser->pos += frame.payload_length;
                        
                        rframe = frame;
                        retmsg = msg;
                        break;
                    } else {
                        memcpy(msg->data, inbuf, length);
                        parser->state = 1;
                        parser->current = msg;
                        parser->cur_plen = length;
                        parser->frame = frame;
                        return NULL;
                    }
                }
            }
        } else {
            WSMessage *msg = parser->current;
            if(msg->length >= parser->cur_plen + length) {
                // still incomplete message
                memcpy(msg->data + parser->cur_plen, inbuf, length);
                parser->cur_plen += length;
                return NULL;
            } else {
                size_t cplen = msg->length - parser->cur_plen;
                memcpy(msg->data + parser->cur_plen, inbuf, cplen);
                parser->pos += cplen;
                parser->state = 0;
                parser->current = NULL;
                
                rframe = parser->frame;
                retmsg = msg;
                break;
            }
        }
    }
    
    if(retmsg && rframe.mask) {
        websocket_mask_data(retmsg->data, retmsg->length, rframe.masking_key);
    }
    return retmsg;
}


ssize_t websocket_get_frameheader(WSFrame *frame, const char *buf, size_t len) {
    if(len < 2) {
        return 0; // too small for anything
    }
    
/*
    printf("websocket_get_frameheader: ");
    for(int i=0;i<len;i++) {
        printf("%x ", buf[i]);
        if(len > 15) {
            break;
        }
    }
    printf("\n");
*/
    
    size_t msglen = 2; // minimal length
    
    uint8_t fin = (buf[0] & 0x80) != 0;
    uint8_t opcode = buf[0] & 0xf;
    
    uint8_t mask = (buf[1] & 0x80) != 0;
    uint8_t payload_len = buf[1] & 0x7f;
    
    uint64_t payload_length = payload_len;
    if(payload_len == 126) {
        msglen += 2;
        if(len < msglen) {
            return 0;
        }
        payload_length = *((uint16_t*)(buf+2));
    } else if(payload_len == 127) {
        msglen += 8;
        if(len < msglen) {
            return 0;
        }
        payload_length = *((uint64_t*)(buf+2));
    } else if(payload_len > 127) {
        return -1;
    }
    
    uint32_t masking_key = 0;
    if(mask) {
        msglen += 4;
        if(len < msglen) {
            return 0;
        }
        masking_key = *((uint32_t*)(buf+msglen-4));
    }
    
    frame->header_complete = TRUE;
    frame->fin = fin;
    frame->opcode = opcode;
    frame->mask = mask;
    frame->masking_key = masking_key;
    frame->payload_length = payload_length;
    
    return msglen;
}

void websocket_mask_data(char *buf, size_t len, uint32_t mask) {
    size_t m = len % 4;
    size_t alen = (len - m) / 4;
    
    uint32_t *data = (uint32_t*)buf;
    for(int i=0;i<alen;i++) {
        data[i] = data[i] ^ mask;
    }
    
    int j = 0;
    char *cmask = (char*)&mask;
    for(int i=len-m;i<len;i++) {
        buf[i] = buf[i] ^ cmask[j];
        j++;
    }
}

/* ------------------------------ public API ------------------------------*/

NSAPI_PUBLIC int websocket_send_text(SYS_NETFD csd, char *msg, size_t len) {
    char frame[WS_FRAMEHEADER_BUFLEN];
    frame[0] = (char)129; // 0b10000001
    size_t hlen;
    if(len < 126) {
        frame[1] = (char)len;
        hlen = 2;
    } else if(len < 65536) {
        frame[1] = 126;
        uint16_t plen = htons(len);
        memcpy(frame + 2, &plen, 2);
        hlen = 4;
    } else {
        frame[1] = 127;
        // TODO
        hlen = 10;
    }
    
    struct iovec iov[2];
    iov[0].iov_base = frame;
    iov[0].iov_len = hlen;
    
    iov[1].iov_base = msg;
    iov[1].iov_len = len;
    
    ssize_t w = net_writev(csd, iov, 2);
    if(w > 0) {
        return 0;
    } else {
        return 1;
    }
}

mercurial