#include <stdio.h>
#include <stdlib.h>
#include "websocket.h"
#include "../util/io.h"
#include "../util/pblock.h"
#include "../util/util.h"
#include "../util/strbuf.h"
#include <cx/string.h>
#define WS_BUFFER_LEN 2048
NSAPI_PUBLIC int http_handle_websocket(Session *sn, Request *rq, WebSocket *websocket) {
char *connection = pblock_findkeyval(pb_key_connection, rq->headers);
char *upgrade = pblock_findval(
"upgrade", rq->headers);
char *origin = pblock_findval(
"origin", rq->headers);
char *wskey = pblock_findval(
"sec-websocket-key", rq->headers);
char *wsprot = pblock_findval(
"sec-websocket-protocol", rq->headers);
char *wsv = pblock_findval(
"sec-websocket-version", rq->headers);
if(!connection || !upgrade) {
return REQ_NOACTION;
}
if(cx_strcasecmp(cx_str(connection), (cxstring)
CX_STR(
"upgrade"))) {
return REQ_NOACTION;
}
if(cx_strcasecmp(cx_str(upgrade), (cxstring)
CX_STR(
"websocket"))) {
return REQ_NOACTION;
}
cxmutstr wsaccept = cx_strcat(
2, cx_str(wskey), (cxstring)
CX_STR(
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
unsigned char hash[
20];
SHA1((
const unsigned char*)wsaccept.ptr, wsaccept.length, hash);
char *websocket_accept = util_base64encode((
char*)hash,
20);
sbuf_t *response = sbuf_new(
512);
sbuf_append(response, (cxstring)
CX_STR(
"HTTP/1.1 101 Switching Protocols\r\n"));
sbuf_append(response, (cxstring)
CX_STR(
"Upgrade: websocket\r\n"));
sbuf_append(response, (cxstring)
CX_STR(
"Connection: Upgrade\r\n"));
sbuf_append(response, (cxstring)
CX_STR(
"Sec-WebSocket-Accept: "));
sbuf_puts(response, websocket_accept);
sbuf_append(response, (cxstring)
CX_STR(
"\r\n\r\n"));
net_write(sn->csd, response->ptr, response->length);
sbuf_free(response);
free(websocket_accept);
free(wsaccept.ptr);
WSParser *parser = websocket_parser(sn);
WSFrame frame;
int ret =
REQ_PROCEED;
char *inbuf = pool_malloc(sn->pool,
WS_BUFFER_LEN);
ssize_t r =
0;
while((r = net_read(sn->csd, inbuf,
WS_BUFFER_LEN)) >
0) {
websocket_input(parser, inbuf, r);
WSMessage *msg;
int error;
while((msg = websocket_get_message(parser, &error)) !=
NULL) {
websocket->on_message(websocket, msg);
}
if(error) {
log_ereport(
LOG_FAILURE,
"websocket protocol error");
break;
}
}
return ret;
}
WSParser* websocket_parser(Session *sn) {
WSParser *parser = pool_malloc(sn->pool,
sizeof(WSParser));
if(!parser) {
return NULL;
}
ZERO(parser,
sizeof(WSParser));
parser->pool = sn->pool;
return parser;
}
void websocket_input(WSParser *parser,
const char *data,
size_t length) {
parser->inbuf = data;
parser->length = length;
parser->pos =
0;
}
WSMessage* websocket_get_message(WSParser *parser,
int *error) {
WSFrame rframe;
WSMessage *retmsg =
NULL;
while(parser->pos < parser->length) {
const char *inbuf = parser->inbuf + parser->pos;
size_t length = parser->length - parser->pos;
if(parser->state ==
0) {
WSFrame frame;
ZERO(&frame,
sizeof(WSFrame));
char frame_data[
WS_FRAMEHEADER_BUFLEN];
size_t flen =
0;
if(parser->tmplen >
0) {
memcpy(parser->tmpbuf, frame_data, parser->tmplen);
flen = parser->tmplen;
}
size_t cp_remaining = length <
WS_FRAMEHEADER_BUFLEN-flen ?
length :
WS_FRAMEHEADER_BUFLEN-flen;
memcpy(&frame_data[flen], inbuf, cp_remaining);
flen += cp_remaining;
ssize_t frame_hlen = websocket_get_frameheader(
&frame,
frame_data,
flen);
if(frame_hlen == -
1) {
*error =
1;
return NULL;
}
if(frame_hlen ==
0) {
memcpy(parser->tmpbuf, frame_data, flen);
}
else {
inbuf += frame_hlen;
length -= frame_hlen;
parser->pos += frame_hlen;
if(frame.payload_length >
0) {
WSMessage *msg = pool_malloc(parser->pool,
sizeof(WSMessage));
msg->data = pool_malloc(parser->pool, frame.payload_length);
msg->length = frame.payload_length;
msg->next =
NULL;
msg->type = frame.opcode;
if(frame.payload_length >= length) {
memcpy(msg->data, inbuf, frame.payload_length);
parser->pos += frame.payload_length;
rframe = frame;
retmsg = msg;
break;
}
else {
memcpy(msg->data, inbuf, length);
parser->state =
1;
parser->current = msg;
parser->cur_plen = length;
parser->frame = frame;
return NULL;
}
}
}
}
else {
WSMessage *msg = parser->current;
if(msg->length >= parser->cur_plen + length) {
memcpy(msg->data + parser->cur_plen, inbuf, length);
parser->cur_plen += length;
return NULL;
}
else {
size_t cplen = msg->length - parser->cur_plen;
memcpy(msg->data + parser->cur_plen, inbuf, cplen);
parser->pos += cplen;
parser->state =
0;
parser->current =
NULL;
rframe = parser->frame;
retmsg = msg;
break;
}
}
}
if(retmsg && rframe.mask) {
websocket_mask_data(retmsg->data, retmsg->length, rframe.masking_key);
}
return retmsg;
}
ssize_t websocket_get_frameheader(WSFrame *frame,
const char *buf,
size_t len) {
if(len <
2) {
return 0;
}
size_t msglen =
2;
uint8_t fin = (buf[
0] & 0x80) !=
0;
uint8_t opcode = buf[
0] & 0xf;
uint8_t mask = (buf[
1] & 0x80) !=
0;
uint8_t payload_len = buf[
1] & 0x7f;
uint64_t payload_length = payload_len;
if(payload_len ==
126) {
msglen +=
2;
if(len < msglen) {
return 0;
}
payload_length = *((
uint16_t*)(buf+
2));
}
else if(payload_len ==
127) {
msglen +=
8;
if(len < msglen) {
return 0;
}
payload_length = *((
uint64_t*)(buf+
2));
}
else if(payload_len >
127) {
return -
1;
}
uint32_t masking_key =
0;
if(mask) {
msglen +=
4;
if(len < msglen) {
return 0;
}
masking_key = *((
uint32_t*)(buf+msglen-
4));
}
frame->header_complete =
TRUE;
frame->fin = fin;
frame->opcode = opcode;
frame->mask = mask;
frame->masking_key = masking_key;
frame->payload_length = payload_length;
return msglen;
}
void websocket_mask_data(
char *buf,
size_t len,
uint32_t mask) {
size_t m = len %
4;
size_t alen = (len - m) /
4;
uint32_t *data = (
uint32_t*)buf;
for(
int i=
0;i<alen;i++) {
data[i] = data[i] ^ mask;
}
int j =
0;
char *cmask = (
char*)&mask;
for(
int i=len-m;i<len;i++) {
buf[i] = buf[i] ^ cmask[j];
j++;
}
}
NSAPI_PUBLIC int websocket_send_text(
SYS_NETFD csd,
char *msg,
size_t len) {
char frame[
WS_FRAMEHEADER_BUFLEN];
frame[
0] = (
char)
129;
size_t hlen;
if(len <
126) {
frame[
1] = (
char)len;
hlen =
2;
}
else if(len <
65536) {
frame[
1] =
126;
uint16_t plen = htons(len);
memcpy(frame +
2, &plen,
2);
hlen =
4;
}
else {
frame[
1] =
127;
hlen =
10;
}
struct iovec iov[
2];
iov[
0].iov_base = frame;
iov[
0].iov_len = hlen;
iov[
1].iov_base = msg;
iov[
1].iov_len = len;
ssize_t w = net_writev(csd, iov,
2);
if(w >
0) {
return 0;
}
else {
return 1;
}
}