Sat, 21 Jan 2017 12:46:31 +0100
fixes socket fd leak when SSL_accept fails
/* * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS HEADER. * * Copyright 2016 Olaf Wintermann. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * POSSIBILITY OF SUCH DAMAGE. */ #include <stdio.h> #include <stdlib.h> #include "websocket.h" #include "../util/io.h" #include "../util/pblock.h" #include "../util/util.h" #include "../util/strbuf.h" #include <ucx/string.h> #define WS_BUFFER_LEN 2048 NSAPI_PUBLIC int http_handle_websocket(Session *sn, Request *rq, WebSocket *websocket) { char *connection = pblock_findkeyval(pb_key_connection, rq->headers); char *upgrade = pblock_findval("upgrade", rq->headers); char *origin = pblock_findval("origin", rq->headers); char *wskey = pblock_findval("sec-websocket-key", rq->headers); char *wsprot = pblock_findval("sec-websocket-protocol", rq->headers); char *wsv = pblock_findval("sec-websocket-version", rq->headers); if(!connection || !upgrade) { return REQ_NOACTION; } if(sstrcasecmp(sstr(connection), S("upgrade"))) { return REQ_NOACTION; } if(sstrcasecmp(sstr(upgrade), S("websocket"))) { return REQ_NOACTION; } sstr_t wsaccept = sstrcat(2, sstr(wskey), S("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); unsigned char hash[20]; SHA1((const unsigned char*)wsaccept.ptr, wsaccept.length, hash); char *websocket_accept = util_base64encode((char*)hash, 20); sbuf_t *response = sbuf_new(512); sbuf_append(response, S("HTTP/1.1 101 Switching Protocols\r\n")); sbuf_append(response, S("Upgrade: websocket\r\n")); sbuf_append(response, S("Connection: Upgrade\r\n")); sbuf_append(response, S("Sec-WebSocket-Accept: ")); sbuf_puts(response, websocket_accept); sbuf_append(response, S("\r\n\r\n")); net_write(sn->csd, response->ptr, response->length); sbuf_free(response); free(websocket_accept); free(wsaccept.ptr); // start websocket I/O WSParser *parser = websocket_parser(sn); WSFrame frame; int ret = REQ_PROCEED; char *inbuf = pool_malloc(sn->pool, WS_BUFFER_LEN); ssize_t r = 0; while((r = net_read(sn->csd, inbuf, WS_BUFFER_LEN)) > 0) { websocket_input(parser, inbuf, r); WSMessage *msg; int error; while((msg = websocket_get_message(parser, &error)) != NULL) { websocket->on_message(websocket, msg); } if(error) { log_ereport(LOG_FAILURE, "websocket protocol error"); break; } } return ret; } WSParser* websocket_parser(Session *sn) { WSParser *parser = pool_malloc(sn->pool, sizeof(WSParser)); if(!parser) { return NULL; } ZERO(parser, sizeof(WSParser)); parser->pool = sn->pool; return parser; } void websocket_input(WSParser *parser, const char *data, size_t length) { parser->inbuf = data; parser->length = length; parser->pos = 0; } WSMessage* websocket_get_message(WSParser *parser, int *error) { WSFrame rframe; WSMessage *retmsg = NULL; while(parser->pos < parser->length) { const char *inbuf = parser->inbuf + parser->pos; size_t length = parser->length - parser->pos; if(parser->state == 0) { WSFrame frame; ZERO(&frame, sizeof(WSFrame)); /* * small buffer for a websocket frame without payload data * I know using so many buffers it not zero copy but * it makes things a little bit easier :) */ char frame_data[WS_FRAMEHEADER_BUFLEN]; size_t flen = 0; /* * when the last call of websocket_get_message didn't completed * a frame header, the tmpbuf contains the remaining bytes * in this case we combine tmpbuf and inputbuf */ if(parser->tmplen > 0) { memcpy(parser->tmpbuf, frame_data, parser->tmplen); flen = parser->tmplen; } size_t cp_remaining = length < WS_FRAMEHEADER_BUFLEN-flen ? length : WS_FRAMEHEADER_BUFLEN-flen; memcpy(&frame_data[flen], inbuf, cp_remaining); flen += cp_remaining; // ready to parse the frame ssize_t frame_hlen = websocket_get_frameheader( &frame, frame_data, flen); if(frame_hlen == -1) { // protocol error, abort *error = 1; return NULL; } if(frame_hlen == 0) { memcpy(parser->tmpbuf, frame_data, flen); } else { inbuf += frame_hlen; length -= frame_hlen; parser->pos += frame_hlen; // frame complete, create a message object if(frame.payload_length > 0) { WSMessage *msg = pool_malloc(parser->pool, sizeof(WSMessage)); msg->data = pool_malloc(parser->pool, frame.payload_length); msg->length = frame.payload_length; msg->next = NULL; msg->type = frame.opcode; if(frame.payload_length >= length) { // message complete memcpy(msg->data, inbuf, frame.payload_length); parser->pos += frame.payload_length; rframe = frame; retmsg = msg; break; } else { memcpy(msg->data, inbuf, length); parser->state = 1; parser->current = msg; parser->cur_plen = length; parser->frame = frame; return NULL; } } } } else { WSMessage *msg = parser->current; if(msg->length >= parser->cur_plen + length) { // still incomplete message memcpy(msg->data + parser->cur_plen, inbuf, length); parser->cur_plen += length; return NULL; } else { size_t cplen = msg->length - parser->cur_plen; memcpy(msg->data + parser->cur_plen, inbuf, cplen); parser->pos += cplen; parser->state = 0; parser->current = NULL; rframe = parser->frame; retmsg = msg; break; } } } if(retmsg && rframe.mask) { websocket_mask_data(retmsg->data, retmsg->length, rframe.masking_key); } return retmsg; } ssize_t websocket_get_frameheader(WSFrame *frame, const char *buf, size_t len) { if(len < 2) { return 0; // too small for anything } /* printf("websocket_get_frameheader: "); for(int i=0;i<len;i++) { printf("%x ", buf[i]); if(len > 15) { break; } } printf("\n"); */ size_t msglen = 2; // minimal length uint8_t fin = (buf[0] & 0x80) != 0; uint8_t opcode = buf[0] & 0xf; uint8_t mask = (buf[1] & 0x80) != 0; uint8_t payload_len = buf[1] & 0x7f; uint64_t payload_length = payload_len; if(payload_len == 126) { msglen += 2; if(len < msglen) { return 0; } payload_length = *((uint16_t*)(buf+2)); } else if(payload_len == 127) { msglen += 8; if(len < msglen) { return 0; } payload_length = *((uint64_t*)(buf+2)); } else if(payload_len > 127) { return -1; } uint32_t masking_key = 0; if(mask) { msglen += 4; if(len < msglen) { return 0; } masking_key = *((uint32_t*)(buf+msglen-4)); } frame->header_complete = TRUE; frame->fin = fin; frame->opcode = opcode; frame->mask = mask; frame->masking_key = masking_key; frame->payload_length = payload_length; return msglen; } void websocket_mask_data(char *buf, size_t len, uint32_t mask) { size_t m = len % 4; size_t alen = (len - m) / 4; uint32_t *data = (uint32_t*)buf; for(int i=0;i<alen;i++) { data[i] = data[i] ^ mask; } int j = 0; char *cmask = (char*)&mask; for(int i=len-m;i<len;i++) { buf[i] = buf[i] ^ cmask[j]; j++; } } /* ------------------------------ public API ------------------------------*/ NSAPI_PUBLIC int websocket_send_text(SYS_NETFD csd, char *msg, size_t len) { char frame[WS_FRAMEHEADER_BUFLEN]; frame[0] = 0b10000001; size_t hlen; if(len < 126) { frame[1] = (char)len; hlen = 2; } else if(len < 65536) { frame[1] = 126; uint16_t plen = htons(len); memcpy(frame + 2, &plen, 2); hlen = 4; } else { frame[1] = 127; // TODO hlen = 10; } struct iovec iov[2]; iov[0].iov_base = frame; iov[0].iov_len = hlen; iov[1].iov_base = msg; iov[1].iov_len = len; ssize_t w = net_writev(csd, iov, 2); if(w > 0) { return 0; } else { return 1; } }