diff --git a/include/SslConnection.hpp b/include/SslConnection.hpp index a0a1526..b48edd1 100644 --- a/include/SslConnection.hpp +++ b/include/SslConnection.hpp @@ -38,6 +38,10 @@ public: bool connect(); bool disconnect(); + bool initServerContext( const std::string certificateFile, + const std::string privateKeyFile ); + bool initClientContext(); + bool send( const void* message, const size_t length ); bool receive(); @@ -52,8 +56,10 @@ private: SslConnection(const SslConnection&); SslConnection& operator=(const SslConnection&); - bool initHandlers(); + bool initHandle(); std::string getSslError(const std::string &msg); + bool loadCertificates( const std::string certificateFile, + const std::string keyFile ); TcpConnection m_tcpConnection; diff --git a/other/sslclient_main.cpp b/other/sslclient_main.cpp index 3f39f0a..d3607c2 100644 --- a/other/sslclient_main.cpp +++ b/other/sslclient_main.cpp @@ -78,6 +78,7 @@ int main(int argc, char* argv[] ) SimpleMessage msg(&finished); SslConnection conn(argv[1], StrToT(argv[2]), &msg); + conn.initClientContext(); SocketClient socketClient(&conn); if ( !socketClient.connect() ) { diff --git a/other/sslserver_main.cpp b/other/sslserver_main.cpp index 360147d..b5dc4f8 100644 --- a/other/sslserver_main.cpp +++ b/other/sslserver_main.cpp @@ -81,8 +81,8 @@ void signalHandler(int s) int main(int argc, char* argv[] ) { - if ( argc != 3 ) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + if ( argc != 5 ) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; return 1; } @@ -100,6 +100,7 @@ int main(int argc, char* argv[] ) EchoMessage msg; SslConnection conn(argv[1], StrToT(argv[2]), &msg); + conn.initServerContext(argv[3], argv[4]); socketServer = new SocketServer(&conn); if ( !socketServer->start() ) { diff --git a/other/tcpclient_main.cpp b/other/tcpclient_main.cpp index 1e328bd..46e0227 100644 --- a/other/tcpclient_main.cpp +++ b/other/tcpclient_main.cpp @@ -82,7 +82,7 @@ int main(int argc, char* argv[] ) SocketClient socketClient(&conn); if ( !socketClient.connect() ) { - LOG( Logger::ERR, "Couldn't connect to server, exiting..." ); + LOG_STATIC( Logger::ERR, "Couldn't connect to server, exiting..." ); Logger::destroy(); return 1; } @@ -93,7 +93,7 @@ int main(int argc, char* argv[] ) // send message to server std::string msg1(argv[3]); if ( !socketClient.send( msg1.c_str(), msg1.length()) ) { - LOG( Logger::ERR, "Couldn't send message to server, exiting..." ); + LOG_STATIC( Logger::ERR, "Couldn't send message to server, exiting..." ); Logger::destroy(); return 1; } diff --git a/other/tcpserver_main.cpp b/other/tcpserver_main.cpp index 4f6df99..3708d5e 100644 --- a/other/tcpserver_main.cpp +++ b/other/tcpserver_main.cpp @@ -83,7 +83,7 @@ int main(int argc, char* argv[] ) SocketServer socketServer(&conn); if ( !socketServer.start() ) { - LOG( Logger::ERR, "Failed to start TCP server, exiting..."); + LOG_STATIC( Logger::ERR, "Failed to start TCP server, exiting..."); Logger::destroy(); return 1; } diff --git a/src/Poll.cpp b/src/Poll.cpp index f4d4b1c..e35174b 100644 --- a/src/Poll.cpp +++ b/src/Poll.cpp @@ -50,7 +50,6 @@ void Poll::startPolling() /// @todo reconnect return; } - if ( ret == 0 ) // timeout continue; @@ -82,15 +81,10 @@ void Poll::acceptClient() { TRACE; -// sockaddr clientAddr; -// socklen_t clientAddrLen; -// int client_socket = accept( m_connection->getSocket(), -// &clientAddr, &clientAddrLen ) ; int client_socket = m_connection->accept(); if ( client_socket == -1 ) { - LOG( Logger::ERR, errnoToString("ERROR accepting. ").c_str() ); return; } diff --git a/src/SslConnection.cpp b/src/SslConnection.cpp index 31a0fed..3f874ae 100644 --- a/src/SslConnection.cpp +++ b/src/SslConnection.cpp @@ -77,10 +77,8 @@ Connection* SslConnection::clone(const int socket) { TRACE; - Connection *conn = new SslConnection( socket, - m_message->clone(), - m_bufferLength ); - + SslConnection *conn = new SslConnection( socket, m_message->clone(), m_bufferLength ); + conn->initClientContext(); return conn; } @@ -92,13 +90,20 @@ bool SslConnection::connect() if ( !m_tcpConnection.connect() ) return false; - if ( !initHandlers() ) - return false; +// if ( !initHandlers() ) +// return false; + + if ( SSL_set_fd(m_sslHandle, m_tcpConnection.getSocket() ) == 0 ) { + getSslError("SSL set connection socket failed. "); + return -1; + } + LOG( Logger::INFO, "itt" ); if ( SSL_connect (m_sslHandle) != 1 ) { LOG (Logger::ERR, getSslError("SSL handshake failed. ").c_str() ); return false; } + LOG( Logger::INFO, "de itt mar nem?" ); return true; } @@ -111,8 +116,8 @@ bool SslConnection::bind() if ( !m_tcpConnection.bind() ) return false; - if ( !initHandlers() ) - return false; +// if ( !initHandlers() ) +// return false; return true; @@ -134,16 +139,22 @@ int SslConnection::accept() if ( client_socket == -1) return client_socket; - if ( SSL_accept(m_sslHandle) == -1 ) { - getSslError("SSL accept failed. "); - return -1; - } + LOG( Logger::INFO, "server itt"); if ( SSL_set_fd(m_sslHandle, client_socket) == 0 ) { getSslError("SSL set connection socket failed. "); return -1; } + LOG( Logger::INFO, "server itt 2"); + + if ( SSL_accept(m_sslHandle) == -1 ) { + getSslError("SSL accept failed. "); + return -1; + } + + LOG( Logger::INFO, "server itt 3"); + return client_socket; } @@ -187,6 +198,38 @@ bool SslConnection::disconnect() } +bool SslConnection::initServerContext( const std::string certificateFile, + const std::string privateKeyFile ) +{ + TRACE; + + m_sslContext = SSL_CTX_new (SSLv2_server_method ()); + if ( m_sslContext == NULL ) { + LOG (Logger::ERR, getSslError("Creating SSL context failed. ").c_str() ); + return false; + } + + if ( !loadCertificates(certificateFile, privateKeyFile) ) + return false; + + return initHandle(); +} + + +bool SslConnection::initClientContext() +{ + TRACE; + + m_sslContext = SSL_CTX_new (SSLv23_client_method ()); + if ( m_sslContext == NULL ) { + LOG (Logger::ERR, getSslError("Creating SSL context failed. ").c_str() ); + return false; + } + + return initHandle(); +} + + bool SslConnection::send( const void* message, const size_t length ) { TRACE; @@ -234,16 +277,10 @@ int SslConnection::getSocket() const } -bool SslConnection::initHandlers() +bool SslConnection::initHandle() { TRACE; - m_sslContext = SSL_CTX_new (SSLv23_client_method ()); - if ( m_sslContext == NULL ) { - LOG (Logger::ERR, getSslError("Creating SSL context failed. ").c_str() ); - return false; - } - m_sslHandle = SSL_new (m_sslContext); if ( m_sslHandle == NULL ) { LOG (Logger::ERR, getSslError("Creating SSL structure for connection failed. ").c_str() ); @@ -269,3 +306,51 @@ std::string SslConnection::getSslError(const std::string &msg) return std::string(msg).append(buffer); } + + +bool SslConnection::loadCertificates( const std::string certificateFile, + const std::string privateKeyFile ) +{ + if ( SSL_CTX_use_certificate_file(m_sslContext, certificateFile.c_str(), SSL_FILETYPE_PEM) != 1 ) + { + getSslError("SSL certificate file loading failed. "); + return false; + } + + if ( SSL_CTX_use_PrivateKey_file(m_sslContext, privateKeyFile.c_str(), SSL_FILETYPE_PEM) != 1 ) + { + getSslError("SSL private Key file loading failed. "); + return false; + } + + if ( SSL_CTX_check_private_key(m_sslContext) != 1 ) + { + LOG( Logger::ERR, "Private key does not match the public certificate\n"); + return false; + } + + return true; +} + +/*---------------------------------------------------------------------*/ +/*--- ShowCerts - print out certificates. ---*/ +/*---------------------------------------------------------------------*/ +// void showCertificates(SSL* ssl) +// { X509 *cert; +// char *line; +// +// cert = SSL_get_peer_certificate(ssl); /* Get certificates (if available) */ +// if ( cert != NULL ) +// { +// printf("Server certificates:\n"); +// line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0); +// printf("Subject: %s\n", line); +// free(line); +// line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0); +// printf("Issuer: %s\n", line); +// free(line); +// X509_free(cert); +// } +// else +// printf("No certificates.\n"); +// } diff --git a/src/TcpConnection.cpp b/src/TcpConnection.cpp index f7575ce..6e7e6b1 100644 --- a/src/TcpConnection.cpp +++ b/src/TcpConnection.cpp @@ -83,10 +83,18 @@ bool TcpConnection::listen( const int maxPendingQueueLen ) int TcpConnection::accept() { + TRACE; sockaddr clientAddr; socklen_t clientAddrLen; - return ::accept( getSocket(), &clientAddr, &clientAddrLen ) ; + int client_socket = ::accept( getSocket(), &clientAddr, &clientAddrLen ) ; + + if ( client_socket == -1 ) { + LOG( Logger::ERR, errnoToString("ERROR accepting. ").c_str() ); + return -1; + } + + return client_socket; }