diff --git a/include/crow/settings.h b/include/crow/settings.h index 5958e3544..3c5d64865 100644 --- a/include/crow/settings.h +++ b/include/crow/settings.h @@ -11,6 +11,9 @@ /* #ifdef - enables ssl */ //#define CROW_ENABLE_SSL +/* #ifdef - enables websocket compression */ +//#define CROW_ENABLE_WEBSOCKET_COMPRESSION + /* #define - specifies log level */ /* Debug = 0 diff --git a/include/crow/websocket.h b/include/crow/websocket.h index c2dc3f81f..65d00e0c2 100644 --- a/include/crow/websocket.h +++ b/include/crow/websocket.h @@ -1,9 +1,11 @@ #pragma once #include #include +#include #include "crow/socket_adaptors.h" #include "crow/http_request.h" #include "crow/TinySHA1.hpp" +#include "crow/zlib.hpp" namespace crow { @@ -61,7 +63,18 @@ namespace crow return; } } - +#ifdef CROW_ENABLE_WEBSOCKET_COMPRESSION + std::string extensionsHeader = req.get_header_value("Sec-WebSocket-Extensions"); + std::vector extensions; + boost::split(extensions, extensionsHeader, boost::is_any_of(";")); + if (std::find(extensions.begin(), extensions.end(), "permessage-deflate") != extensions.end()) + { + bool reset_compressor_on_send_ = std::find(extensions.begin(), extensions.end(), "server_no_context_takeover") != extensions.end(); + compressor_.reset(new zlib_compressor(reset_compressor_on_send_, true, 15, Z_BEST_COMPRESSION, 8, Z_DEFAULT_STRATEGY)); + bool reset_decompressor_on_send_ = std::find(extensions.begin(), extensions.end(), "client_no_context_takeover") != extensions.end(); + decompressor_.reset(new zlib_decompressor(reset_decompressor_on_send_, true, 15)); + } +#endif // Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== // Sec-WebSocket-Version: 13 std::string magic = req.get_header_value("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -98,9 +111,11 @@ namespace crow void send_binary(const std::string& msg) override { dispatch([this, msg]{ - auto header = build_header(2, msg.size()); + std::string msg_ = compressor_ ? compressor_->compress(msg) : msg; + auto header = build_header(2, msg_.size()); + if (compressor_) header[0] += 0x40; write_buffers_.emplace_back(std::move(header)); - write_buffers_.emplace_back(msg); + write_buffers_.emplace_back(msg_); do_write(); }); } @@ -108,9 +123,11 @@ namespace crow void send_text(const std::string& msg) override { dispatch([this, msg]{ - auto header = build_header(1, msg.size()); + std::string msg_ = compressor_ ? compressor_->compress(msg) : msg; + auto header = build_header(1, msg_.size()); + if (compressor_) header[0] += 0x40; write_buffers_.emplace_back(std::move(header)); - write_buffers_.emplace_back(msg); + write_buffers_.emplace_back(msg_); do_write(); }); } @@ -167,6 +184,16 @@ namespace crow write_buffers_.emplace_back(header); write_buffers_.emplace_back(std::move(hello)); write_buffers_.emplace_back(crlf); + if (compressor_ && decompressor_) { + write_buffers_.emplace_back( + "Sec-WebSocket-Extensions: permessage-deflate" + "; server_max_window_bits=" + std::to_string(decompressor_->window_bits) + + "; client_max_window_bits=" + std::to_string(compressor_->window_bits) + + (compressor_->reset_before_compress ? "; server_no_context_takeover": "") + + (decompressor_->reset_before_decompress ? "; client_no_context_takeover": "") + ); + write_buffers_.emplace_back(crlf); + } write_buffers_.emplace_back(crlf); do_write(); if (open_handler_) @@ -368,6 +395,11 @@ namespace crow return mini_header_ & 0x8000; } + bool is_compressed() + { + return mini_header_ & 0x4000; + } + int opcode() { return (mini_header_ & 0x0f00) >> 8; @@ -387,7 +419,7 @@ namespace crow if (is_FIN()) { if (message_handler_) - message_handler_(*this, message_, is_binary_); + message_handler_(*this, (is_compressed() && decompressor_) ? decompressor_->decompress(message_) : message_, is_binary_); message_.clear(); } } @@ -398,7 +430,7 @@ namespace crow if (is_FIN()) { if (message_handler_) - message_handler_(*this, message_, is_binary_); + message_handler_(*this, (is_compressed() && decompressor_) ? decompressor_->decompress(message_) : message_, is_binary_); message_.clear(); } } @@ -410,7 +442,7 @@ namespace crow if (is_FIN()) { if (message_handler_) - message_handler_(*this, message_, is_binary_); + message_handler_(*this, (is_compressed() && decompressor_) ? decompressor_->decompress(message_) : message_, is_binary_); message_.clear(); } } @@ -514,6 +546,11 @@ namespace crow bool pong_received_{false}; bool is_close_handler_called_{false}; + bool reset_compressor_on_send_{false}; + bool reset_decompressor_on_send_{false}; + std::unique_ptr compressor_; + std::unique_ptr decompressor_; + std::function open_handler_; std::function message_handler_; std::function close_handler_; diff --git a/include/crow/zlib.hpp b/include/crow/zlib.hpp new file mode 100644 index 000000000..a3058c5f6 --- /dev/null +++ b/include/crow/zlib.hpp @@ -0,0 +1,117 @@ +#pragma once +#include "crow/settings.h" +#include +#include +#include + +namespace crow { + class zlib_compressor { + public: + zlib_compressor(bool reset_before_compress, bool noheader, int window_bits, int level, int mem_level, int strategy) + : reset_before_compress(reset_before_compress) + , window_bits(window_bits) { + stream = std::make_unique(); + stream->zalloc = 0; + stream->zfree = 0; + stream->opaque = 0; + + ::deflateInit2(stream.get(), + level, + Z_DEFLATED, + (noheader ? -window_bits : window_bits), + mem_level, + strategy + ); + } + + ~zlib_compressor() { + ::deflateEnd(stream.get()); + } + + std::string compress(const std::string& src) { + if(reset_before_compress) + ::deflateReset(stream.get()); + + stream->next_in = reinterpret_cast(const_cast(src.c_str())); + stream->avail_in = src.size(); + + const uint64_t bufferSize = 256; + boost::asio::streambuf buffer; + do { + boost::asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize); + + uint8_t* next_out = boost::asio::buffer_cast(chunk); + + stream->next_out = next_out; + stream->avail_out = bufferSize; + + ::deflate(stream.get(), reset_before_compress ? Z_FINISH : Z_SYNC_FLUSH); + + uint64_t outputSize = stream->next_out - next_out; + buffer.commit(outputSize); + } while(stream->avail_out == 0); + + uint64_t buffer_size = buffer.size(); + if(!reset_before_compress) buffer_size -= 4; + + return std::string(boost::asio::buffer_cast(buffer.data()), buffer_size); + } + + std::unique_ptr stream; + + bool reset_before_compress; + int window_bits; + }; + + class zlib_decompressor { + public: + zlib_decompressor(bool reset_before_decompress, bool noheader, int window_bits) + : reset_before_decompress(reset_before_decompress) + , window_bits(window_bits) { + stream = std::make_unique(); + stream->zalloc = 0; + stream->zfree = 0; + stream->opaque = 0; + + ::inflateInit2(stream.get(), (noheader ? -window_bits : window_bits)); + } + + ~zlib_decompressor() { + inflateEnd(stream.get()); + } + + std::string decompress(std::string src) { + if(reset_before_decompress) + inflateReset(stream.get()); + + src.push_back('\x00'); + src.push_back('\x00'); + src.push_back('\xff'); + src.push_back('\xff'); + + stream->next_in = reinterpret_cast(const_cast(src.c_str())); + stream->avail_in = src.size(); + + const uint64_t bufferSize = 256; + boost::asio::streambuf buffer; + do { + boost::asio::streambuf::mutable_buffers_type chunk = buffer.prepare(bufferSize); + + uint8_t* next_out = boost::asio::buffer_cast(chunk); + + stream->next_out = next_out; + stream->avail_out = bufferSize; + + int ret = ::inflate(stream.get(), reset_before_decompress ? Z_FINISH : Z_SYNC_FLUSH); + buffer.commit(stream->next_out - next_out); + } while(stream->avail_out == 0); + + return std::string(boost::asio::buffer_cast(buffer.data()), buffer.size()); + } + + std::unique_ptr stream; + + bool reset_before_decompress; + int window_bits; + }; +}