diff --git a/include/MysqlClient.hpp b/include/MysqlClient.hpp index 211a9c5..540a3ac 100644 --- a/include/MysqlClient.hpp +++ b/include/MysqlClient.hpp @@ -6,12 +6,16 @@ #include #include + +/// @note Call init/destroy before/after usage class MysqlClient { public: - /// @note Call init_client_errs() / finish_client_errs() before / after + static void init(); + static void destroy(); + MysqlClient ( const char *host = NULL, const char *user = NULL, const char *passwd = NULL, diff --git a/include/SslConnection.hpp b/include/SslConnection.hpp new file mode 100644 index 0000000..cc5313a --- /dev/null +++ b/include/SslConnection.hpp @@ -0,0 +1,62 @@ +#ifndef SSL_CONNECTION_HPP +#define SSL_CONNECTION_HPP + + +#include "SocketConnection.hpp" +#include "TcpConnection.hpp" +#include "Message.hpp" + +#include + +// #include +#include +// #include + + + +/// @note Call init/destroy before/after usage +class SslConnection : public SocketConnection +{ +public: + + static void init(); + static void destroy(); + + SslConnection ( const int socket, + Message *message, + const size_t bufferLength = 1024 ); + + SslConnection ( const std::string host, + const std::string port, + Message *message, + const size_t bufferLength = 1024 ); + + virtual ~SslConnection(); + + SocketConnection* clone(const int socket); + + bool connectToHost(); + bool bindToHost(); + bool listen( const int maxPendingQueueLen = 64 ); + void closeConnection(); + + bool send( const void* message, const size_t length ); + bool receive(); + + +private: + + SslConnection(const SslConnection&); + SslConnection& operator=(const SslConnection&); + + bool connect(); + std::string getSslError(const std::string &msg); + + + TcpConnection m_tcpConnection; + SSL *m_sslHandle; + SSL_CTX *m_sslContext; +}; + + +#endif // SSL_CONNECTION_HPP diff --git a/include/TcpConnection.hpp b/include/TcpConnection.hpp index f799bd9..fdce5c8 100644 --- a/include/TcpConnection.hpp +++ b/include/TcpConnection.hpp @@ -3,7 +3,6 @@ #include "SocketConnection.hpp" -#include "Socket.hpp" #include "Message.hpp" #include diff --git a/other/mysqlclient_main.cpp b/other/mysqlclient_main.cpp index 0dd9d9a..26c95b3 100644 --- a/other/mysqlclient_main.cpp +++ b/other/mysqlclient_main.cpp @@ -130,7 +130,7 @@ int main(int argc, char* argv[] ) return 1; // init - init_client_errs(); + MysqlClient::init(); MysqlConnectionPool cp ( argParse.foundArg("--host") ? host.c_str() : NULL, argParse.foundArg("-u, --user") ? user.c_str() : NULL, @@ -153,7 +153,7 @@ int main(int argc, char* argv[] ) // end cp.clear(); - finish_client_errs(); + MysqlClient::destroy(); Logger::destroy(); return 0; } diff --git a/other/mysqlclient_tcpwrapper.cpp b/other/mysqlclient_tcpwrapper.cpp index 23a3364..f47bc21 100644 --- a/other/mysqlclient_tcpwrapper.cpp +++ b/other/mysqlclient_tcpwrapper.cpp @@ -183,10 +183,9 @@ int main(int argc, char* argv[] ) conns, port, clients, pending, threads ) ) return 1; - /* // init MySQL connection pool - init_client_errs(); + MysqlClient::init(); MysqlConnectionPool mysqlConnectionPool ( argParse.foundArg("--host") ? host.c_str() : NULL, argParse.foundArg("-u, --user") ? user.c_str() : NULL, @@ -227,9 +226,7 @@ int main(int argc, char* argv[] ) // end mysqlConnectionPool.clear(); - finish_client_errs(); - - */ + MysqlClient::destroy(); Logger::destroy(); return 0; } diff --git a/other/sslclient_main.cpp b/other/sslclient_main.cpp new file mode 100644 index 0000000..e7ede3c --- /dev/null +++ b/other/sslclient_main.cpp @@ -0,0 +1,107 @@ +// gpp sslclient_main.cpp -o sslclient -I../include ../src/Logger.cpp ../src/Thread.cpp ../src/Socket.cpp -lpthread ../src/SocketClient.cpp ../src/Poll.cpp ../src/SocketConnection.cpp ../src/SslConnection.cpp -lssl -lcrypto ../src/TcpConnection.cpp + +#include "Logger.hpp" + +#include "Message.hpp" +#include "SslConnection.hpp" +#include "SocketClient.hpp" + + +#include +#include + +#include // nanosleep + + + + +class SimpleMessage : public Message +{ +public: + + SimpleMessage( void *msgParam = 0) + : Message(msgParam) + { + TRACE; + } + + bool buildMessage( const void *msgPart, + const size_t msgLen ) + { + TRACE; + m_buffer = std::string( (const char*) msgPart, msgLen ); + onMessageReady(); + return true; + } + + void onMessageReady() + { + TRACE; + + LOG( Logger::INFO, std::string("Got reply from server: "). + append(m_buffer).c_str() ); + + *( static_cast(m_param) ) = true; + } + + Message* clone() + { + TRACE; + return new SimpleMessage(m_param); + } + +protected: + + size_t getExpectedLength() + { + TRACE; + return 0; + } +}; + + +int main(int argc, char* argv[] ) +{ + if ( argc != 4 ) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + Logger::createInstance(); + Logger::init(std::cout); + Logger::setLogLevel(Logger::FINEST); + SslConnection::init(); + + bool finished = false; + + SimpleMessage msg(&finished); + SslConnection conn(argv[1], argv[2], &msg); + SocketClient socketClient(&conn); + + if ( !socketClient.connect() ) { + LOG( Logger::ERR, "Couldn't connect to server, exiting..." ); + Logger::destroy(); + return 1; + } + + // wait for thread creation + sleep(1); + + // send message to server + std::string msg1(argv[3]); + if ( !socketClient.send( msg1.c_str(), msg1.length()) ) { + LOG( Logger::ERR, "Couldn't send message to server, exiting..." ); + Logger::destroy(); + return 1; + } + + // wait for the complate &handled reply + struct timespec tm = {0,1000}; + while ( !finished && socketClient.isPolling() ) + nanosleep(&tm, &tm) ; + + socketClient.disconnect(); + SslConnection::destroy(); + Logger::destroy(); + return 0; +} \ No newline at end of file diff --git a/src/MysqlClient.cpp b/src/MysqlClient.cpp index 0c0f5a3..01507ae 100644 --- a/src/MysqlClient.cpp +++ b/src/MysqlClient.cpp @@ -6,6 +6,19 @@ #include "Common.hpp" +void MysqlClient::init() +{ + TRACE_STATIC; + init_client_errs(); +} + +void MysqlClient::destroy() +{ + TRACE_STATIC; + finish_client_errs(); +} + + MysqlClient::MysqlClient( const char *host, const char *user, const char *passwd, diff --git a/src/SslConnection.cpp b/src/SslConnection.cpp new file mode 100644 index 0000000..0978645 --- /dev/null +++ b/src/SslConnection.cpp @@ -0,0 +1,206 @@ +#include "SslConnection.hpp" + +#include "Logger.hpp" +#include "Common.hpp" + +#include +#include +#include + + + +void SslConnection::init() +{ + TRACE_STATIC; + + SSL_load_error_strings(); + SSL_library_init(); +} + +void SslConnection::destroy() +{ + TRACE_STATIC; + + ERR_free_strings(); +} + + +SslConnection::SslConnection ( const int socket, + Message *message, + const size_t bufferLength ) + : SocketConnection(socket, message, bufferLength) + , m_tcpConnection(socket, 0, 0) + , m_sslHandle(0) + , m_sslContext(0) +{ + TRACE; +} + + +SslConnection::SslConnection ( const std::string host, + const std::string port, + Message *message, + const size_t bufferLength ) + : SocketConnection(host, port, message, bufferLength) + , m_tcpConnection(host, port, 0, 0) + , m_sslHandle(0) + , m_sslContext(0) +{ + TRACE; +} + + +SslConnection::~SslConnection() +{ + TRACE; + closeConnection(); +} + + +SocketConnection* SslConnection::clone(const int socket) +{ + SocketConnection *conn = new SslConnection(socket, + m_message->clone(), + m_bufferLength ); + + return conn; +} + + +bool SslConnection::connectToHost() +{ + TRACE; + + if ( !m_tcpConnection.connectToHost() ) + return false; + + return connect(); +} + + +bool SslConnection::bindToHost() +{ + TRACE; + + if ( !m_tcpConnection.bindToHost() ) + return false; + + return connect(); +} + + +bool SslConnection::listen( const int maxPendingQueueLen ) +{ + TRACE; + return m_tcpConnection.listen(maxPendingQueueLen); +} + + +void SslConnection::closeConnection() +{ + TRACE; + + /// @note do I have to call this? + m_tcpConnection.closeConnection(); + + int ret = SSL_shutdown(m_sslHandle); + + if ( ret == 0 ) { + LOG( Logger::INFO, "\"close notify\" alert was sent and the peer's " + "\"close notify\" alert was received."); + } + else if (ret == 1 ) { + LOG( Logger::WARNING, "\"The shutdown is not yet finished. " + "Calling SSL_shutdown() for a second time..."); + SSL_shutdown(m_sslHandle); + } + else if ( ret == 2 ) { + LOG (Logger::ERR, getSslError("The shutdown was not successful. ").c_str() ); + } + + SSL_free(m_sslHandle); + SSL_CTX_free(m_sslContext); +} + + +bool SslConnection::send( const void* message, const size_t length ) +{ + TRACE; + + int ret = SSL_write(m_sslHandle, message, length); + + if ( ret > 0 ) + return true; + + unsigned long sslErrNo = ERR_peek_error(); + if ( ret == 0 && sslErrNo == SSL_ERROR_ZERO_RETURN ) { + LOG( Logger::INFO, "Underlying connection has been closed."); + return true; + } + + LOG (Logger::ERR, getSslError("SSL write failed. ").c_str() ); + return false; +} + + +bool SslConnection::receive() +{ + TRACE; + + int length = SSL_read(m_sslHandle, m_buffer, m_bufferLength); + + if ( length > 0 ) + return m_message->buildMessage( (void*)m_buffer, (size_t)length); + + unsigned long sslErrNo = ERR_peek_error(); + if ( length == 0 && sslErrNo == SSL_ERROR_ZERO_RETURN ) { + LOG( Logger::INFO, "Underlying connection has been closed."); + return true; + } + + LOG (Logger::ERR, getSslError("SSL read failed. ").c_str() ); + return false; +} + + +bool SslConnection::connect() +{ + TRACE; + + m_sslContext = SSL_CTX_new (SSLv23_client_method ()); + if ( m_sslContext == NULL ) { + LOG (Logger::ERR, getSslError("Creating SSL context failed. ").c_str() ); + return false; + } + + m_sslHandle = SSL_new (m_sslContext); + if ( m_sslHandle == NULL ) { + LOG (Logger::ERR, getSslError("Creating SSL structure for connection failed. ").c_str() ); + return false; + } + + + if ( !SSL_set_fd (m_sslHandle, m_tcpConnection.getSocket()) ) { + LOG (Logger::ERR, getSslError("Connect the SSL object with a file descriptor failed. ").c_str() ); + return false; + } + + + if ( SSL_connect (m_sslHandle) != 1 ) { + LOG (Logger::ERR, getSslError("Handshake with SSL server failed. ").c_str() ); + return false; + } + + return true; +} + + +std::string SslConnection::getSslError(const std::string &msg) +{ + char buffer[130]; + unsigned long sslErrNo = ERR_get_error(); + + ERR_error_string(sslErrNo, buffer); + + return std::string(msg).append(buffer); +}