diff --git a/.gitignore b/.gitignore index 99d0414..f39e4b7 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ test/testCppUtils.out *.kate-swp other/client other/server +other/sslclient +other/sslserver + diff --git a/doxyfile b/doxyfile index 8f4915e..dc3c583 100644 --- a/doxyfile +++ b/doxyfile @@ -239,7 +239,7 @@ EXTENSION_MAPPING = # func(std::string) {}). This also makes the inheritance and collaboration # diagrams that involve STL classes more complete and accurate. -BUILTIN_STL_SUPPORT = YES +BUILTIN_STL_SUPPORT = NO # If you use Microsoft's C++/CLI language, you should set this option to YES to # enable parsing support. diff --git a/include/Common.hpp b/include/Common.hpp index 76d3f81..630c07f 100644 --- a/include/Common.hpp +++ b/include/Common.hpp @@ -108,11 +108,20 @@ inline std::string TToStr(const T t) } +// template +// inline void StrToT( T &t, const std::string s ) +// { +// std::stringstream ss(s); +// ss >> t; +// } + template -inline void StrToT( T &t, const std::string s ) +inline T StrToT( const std::string s ) { std::stringstream ss(s); + T t; ss >> t; + return t; } diff --git a/include/Connection.hpp b/include/Connection.hpp new file mode 100644 index 0000000..c652d9e --- /dev/null +++ b/include/Connection.hpp @@ -0,0 +1,42 @@ +#ifndef CONNECTION_HPP +#define CONNECTION_HPP + + +#include "string" + +class Connection +{ +public: + + virtual ~Connection(); + virtual Connection* clone(const int socket) = 0; + + virtual bool bind() = 0; + + virtual bool send( const void* message, const size_t length ) = 0; + virtual bool receive() = 0; + + std::string getHost() const; + int getPort() const; + + void setHost(const std::string host); + void setPort(const int port); + + virtual int getSocket() const = 0; + + +protected: + + Connection(std::string host = std::string("invalid"), int port = -1); + + std::string m_host; + int m_port; + + +private: + + Connection(const Connection &); + Connection& operator= (const Connection &); +}; + +#endif // CONNECTION_HPP diff --git a/include/Logger.hpp b/include/Logger.hpp index 21b47e6..31df661 100644 --- a/include/Logger.hpp +++ b/include/Logger.hpp @@ -35,12 +35,13 @@ public: inline static LogLevel getLoglevel() { return m_logLevel; } - static void log_pointer( const void* msg, + static void log_pointer( const void* pointer, const char* file, const int line, const char* function); static void log_string( const int level, + const void* pointer, const char* msg, const char* file, const int line, @@ -65,6 +66,7 @@ private: #define TRACE (void)0 #define TRACE_STATIC (void)0 #define LOG(level, msg) (void)0 + #define LOG_STATIC(level, msg) (void)0 #else @@ -83,7 +85,13 @@ private: #define LOG(level, msg) \ if ( Logger::getInstance()->getLoglevel() >= level ) \ Logger::getInstance()->log_string( \ - level, msg, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + level, this, msg, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + else (void)0 + + #define LOG_STATIC(level, msg) \ + if ( Logger::getInstance()->getLoglevel() >= level ) \ + Logger::getInstance()->log_string( \ + level, 0, msg, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ else (void)0 #endif diff --git a/include/Message.hpp b/include/Message.hpp index 4abf82a..62f2df9 100644 --- a/include/Message.hpp +++ b/include/Message.hpp @@ -12,14 +12,14 @@ * getExpectedLength(). */ -class SocketConnection; +class Connection; class Message { public: - Message( SocketConnection *connection, + Message( Connection *connection, void *msgParam = 0 ) : m_connection(connection) , m_param(msgParam) @@ -43,7 +43,7 @@ public: const size_t msgLen ) = 0; virtual void onMessageReady() = 0; - void setConnection(SocketConnection* conn ) + void setConnection(Connection* conn ) { TRACE; m_connection = conn; @@ -54,7 +54,7 @@ protected: virtual size_t getExpectedLength() = 0; - SocketConnection *m_connection; + Connection *m_connection; void *m_param; std::string m_buffer; diff --git a/include/Poll.hpp b/include/Poll.hpp index a9b677c..28cc0f0 100644 --- a/include/Poll.hpp +++ b/include/Poll.hpp @@ -1,7 +1,7 @@ #ifndef POLL_HPP #define POLL_HPP -#include "SocketConnection.hpp" +#include "Connection.hpp" #include #include @@ -12,8 +12,8 @@ class Poll { public: - Poll( SocketConnection *connection, - const nfds_t maxClient = 10 ); + Poll( Connection *connection, + const nfds_t maxClient = 10 ); virtual ~Poll(); @@ -41,15 +41,15 @@ private: bool removeFd( const int socket ); - typedef typename std::map< int, SocketConnection* > ConnectionPool; + typedef typename std::map< int, Connection* > ConnectionPool; - SocketConnection *m_connection; - volatile bool m_polling; - ConnectionPool m_connectionPool; + Connection *m_connection; + volatile bool m_polling; + ConnectionPool m_connectionPool; - nfds_t m_maxclients; - pollfd *m_fds; - nfds_t m_num_of_fds; + nfds_t m_maxclients; + pollfd *m_fds; + nfds_t m_num_of_fds; }; diff --git a/include/Socket.hpp b/include/Socket.hpp index 21c10eb..59ebf10 100644 --- a/include/Socket.hpp +++ b/include/Socket.hpp @@ -20,7 +20,7 @@ public: virtual ~Socket(); bool createSocket(); - void closeSocket(); + bool closeSocket(); bool connectToHost( const std::string host, const std::string port ); diff --git a/include/SocketClient.hpp b/include/SocketClient.hpp index e161057..705d23b 100644 --- a/include/SocketClient.hpp +++ b/include/SocketClient.hpp @@ -2,7 +2,7 @@ #define SOCKET_CLIENT_HPP -#include "SocketConnection.hpp" +#include "StreamConnection.hpp" #include "Thread.hpp" #include "Poll.hpp" @@ -46,7 +46,7 @@ private: public: - SocketClient (SocketConnection *connection ); + SocketClient (StreamConnection *connection ); virtual ~SocketClient(); @@ -63,7 +63,7 @@ private: SocketClient& operator=(const SocketClient& ); - SocketConnection *m_connection; + StreamConnection *m_connection; PollerThread m_watcher; }; diff --git a/include/SocketConnection.hpp b/include/SocketConnection.hpp deleted file mode 100644 index 6ec771f..0000000 --- a/include/SocketConnection.hpp +++ /dev/null @@ -1,56 +0,0 @@ -#ifndef SOCKET_CONNECTION_HPP -#define SOCKET_CONNECTION_HPP - -#include "Socket.hpp" -#include "Message.hpp" - -#include - - -class SocketConnection -{ -public: - - SocketConnection ( const int socket, - Message *message, - const size_t bufferLength = 1024 ); - - SocketConnection ( const std::string host, - const std::string port, - Message *message, - const size_t bufferLength = 1024 ); - - virtual ~SocketConnection(); - - virtual SocketConnection* clone(const int socket) = 0; - virtual bool connectToHost() = 0; - virtual bool bindToHost() = 0; - virtual bool listen( const int maxPendingQueueLen = 64 ) = 0; - virtual void closeConnection() = 0; - - virtual bool send( const void* message, const size_t length ) = 0; - virtual bool receive() = 0; - - int getSocket() const; - std::string getHost() const; - std::string getPort() const; - -protected: - - Socket m_socket; - std::string m_host; - std::string m_port; - Message *m_message; - - unsigned char *m_buffer; - size_t m_bufferLength; - -private: - - SocketConnection(const SocketConnection&); - SocketConnection& operator=(const SocketConnection&); - -}; - - -#endif // SOCKET_CONNECTION_HPP diff --git a/include/SocketServer.hpp b/include/SocketServer.hpp index 21cd8b9..70b22e7 100644 --- a/include/SocketServer.hpp +++ b/include/SocketServer.hpp @@ -1,7 +1,7 @@ #ifndef SOCKET_SERVER_HPP #define SOCKET_SERVER_HPP -#include "SocketConnection.hpp" +#include "StreamConnection.hpp" #include "Poll.hpp" @@ -9,9 +9,9 @@ class SocketServer { public: - SocketServer ( SocketConnection *connection, - const int maxClients = 5, - const int maxPendingQueueLen = 10 ); + SocketServer ( StreamConnection *connection, + const int maxClients = 5, + const int maxPendingQueueLen = 10 ); virtual ~SocketServer(); @@ -24,7 +24,7 @@ private: SocketServer(const SocketServer&); SocketServer& operator=(const SocketServer&); - SocketConnection *m_connection; + StreamConnection *m_connection; Poll m_poll; const int m_maxPendingQueueLen; }; diff --git a/include/SslConnection.hpp b/include/SslConnection.hpp index cc5313a..91c3b0a 100644 --- a/include/SslConnection.hpp +++ b/include/SslConnection.hpp @@ -2,7 +2,7 @@ #define SSL_CONNECTION_HPP -#include "SocketConnection.hpp" +#include "StreamConnection.hpp" #include "TcpConnection.hpp" #include "Message.hpp" @@ -15,7 +15,7 @@ /// @note Call init/destroy before/after usage -class SslConnection : public SocketConnection +class SslConnection : public StreamConnection { public: @@ -27,33 +27,38 @@ public: const size_t bufferLength = 1024 ); SslConnection ( const std::string host, - const std::string port, + const int port, Message *message, const size_t bufferLength = 1024 ); virtual ~SslConnection(); - SocketConnection* clone(const int socket); + Connection* clone(const int socket); - bool connectToHost(); - bool bindToHost(); - bool listen( const int maxPendingQueueLen = 64 ); - void closeConnection(); + bool connect(); + bool disconnect(); bool send( const void* message, const size_t length ); bool receive(); + bool bind(); + bool listen( const int maxPendingQueueLen = 64 ); + + int getSocket() const; private: SslConnection(const SslConnection&); SslConnection& operator=(const SslConnection&); - bool connect(); + bool initHandlers(); std::string getSslError(const std::string &msg); TcpConnection m_tcpConnection; + Message *m_message; + unsigned char *m_buffer; + size_t m_bufferLength; SSL *m_sslHandle; SSL_CTX *m_sslContext; }; diff --git a/include/StreamConnection.hpp b/include/StreamConnection.hpp new file mode 100644 index 0000000..5a97f28 --- /dev/null +++ b/include/StreamConnection.hpp @@ -0,0 +1,38 @@ +#ifndef STREAM_CONNECTION_HPP +#define STREAM_CONNECTION_HPP + + +#include "Connection.hpp" + +#include "string" + +class StreamConnection : public Connection +{ +public: + + virtual ~StreamConnection() {}; + virtual Connection* clone(const int socket) = 0; + + virtual bool connect() = 0; + virtual bool disconnect() = 0; + + virtual bool listen( const int maxPendingQueueLen = 64 ) = 0; + + /// @todo move accept and poll here +// virtual bool accept() = 0; +// virtual bool poll() = 0; + + +protected: + + StreamConnection(std::string host = std::string("invalid"), int port = -1) + : Connection(host, port) {}; + + +private: + + StreamConnection(const StreamConnection &); + StreamConnection& operator= (const StreamConnection &); +}; + +#endif // STREAM_CONNECTION_HPP diff --git a/include/TcpConnection.hpp b/include/TcpConnection.hpp index fdce5c8..b7980cb 100644 --- a/include/TcpConnection.hpp +++ b/include/TcpConnection.hpp @@ -2,13 +2,14 @@ #define TCP_CONNECTION_HPP -#include "SocketConnection.hpp" +#include "StreamConnection.hpp" #include "Message.hpp" +#include "Socket.hpp" #include -class TcpConnection : public SocketConnection +class TcpConnection : public StreamConnection { public: @@ -17,27 +18,34 @@ public: const size_t bufferLength = 1024 ); TcpConnection ( const std::string host, - const std::string port, + const int port, Message *message, const size_t bufferLength = 1024 ); virtual ~TcpConnection(); - SocketConnection* clone(const int socket); + Connection* clone(const int socket); - bool connectToHost(); - bool bindToHost(); - bool listen( const int maxPendingQueueLen = 64 ); - void closeConnection(); + bool connect(); + bool disconnect(); bool send( const void* message, const size_t length ); bool receive(); + int getSocket() const; + + bool bind(); + bool listen( const int maxPendingQueueLen = 64 ); private: TcpConnection(const TcpConnection&); TcpConnection& operator=(const TcpConnection&); + + Socket m_socket; + Message *m_message; + unsigned char *m_buffer; + size_t m_bufferLength; }; diff --git a/include/Thread.hpp b/include/Thread.hpp index 592d212..a9d4407 100644 --- a/include/Thread.hpp +++ b/include/Thread.hpp @@ -19,7 +19,7 @@ public: protected: - bool m_isRunning; + volatile bool m_isRunning; private: diff --git a/other/sslclient_main.cpp b/other/sslclient_main.cpp index e7ede3c..3f39f0a 100644 --- a/other/sslclient_main.cpp +++ b/other/sslclient_main.cpp @@ -1,4 +1,5 @@ -// gpp sslclient_main.cpp -o sslclient -I../include ../src/Logger.cpp ../src/Thread.cpp ../src/Socket.cpp -lpthread ../src/SocketClient.cpp ../src/Poll.cpp ../src/SocketConnection.cpp ../src/SslConnection.cpp -lssl -lcrypto ../src/TcpConnection.cpp +// gpp sslclient_main.cpp -o sslclient -I../include ../src/Logger.cpp ../src/Thread.cpp ../src/Socket.cpp -lpthread ../src/SocketClient.cpp ../src/Poll.cpp ../src/Connection.cpp ../src/SslConnection.cpp -lssl -lcrypto ../src/TcpConnection.cpp + #include "Logger.hpp" @@ -11,6 +12,7 @@ #include #include // nanosleep +#include @@ -75,11 +77,12 @@ int main(int argc, char* argv[] ) bool finished = false; SimpleMessage msg(&finished); - SslConnection conn(argv[1], argv[2], &msg); + SslConnection conn(argv[1], StrToT(argv[2]), &msg); SocketClient socketClient(&conn); if ( !socketClient.connect() ) { - LOG( Logger::ERR, "Couldn't connect to server, exiting..." ); + LOG_STATIC( Logger::ERR, "Couldn't connect to server, exiting..." ); + SslConnection::destroy(); Logger::destroy(); return 1; } @@ -90,7 +93,8 @@ int main(int argc, char* argv[] ) // send message to server std::string msg1(argv[3]); if ( !socketClient.send( msg1.c_str(), msg1.length()) ) { - LOG( Logger::ERR, "Couldn't send message to server, exiting..." ); + LOG_STATIC( Logger::ERR, "Couldn't send message to server, exiting..." ); + SslConnection::destroy(); Logger::destroy(); return 1; } @@ -100,7 +104,7 @@ int main(int argc, char* argv[] ) while ( !finished && socketClient.isPolling() ) nanosleep(&tm, &tm) ; - socketClient.disconnect(); +// socketClient.disconnect(); SslConnection::destroy(); Logger::destroy(); return 0; diff --git a/other/sslserver_main.cpp b/other/sslserver_main.cpp new file mode 100644 index 0000000..8976feb --- /dev/null +++ b/other/sslserver_main.cpp @@ -0,0 +1,96 @@ +// gpp sslserver_main.cpp -o sslserver -I../include ../src/Logger.cpp ../src/Socket.cpp -ggdb ../src/SocketServer.cpp ../src/Connection.cpp ../src/Poll.cpp ../src/TcpConnection.cpp ../src/SslConnection.cpp -lssl -lcrypto + +#include "Logger.hpp" +#include "Common.hpp" + +#include "Message.hpp" +#include "SslConnection.hpp" +#include "SocketServer.hpp" + + +#include +#include + +class EchoMessage : public Message +{ +public: + + EchoMessage( void *msgParam = 0) + : Message(msgParam) + { + TRACE; + } + + bool buildMessage( const void *msgPart, + const size_t msgLen ) + { + TRACE; + m_buffer = std::string( (const char*) msgPart, msgLen ); + onMessageReady(); + return true; + } + + void onMessageReady() + { + TRACE; + + LOG( Logger::INFO, std::string("Got message: \""). + append(m_buffer).append("\" from: "). + append(m_connection->getHost().append(":"). + append(TToStr(m_connection->getPort())) ).c_str() ); + + std::string reply("Got your message, "); + reply.append(m_connection->getHost()).append(":"). + append(TToStr(m_connection->getPort())). + append(" \"").append(m_buffer).append("\""); + + m_connection->send( reply.c_str(), reply.length() ); + } + + Message* clone() + { + TRACE; + return new EchoMessage(m_param); + } + +protected: + + size_t getExpectedLength() + { + TRACE; + return 0; + } +}; + + +int main(int argc, char* argv[] ) +{ + if ( argc != 3 ) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + Logger::createInstance(); + Logger::init(std::cout); + Logger::setLogLevel(Logger::FINEST); +// Logger::setNoPrefix(); + SslConnection::init(); + + EchoMessage msg; + SslConnection conn(argv[1], StrToT(argv[2]), &msg); + SocketServer socketServer(&conn); + + if ( !socketServer.start() ) { + LOG( Logger::ERR, "Failed to start TCP server, exiting..."); + Logger::destroy(); + return 1; + } + + // never reached + sleep(1); + + socketServer.stop(); + SslConnection::destroy(); + Logger::destroy(); + return 0; +} \ No newline at end of file diff --git a/other/tcpclient_main.cpp b/other/tcpclient_main.cpp index 06d6588..1e328bd 100644 --- a/other/tcpclient_main.cpp +++ b/other/tcpclient_main.cpp @@ -2,6 +2,7 @@ #include "Logger.hpp" +#include "Common.hpp" #include "Message.hpp" #include "TcpConnection.hpp" @@ -68,6 +69,7 @@ int main(int argc, char* argv[] ) return 1; } + Logger::createInstance(); Logger::init(std::cout); Logger::setLogLevel(Logger::FINEST); @@ -75,7 +77,8 @@ int main(int argc, char* argv[] ) bool finished = false; SimpleMessage msg(&finished); - TcpConnection conn(argv[1], argv[2], &msg); + + TcpConnection conn(argv[1], StrToT(argv[2]), &msg); SocketClient socketClient(&conn); if ( !socketClient.connect() ) { diff --git a/other/tcpserver_main.cpp b/other/tcpserver_main.cpp index af8d2fe..4f6df99 100644 --- a/other/tcpserver_main.cpp +++ b/other/tcpserver_main.cpp @@ -35,14 +35,16 @@ public: { TRACE; + std::cout << "buffer: " << m_buffer << std::endl; + LOG( Logger::INFO, std::string("Got message: \""). append(m_buffer).append("\" from: "). append(m_connection->getHost().append(":"). - append(m_connection->getPort()) ).c_str() ); + append(TToStr(m_connection->getPort())) ).c_str() ); std::string reply("Got your message, "); reply.append(m_connection->getHost()).append(":"). - append(m_connection->getPort()). + append(TToStr(m_connection->getPort())). append(" \"").append(m_buffer).append("\""); m_connection->send( reply.c_str(), reply.length() ); @@ -77,7 +79,7 @@ int main(int argc, char* argv[] ) // Logger::setNoPrefix(); EchoMessage msg; - TcpConnection conn(argv[1], argv[2], &msg); + TcpConnection conn(argv[1], StrToT(argv[2]), &msg); SocketServer socketServer(&conn); if ( !socketServer.start() ) { diff --git a/src/Connection.cpp b/src/Connection.cpp new file mode 100644 index 0000000..c75ceb7 --- /dev/null +++ b/src/Connection.cpp @@ -0,0 +1,47 @@ +#include "Connection.hpp" + +#include "Logger.hpp" + + +Connection::Connection(std::string host, int port) + : m_host(host) + , m_port(port) +{ + TRACE; +} + + +Connection::~Connection() +{ + TRACE; +} + + +std::string Connection::getHost() const +{ + TRACE; + return m_host; +} + + +int Connection::getPort() const +{ + TRACE; + return m_port; +} + + +void Connection::setHost(const std::string host) +{ + TRACE; + m_host = host; +} + + +void Connection::setPort(const int port) +{ + TRACE; + m_port = port; +} + + diff --git a/src/Logger.cpp b/src/Logger.cpp index 0877094..21ea1ea 100644 --- a/src/Logger.cpp +++ b/src/Logger.cpp @@ -24,13 +24,13 @@ void Logger::setNoPrefix () } -void Logger::log_pointer( const void* msg, +void Logger::log_pointer( const void* pointer, const char* file, const int line, const char* function) { if ( !m_usePrefix ) { - *m_ostream << msg << std::endl; + *m_ostream << pointer << std::endl; return; } @@ -39,12 +39,13 @@ void Logger::log_pointer( const void* msg, << COLOR_RESET << ":" << COLOR( FG_BROWN ) << line << COLOR_RESET << " " << COLOR( FG_CYAN ) << function << COLOR_RESET << " " - << COLOR( FG_BLUE ) << "\"" << msg << "\"" + << COLOR( FG_BLUE ) << "\"" << pointer << "\"" << COLOR_RESET << std::endl; } void Logger::log_string( const int level, + const void* pointer, const char* msg, const char* file, const int line, @@ -66,6 +67,7 @@ void Logger::log_string( const int level, << COLOR( FG_BROWN ) << line << COLOR_RESET << " " << COLOR( FG_CYAN ) << function << COLOR_RESET << " " << color << "\"" << msg << "\"" + << COLOR( FG_BLUE ) << "\"" << pointer << "\"" << COLOR_RESET << std::endl; } diff --git a/src/Poll.cpp b/src/Poll.cpp index 8f31efd..53fd4b0 100644 --- a/src/Poll.cpp +++ b/src/Poll.cpp @@ -4,8 +4,13 @@ #include "Common.hpp" -Poll::Poll( SocketConnection *connection, - const nfds_t maxClient ) +#include +#include + + + +Poll::Poll( Connection *connection, + const nfds_t maxClient ) : m_connection(connection) , m_polling(false) , m_connectionPool() @@ -36,6 +41,8 @@ void Poll::startPolling() while ( m_polling ) { nanosleep(&tm, &tm) ; + + /// @todo put poll into Socket class int ret = poll( m_fds , m_maxclients, 1000); if ( ret == -1 ) { @@ -77,6 +84,8 @@ void Poll::acceptClient() sockaddr clientAddr; socklen_t clientAddrLen; + + /// @todo put accept into Socket class int client_socket = accept( m_connection->getSocket(), &clientAddr, &clientAddrLen ) ; @@ -85,11 +94,11 @@ void Poll::acceptClient() return; } - SocketConnection *connection = m_connection->clone(client_socket); + Connection *connection = m_connection->clone(client_socket); LOG( Logger::INFO, std::string("New client connected: "). append(connection->getHost()).append(":"). - append(connection->getPort()).c_str() ); + append(TToStr(connection->getPort())).c_str() ); m_connectionPool[client_socket] = connection; addFd( client_socket, POLLIN | POLLPRI ); diff --git a/src/Socket.cpp b/src/Socket.cpp index 736680e..8dad957 100644 --- a/src/Socket.cpp +++ b/src/Socket.cpp @@ -57,7 +57,7 @@ bool Socket::createSocket() } -void Socket::closeSocket() +bool Socket::closeSocket() { TRACE; @@ -65,6 +65,8 @@ void Socket::closeSocket() shutdown(m_socket, SHUT_RDWR); close(m_socket); m_socket = -1; + + return true; } @@ -230,7 +232,7 @@ bool Socket::getHostInfo( const std::string host, int status = getaddrinfo(host.c_str(), port.c_str(), &hints, &results); if (status != 0) { - LOG( Logger::ERR, std::string("Error at network address translation: "). + LOG_STATIC( Logger::ERR, std::string("Error at network address translation: "). append(gai_strerror(status)).c_str() ) ; return false; } @@ -263,7 +265,7 @@ void Socket::printHostDetails(struct addrinfo *servinfo) char ipstr[INET6_ADDRSTRLEN]; inet_ntop( it->ai_family, addr, ipstr, sizeof ipstr ); - LOG( Logger::DEBUG, std::string(TToStr(counter)).append(". address is "). + LOG_STATIC( Logger::DEBUG, std::string(TToStr(counter)).append(". address is "). append(ipver).append(": "). append(ipstr).c_str() ); } @@ -286,7 +288,7 @@ bool Socket::convertNameInfo(sockaddr* addr, NI_NAMEREQD ); if ( status != 0 ) { - LOG( Logger::WARNING, std::string("Could not resolve hostname. "). + LOG_STATIC( Logger::WARNING, std::string("Could not resolve hostname. "). append(gai_strerror(status)).c_str() ); return false; } diff --git a/src/SocketClient.cpp b/src/SocketClient.cpp index 1fe7192..800377f 100644 --- a/src/SocketClient.cpp +++ b/src/SocketClient.cpp @@ -48,7 +48,7 @@ void* SocketClient::PollerThread::run() // SocketClient -SocketClient::SocketClient (SocketConnection *connection ) +SocketClient::SocketClient (StreamConnection *connection ) : m_connection (connection) , m_watcher(this) { @@ -67,7 +67,7 @@ bool SocketClient::connect() { TRACE; - if ( !m_connection->connectToHost() ) + if ( !m_connection->connect() ) return false; m_watcher.start(); @@ -84,7 +84,7 @@ void SocketClient::disconnect() m_watcher.join(); } - m_connection->closeConnection(); + m_connection->disconnect(); } diff --git a/src/SocketConnection.cpp b/src/SocketConnection.cpp deleted file mode 100644 index dd93cc9..0000000 --- a/src/SocketConnection.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include "SocketConnection.hpp" - -#include "Logger.hpp" -#include "Common.hpp" - - -SocketConnection::SocketConnection ( const int socket, - Message *message, - const size_t bufferLength ) - : m_socket(socket) - , m_host() - , m_port() - , m_message(message) - , m_buffer(0) - , m_bufferLength(bufferLength) -{ - TRACE; - - m_socket.getPeerName(m_host, m_port); - m_buffer = new unsigned char[m_bufferLength]; - m_message->setConnection(this); -} - - -SocketConnection::SocketConnection ( const std::string host, - const std::string port, - Message *message, - const size_t bufferLength ) - : m_socket(AF_INET, SOCK_STREAM) - , m_host(host) - , m_port(port) - , m_message(message) - , m_buffer(0) - , m_bufferLength(bufferLength) -{ - TRACE; - m_socket.createSocket(); - m_buffer = new unsigned char[m_bufferLength]; - m_message->setConnection(this); -} - - -SocketConnection::~SocketConnection() -{ - TRACE; - m_socket.closeSocket(); - delete[] m_buffer; -} - - -int SocketConnection::getSocket() const -{ - TRACE; - return m_socket.getSocket(); -} - - -std::string SocketConnection::getHost() const -{ - TRACE; - return m_host; -} - - -std::string SocketConnection::getPort() const -{ - TRACE; - return m_port; -} diff --git a/src/SocketServer.cpp b/src/SocketServer.cpp index 77ac4c9..f734e09 100644 --- a/src/SocketServer.cpp +++ b/src/SocketServer.cpp @@ -2,9 +2,10 @@ #include "Logger.hpp" -SocketServer::SocketServer ( SocketConnection *connection, - const int maxClients, - const int maxPendingQueueLen ) + +SocketServer::SocketServer ( StreamConnection *connection, + const int maxClients, + const int maxPendingQueueLen ) : m_connection(connection) , m_poll( m_connection, maxClients) , m_maxPendingQueueLen(maxPendingQueueLen) @@ -23,7 +24,7 @@ bool SocketServer::start() { TRACE; - if ( !m_connection->bindToHost() ) + if ( !m_connection->bind() ) return false; if ( m_connection->listen( m_maxPendingQueueLen ) == -1 ) { @@ -39,5 +40,5 @@ void SocketServer::stop() { TRACE; m_poll.stopPolling(); - m_connection->closeConnection(); + m_connection->disconnect(); } diff --git a/src/SslConnection.cpp b/src/SslConnection.cpp index 0978645..d5f4321 100644 --- a/src/SslConnection.cpp +++ b/src/SslConnection.cpp @@ -28,64 +28,91 @@ void SslConnection::destroy() SslConnection::SslConnection ( const int socket, Message *message, const size_t bufferLength ) - : SocketConnection(socket, message, bufferLength) - , m_tcpConnection(socket, 0, 0) + : StreamConnection() + , m_tcpConnection(socket, message, 0) + , m_message(message) + , m_buffer(0) + , m_bufferLength(bufferLength) , m_sslHandle(0) , m_sslContext(0) { TRACE; + + setHost(m_tcpConnection.getHost()); + setPort(m_tcpConnection.getPort()); + + m_buffer = new unsigned char[m_bufferLength]; + m_message->setConnection(this); } SslConnection::SslConnection ( const std::string host, - const std::string port, + const int port, Message *message, const size_t bufferLength ) - : SocketConnection(host, port, message, bufferLength) - , m_tcpConnection(host, port, 0, 0) + : StreamConnection(host, port) + , m_tcpConnection(host, port, message, 0) + , m_message(message) + , m_buffer(0) + , m_bufferLength(bufferLength) , m_sslHandle(0) , m_sslContext(0) { TRACE; + m_buffer = new unsigned char[m_bufferLength]; + m_message->setConnection(this); } SslConnection::~SslConnection() { TRACE; - closeConnection(); + disconnect(); + delete m_buffer; } -SocketConnection* SslConnection::clone(const int socket) +Connection* SslConnection::clone(const int socket) { - SocketConnection *conn = new SslConnection(socket, - m_message->clone(), - m_bufferLength ); + Connection *conn = new SslConnection( socket, + m_message->clone(), + m_bufferLength ); return conn; } -bool SslConnection::connectToHost() +bool SslConnection::connect() { TRACE; - if ( !m_tcpConnection.connectToHost() ) + if ( !m_tcpConnection.connect() ) + return false; + + if ( !initHandlers() ) return false; - return connect(); + if ( SSL_connect (m_sslHandle) != 1 ) { + LOG (Logger::ERR, getSslError("Handshake with SSL server failed. ").c_str() ); + return false; + } + + return true; } -bool SslConnection::bindToHost() +bool SslConnection::bind() { TRACE; - if ( !m_tcpConnection.bindToHost() ) + if ( !m_tcpConnection.bind() ) + return false; + + if ( !initHandlers() ) return false; - return connect(); + + return true; } @@ -96,21 +123,25 @@ bool SslConnection::listen( const int maxPendingQueueLen ) } -void SslConnection::closeConnection() +/// @todo this function shall be refactored +bool SslConnection::disconnect() { TRACE; /// @note do I have to call this? - m_tcpConnection.closeConnection(); + if ( m_tcpConnection.getSocket() != -1 ) + m_tcpConnection.disconnect(); - int ret = SSL_shutdown(m_sslHandle); + if ( m_sslHandle == 0 || m_sslContext == 0 ) + return false; + int ret = SSL_shutdown(m_sslHandle); if ( ret == 0 ) { LOG( Logger::INFO, "\"close notify\" alert was sent and the peer's " "\"close notify\" alert was received."); } else if (ret == 1 ) { - LOG( Logger::WARNING, "\"The shutdown is not yet finished. " + LOG( Logger::WARNING, "The shutdown is not yet finished. " "Calling SSL_shutdown() for a second time..."); SSL_shutdown(m_sslHandle); } @@ -118,8 +149,17 @@ void SslConnection::closeConnection() LOG (Logger::ERR, getSslError("The shutdown was not successful. ").c_str() ); } - SSL_free(m_sslHandle); - SSL_CTX_free(m_sslContext); + /// @note I have to check the ref count?! This stinks + if (m_sslHandle && m_sslHandle->references > 0) + SSL_free(m_sslHandle); + + if (m_sslHandle && m_sslContext->references > 0) + SSL_CTX_free(m_sslContext); + + m_sslHandle = 0; + m_sslContext = 0; + + return true; } @@ -163,7 +203,14 @@ bool SslConnection::receive() } -bool SslConnection::connect() +int SslConnection::getSocket() const +{ + TRACE; + return m_tcpConnection.getSocket(); +} + + +bool SslConnection::initHandlers() { TRACE; @@ -185,12 +232,6 @@ bool SslConnection::connect() return false; } - - if ( SSL_connect (m_sslHandle) != 1 ) { - LOG (Logger::ERR, getSslError("Handshake with SSL server failed. ").c_str() ); - return false; - } - return true; } diff --git a/src/TcpConnection.cpp b/src/TcpConnection.cpp index 5543934..98e08df 100644 --- a/src/TcpConnection.cpp +++ b/src/TcpConnection.cpp @@ -7,49 +7,70 @@ TcpConnection::TcpConnection ( const int socket, Message *message, const size_t bufferLength ) - : SocketConnection(socket, message, bufferLength) + : StreamConnection() + , m_socket(socket) + , m_message(message) + , m_buffer(0) + , m_bufferLength(bufferLength) { TRACE; + + std::string host, port; + m_socket.getPeerName(host, port); + setHost(host); + setPort(StrToT(port)); + + m_buffer = new unsigned char[m_bufferLength]; + m_message->setConnection(this); } TcpConnection::TcpConnection ( const std::string host, - const std::string port, + const int port, Message *message, const size_t bufferLength ) - : SocketConnection(host, port, message, bufferLength) + : StreamConnection(host, port) + , m_socket(AF_INET, SOCK_STREAM) // or AF_INET6 for IPv6 + , m_message(message) + , m_buffer(0) + , m_bufferLength(bufferLength) { TRACE; + m_socket.createSocket(); + m_buffer = new unsigned char[m_bufferLength]; + m_message->setConnection(this); } TcpConnection::~TcpConnection() { TRACE; + disconnect(); + delete m_buffer; } -SocketConnection* TcpConnection::clone(const int socket) +Connection* TcpConnection::clone(const int socket) { - SocketConnection *conn = new TcpConnection(socket, - m_message->clone(), - m_bufferLength ); + Connection *conn = new TcpConnection(socket, + m_message->clone(), + m_bufferLength ); return conn; } -bool TcpConnection::connectToHost() +bool TcpConnection::connect() { TRACE; - return m_socket.connectToHost(m_host, m_port); + return m_socket.connectToHost(m_host, TToStr(m_port)); } -bool TcpConnection::bindToHost() +bool TcpConnection::bind() { TRACE; - return m_socket.bindToHost(m_host, m_port); + return m_socket.bindToHost(m_host, TToStr(m_port)); } @@ -60,10 +81,13 @@ bool TcpConnection::listen( const int maxPendingQueueLen ) } -void TcpConnection::closeConnection() +bool TcpConnection::disconnect() { TRACE; - m_socket.closeSocket(); + if ( getSocket() == -1 ) + return false; + + return m_socket.closeSocket(); } @@ -85,14 +109,22 @@ bool TcpConnection::receive() } else if (length == 0) { LOG( Logger::INFO, std::string("Connection closed by "). - append(m_host).append(":").append(m_port).c_str() ); + append(m_host).append(":").append(TToStr(m_port)).c_str() ); } - return false; + return false; } LOG ( Logger::DEBUG, std::string("Received: "). append(TToStr(length)).append(" bytes from: "). - append(m_host).append(":").append(m_port).c_str() ); + append(m_host).append(":").append(TToStr(m_port)).c_str() ); return m_message->buildMessage( (void*)m_buffer, (size_t)length); } + + +int TcpConnection::getSocket() const +{ + TRACE; + + return m_socket.getSocket(); +}