From f8457bff9d31c264843dc4a568582ea910350401 Mon Sep 17 00:00:00 2001 From: Denes Matetelki Date: Fri, 25 Nov 2011 17:35:54 +0100 Subject: [PATCH] accept moved to StreamConnection --- include/Poll.hpp | 16 ++++++++-------- include/SslConnection.hpp | 1 + include/StreamConnection.hpp | 2 +- include/TcpConnection.hpp | 1 + other/sslserver_main.cpp | 36 ++++++++++++++++++++++++++++-------- src/Logger.cpp | 6 +++--- src/Poll.cpp | 14 +++++++------- src/SslConnection.cpp | 26 +++++++++++++++++++++++++- src/TcpConnection.cpp | 9 +++++++++ 9 files changed, 83 insertions(+), 28 deletions(-) diff --git a/include/Poll.hpp b/include/Poll.hpp index 28cc0f0..31150ca 100644 --- a/include/Poll.hpp +++ b/include/Poll.hpp @@ -1,7 +1,7 @@ #ifndef POLL_HPP #define POLL_HPP -#include "Connection.hpp" +#include "StreamConnection.hpp" #include #include @@ -12,7 +12,7 @@ class Poll { public: - Poll( Connection *connection, + Poll( StreamConnection *connection, const nfds_t maxClient = 10 ); virtual ~Poll(); @@ -43,13 +43,13 @@ private: typedef typename std::map< int, Connection* > ConnectionPool; - Connection *m_connection; - volatile bool m_polling; - ConnectionPool m_connectionPool; + StreamConnection *m_connection; + volatile bool m_polling; + ConnectionPool m_connectionPool; - nfds_t m_maxclients; - pollfd *m_fds; - nfds_t m_num_of_fds; + nfds_t m_maxclients; + pollfd *m_fds; + nfds_t m_num_of_fds; }; diff --git a/include/SslConnection.hpp b/include/SslConnection.hpp index 91c3b0a..a0a1526 100644 --- a/include/SslConnection.hpp +++ b/include/SslConnection.hpp @@ -43,6 +43,7 @@ public: bool bind(); bool listen( const int maxPendingQueueLen = 64 ); + int accept(); int getSocket() const; diff --git a/include/StreamConnection.hpp b/include/StreamConnection.hpp index 5a97f28..f2dfe47 100644 --- a/include/StreamConnection.hpp +++ b/include/StreamConnection.hpp @@ -19,7 +19,7 @@ public: virtual bool listen( const int maxPendingQueueLen = 64 ) = 0; /// @todo move accept and poll here -// virtual bool accept() = 0; + virtual int accept() = 0; // virtual bool poll() = 0; diff --git a/include/TcpConnection.hpp b/include/TcpConnection.hpp index b7980cb..b278184 100644 --- a/include/TcpConnection.hpp +++ b/include/TcpConnection.hpp @@ -36,6 +36,7 @@ public: bool bind(); bool listen( const int maxPendingQueueLen = 64 ); + int accept(); private: diff --git a/other/sslserver_main.cpp b/other/sslserver_main.cpp index 8976feb..360147d 100644 --- a/other/sslserver_main.cpp +++ b/other/sslserver_main.cpp @@ -11,6 +11,8 @@ #include #include +#include + class EchoMessage : public Message { public: @@ -63,6 +65,20 @@ protected: }; +SocketServer *socketServer; + +void signalHandler(int s) +{ + LOG_STATIC( Logger::INFO, std::string("Exiting after receiving signal: "). + append(TToStr(s)).c_str() ); + socketServer->stop(); + delete socketServer; + SslConnection::destroy(); + Logger::destroy(); + exit(1); +} + + int main(int argc, char* argv[] ) { if ( argc != 3 ) { @@ -70,27 +86,31 @@ int main(int argc, char* argv[] ) return 1; } + struct sigaction sigIntHandler; + sigIntHandler.sa_handler = signalHandler; + sigemptyset(&sigIntHandler.sa_mask); + sigIntHandler.sa_flags = 0; + sigaction(SIGINT, &sigIntHandler, NULL); + + Logger::createInstance(); Logger::init(std::cout); Logger::setLogLevel(Logger::FINEST); -// Logger::setNoPrefix(); SslConnection::init(); EchoMessage msg; SslConnection conn(argv[1], StrToT(argv[2]), &msg); - SocketServer socketServer(&conn); + socketServer = new SocketServer(&conn); - if ( !socketServer.start() ) { - LOG( Logger::ERR, "Failed to start TCP server, exiting..."); + if ( !socketServer->start() ) { + LOG_STATIC( Logger::ERR, "Failed to start TCP server, exiting..."); + delete socketServer; + SslConnection::destroy(); Logger::destroy(); return 1; } // never reached - sleep(1); - socketServer.stop(); - SslConnection::destroy(); - Logger::destroy(); return 0; } \ No newline at end of file diff --git a/src/Logger.cpp b/src/Logger.cpp index 21ea1ea..cf970e6 100644 --- a/src/Logger.cpp +++ b/src/Logger.cpp @@ -39,7 +39,7 @@ void Logger::log_pointer( const void* pointer, << COLOR_RESET << ":" << COLOR( FG_BROWN ) << line << COLOR_RESET << " " << COLOR( FG_CYAN ) << function << COLOR_RESET << " " - << COLOR( FG_BLUE ) << "\"" << pointer << "\"" + << COLOR( FG_BLUE ) << pointer << COLOR_RESET << std::endl; } @@ -66,8 +66,8 @@ void Logger::log_string( const int level, << COLOR_RESET << ":" << COLOR( FG_BROWN ) << line << COLOR_RESET << " " << COLOR( FG_CYAN ) << function << COLOR_RESET << " " - << color << "\"" << msg << "\"" - << COLOR( FG_BLUE ) << "\"" << pointer << "\"" + << color << "\"" << msg << "\" " + << COLOR( FG_BLUE ) << pointer << COLOR_RESET << std::endl; } diff --git a/src/Poll.cpp b/src/Poll.cpp index 53fd4b0..f4d4b1c 100644 --- a/src/Poll.cpp +++ b/src/Poll.cpp @@ -9,8 +9,8 @@ -Poll::Poll( Connection *connection, - const nfds_t maxClient ) +Poll::Poll( StreamConnection *connection, + const nfds_t maxClient ) : m_connection(connection) , m_polling(false) , m_connectionPool() @@ -82,12 +82,12 @@ void Poll::acceptClient() { TRACE; - sockaddr clientAddr; - socklen_t clientAddrLen; +// sockaddr clientAddr; +// socklen_t clientAddrLen; +// int client_socket = accept( m_connection->getSocket(), +// &clientAddr, &clientAddrLen ) ; + int client_socket = m_connection->accept(); - /// @todo put accept into Socket class - int client_socket = accept( m_connection->getSocket(), - &clientAddr, &clientAddrLen ) ; if ( client_socket == -1 ) { LOG( Logger::ERR, errnoToString("ERROR accepting. ").c_str() ); diff --git a/src/SslConnection.cpp b/src/SslConnection.cpp index d5f4321..31a0fed 100644 --- a/src/SslConnection.cpp +++ b/src/SslConnection.cpp @@ -15,6 +15,7 @@ void SslConnection::init() SSL_load_error_strings(); SSL_library_init(); + OpenSSL_add_all_algorithms(); } void SslConnection::destroy() @@ -74,6 +75,8 @@ SslConnection::~SslConnection() Connection* SslConnection::clone(const int socket) { + TRACE; + Connection *conn = new SslConnection( socket, m_message->clone(), m_bufferLength ); @@ -93,7 +96,7 @@ bool SslConnection::connect() return false; if ( SSL_connect (m_sslHandle) != 1 ) { - LOG (Logger::ERR, getSslError("Handshake with SSL server failed. ").c_str() ); + LOG (Logger::ERR, getSslError("SSL handshake failed. ").c_str() ); return false; } @@ -123,6 +126,27 @@ bool SslConnection::listen( const int maxPendingQueueLen ) } +int SslConnection::accept() +{ + TRACE; + + int client_socket = m_tcpConnection.accept(); + if ( client_socket == -1) + return client_socket; + + if ( SSL_accept(m_sslHandle) == -1 ) { + getSslError("SSL accept failed. "); + return -1; + } + + if ( SSL_set_fd(m_sslHandle, client_socket) == 0 ) { + getSslError("SSL set connection socket failed. "); + return -1; + } + + return client_socket; +} + /// @todo this function shall be refactored bool SslConnection::disconnect() { diff --git a/src/TcpConnection.cpp b/src/TcpConnection.cpp index 98e08df..f7575ce 100644 --- a/src/TcpConnection.cpp +++ b/src/TcpConnection.cpp @@ -81,6 +81,15 @@ bool TcpConnection::listen( const int maxPendingQueueLen ) } +int TcpConnection::accept() +{ + sockaddr clientAddr; + socklen_t clientAddrLen; + + return ::accept( getSocket(), &clientAddr, &clientAddrLen ) ; +} + + bool TcpConnection::disconnect() { TRACE;