From 619ad7461d257a582f50be9f3e776f5c4d249ca1 Mon Sep 17 00:00:00 2001 From: lganzzzo Date: Sun, 25 Nov 2018 23:14:34 +0200 Subject: [PATCH] RequestHeadersReader added --- web/protocol/http/Http.cpp | 34 +++++++ web/protocol/http/Http.hpp | 5 + .../http/incoming/RequestHeadersReader.cpp | 95 ++++++++++++++++++- .../http/incoming/RequestHeadersReader.hpp | 11 +++ 4 files changed, 144 insertions(+), 1 deletion(-) diff --git a/web/protocol/http/Http.cpp b/web/protocol/http/Http.cpp index 672e734d..8c9ba872 100644 --- a/web/protocol/http/Http.cpp +++ b/web/protocol/http/Http.cpp @@ -367,6 +367,40 @@ std::shared_ptr Protocol::parseHeaders(oatpp::parser::Parsing } +void Protocol::parseRequestStartingLineStruct(RequestStartingLineStruct& line, + const std::shared_ptr& headersText, + oatpp::parser::ParsingCaret& caret, + Status& error) { + + oatpp::parser::ParsingCaret::Label methodLabel(caret); + if(caret.findChar(' ')){ + line.method = oatpp::data::share::MemoryLabel(headersText, methodLabel.getData(), methodLabel.getSize()); + caret.inc(); + } else { + error = Status::CODE_400; + return; + } + + oatpp::parser::ParsingCaret::Label pathLabel(caret); + if(caret.findChar(' ')){ + line.path = oatpp::data::share::MemoryLabel(headersText, pathLabel.getData(), pathLabel.getSize()); + caret.inc(); + } else { + error = Status::CODE_400; + return; + } + + oatpp::parser::ParsingCaret::Label protocolLabel(caret); + if(caret.findRN()){ + line.protocol = oatpp::data::share::MemoryLabel(headersText, protocolLabel.getData(), protocolLabel.getSize()); + caret.skipRN(); + } else { + error = Status::CODE_400; + return; + } + +} + void Protocol::parseOneHeaderLabel(HeadersLabels& headers, const std::shared_ptr& headersText, oatpp::parser::ParsingCaret& caret, diff --git a/web/protocol/http/Http.hpp b/web/protocol/http/Http.hpp index 398fafa6..ddf8e63b 100644 --- a/web/protocol/http/Http.hpp +++ b/web/protocol/http/Http.hpp @@ -291,6 +291,11 @@ public: static void parseOneHeader(Headers& headers, oatpp::parser::ParsingCaret& caret, Status& error); static std::shared_ptr parseHeaders(oatpp::parser::ParsingCaret& caret, Status& error); + static void parseRequestStartingLineStruct(RequestStartingLineStruct& line, + const std::shared_ptr& headersText, + oatpp::parser::ParsingCaret& caret, + Status& error); + static void parseOneHeaderLabel(HeadersLabels& headers, const std::shared_ptr& headersText, oatpp::parser::ParsingCaret& caret, diff --git a/web/protocol/http/incoming/RequestHeadersReader.cpp b/web/protocol/http/incoming/RequestHeadersReader.cpp index 659fc4a0..23209212 100644 --- a/web/protocol/http/incoming/RequestHeadersReader.cpp +++ b/web/protocol/http/incoming/RequestHeadersReader.cpp @@ -41,6 +41,9 @@ os::io::Library::v_size RequestHeadersReader::readHeadersSection(const std::shar v_int32 desiredToRead = m_bufferSize; if(progress + desiredToRead > m_maxHeadersSize) { desiredToRead = m_maxHeadersSize - progress; + if(desiredToRead <= 0) { + return -1; + } } res = connection->read(m_buffer, desiredToRead); @@ -81,11 +84,101 @@ RequestHeadersReader::Result RequestHeadersReader::readHeaders(const std::shared auto headersText = buffer.toString(); oatpp::parser::ParsingCaret caret (headersText); http::Status status; - http::Protocol::parseHeadersLabels(result.headers, headersText.getPtr(), caret, status); + http::Protocol::parseRequestStartingLineStruct(result.startingLine, headersText.getPtr(), caret, status); + if(status.code == 0) { + http::Protocol::parseHeadersLabels(result.headers, headersText.getPtr(), caret, status); + } } return result; } + + +RequestHeadersReader::Action RequestHeadersReader::readHeadersAsync(oatpp::async::AbstractCoroutine* parentCoroutine, + AsyncCallback callback, + const std::shared_ptr& connection) +{ + + class ReaderCoroutine : public oatpp::async::CoroutineWithResult { + private: + std::shared_ptr m_connection; + p_char8 m_buffer; + v_int32 m_bufferSize; + v_int32 m_maxHeadersSize; + v_word32 m_accumulator; + v_int32 m_progress; + RequestHeadersReader::Result m_result; + oatpp::data::stream::ChunkedBuffer m_bufferStream; + public: + + ReaderCoroutine(const std::shared_ptr& connection, + p_char8 buffer, v_int32 bufferSize, v_int32 maxHeadersSize) + : m_connection(connection) + , m_buffer(buffer) + , m_bufferSize(bufferSize) + , m_maxHeadersSize(maxHeadersSize) + , m_accumulator(0) + , m_progress(0) + {} + + Action act() override { + + v_int32 desiredToRead = m_bufferSize; + if(m_progress + desiredToRead > m_maxHeadersSize) { + desiredToRead = m_maxHeadersSize - m_progress; + if(desiredToRead <= 0) { + return error("Headers section is too large"); + } + } + + auto res = m_connection->read(m_buffer, desiredToRead); + if(res > 0) { + m_bufferStream.write(m_buffer, res); + + for(v_int32 i = 0; i < res; i ++) { + m_accumulator <<= 8; + m_accumulator |= m_buffer[i]; + if(m_accumulator == SECTION_END) { + m_result.bufferPosStart = i + 1; + m_result.bufferPosEnd = (v_int32) res; + return yieldTo(&ReaderCoroutine::parseHeaders); + } + } + + return waitRetry(); + + } else if(res == oatpp::data::stream::Errors::ERROR_IO_WAIT_RETRY || res == oatpp::data::stream::Errors::ERROR_IO_RETRY) { + return waitRetry(); + } else { + return abort(); + } + + } + + Action parseHeaders() { + + auto headersText = m_bufferStream.toString(); + oatpp::parser::ParsingCaret caret (headersText); + http::Status status; + http::Protocol::parseRequestStartingLineStruct(m_result.startingLine, headersText.getPtr(), caret, status); + if(status.code == 0) { + http::Protocol::parseHeadersLabels(m_result.headers, headersText.getPtr(), caret, status); + if(status.code == 0) { + return _return(m_result); + } else { + return error("error occurred while parsing headers"); + } + } else { + return error("can't parse starting line"); + } + + } + + }; + + return parentCoroutine->startCoroutineForResult(callback, connection, m_buffer, m_bufferSize, m_maxHeadersSize); + +} }}}}} diff --git a/web/protocol/http/incoming/RequestHeadersReader.hpp b/web/protocol/http/incoming/RequestHeadersReader.hpp index f4b5c966..c183af5f 100644 --- a/web/protocol/http/incoming/RequestHeadersReader.hpp +++ b/web/protocol/http/incoming/RequestHeadersReader.hpp @@ -26,10 +26,15 @@ #define oatpp_web_protocol_http_incoming_RequestHeadersReader_hpp #include "oatpp/web/protocol/http/Http.hpp" +#include "oatpp/core/async/Coroutine.hpp" namespace oatpp { namespace web { namespace protocol { namespace http { namespace incoming { class RequestHeadersReader { +public: + typedef oatpp::async::Action Action; +private: + static constexpr v_int32 SECTION_END = ('\r' << 24) | ('\n' << 16) | ('\r' << 8) | ('\n'); public: struct Result { @@ -38,6 +43,9 @@ public: v_int32 bufferPosStart; v_int32 bufferPosEnd; }; + +public: + typedef Action (oatpp::async::AbstractCoroutine::*AsyncCallback)(const Result&); private: os::io::Library::v_size readHeadersSection(const std::shared_ptr& connection, oatpp::data::stream::OutputStream* bufferStream, @@ -49,6 +57,9 @@ private: public: Result readHeaders(const std::shared_ptr& connection, http::Status& error); + Action readHeadersAsync(oatpp::async::AbstractCoroutine* parentCoroutine, + AsyncCallback callback, + const std::shared_ptr& connection); };