/************************************************************************************
 *    This file is part of the MynahSA streaming and archiving toolkit              *
 *    Copyright (C) 2006 Mynah-Software Ltd. All Rights Reserved.                   *
 *                                                                                  *
 *    This program is free software; you can redistribute it and/or modify          *
 *    it under the terms of the GNU General Public License, version 2               *
 *    as published by the Free Software Foundation.                                 *
 *                                                                                  *
 *    This program is distributed in the hope that it will be useful,               *
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of                *
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the                 *
 *    GNU General Public License for more details.                                  *
 *                                                                                  *
 *    You should have received a copy of the GNU General Public License along       *
 *    with this program; if not, write to the Free Software Foundation, Inc.,       *
 *    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.                   *
 *                                                                                  *
 ************************************************************************************/


#ifndef WIN32
#include <sys/socket.h>
#include <arpa/inet.h>
#include <resolv.h>
#include <unistd.h>
#else
#include <winsock2.h>
#include <windows.h>
#endif

#include <stdio.h>
#include <string.h>

#include <mynahsa/thread.hpp>

#include <signal.h>
#include <iostream>
#include <mynahsa/sslinit.hpp>
#include <fcntl.h>

#include <mynahsa/sslconnectionmanager.hpp>

#include <mynahsa/sslserver.hpp>
#include <mynahsa/serverexception.hpp>

#include <iostream>
using namespace std;


namespace MynahSA { 
  
  
  SSLServer::SSLServer(SSLConnectionManager& sobj, 
                       std::string cFile, 
                       std::string kFile, 
                       int port,
                       std::string caFile) : _serverObject(sobj), _port(port) {
 
    createCTX(caFile);
    loadCerts(cFile, kFile);
    
    // Get the server started.
    bindPort();

#ifndef WIN32
    //BMS20051230: Don't know the unix equiv for this
    // note: this is disabling signal SIGPIPE
    signal(SIGPIPE, SIG_IGN);
#endif

  }

  SSLServer::SSLServer(SSLConnectionManager& sobj, 
                       X509* certp, 
                       EVP_PKEY* pkeyp, 
                       int port,
                       std::string caFile) : _serverObject(sobj), _port(port) {

    createCTX(caFile);

    if ( SSL_CTX_use_certificate(_ctx, certp) <= 0) {
      throw ServerException(std::string("Error in using certificate: ")+ERR_error_string(ERR_get_error(),0));
    }
      
    if ( SSL_CTX_use_PrivateKey(_ctx,pkeyp) <= 0 ) {
      throw ServerException(std::string("Error in using private key: ")+ERR_error_string(ERR_get_error(),0));
    }
      
    // Verify that the two keys goto together.
    if ( !SSL_CTX_check_private_key(_ctx) ) {
      throw ServerException("Private key is invalid");
    } 

    bindPort();

#ifndef WIN32
    // note: this is disabling SIGPIPE
    signal(SIGPIPE, SIG_IGN);
#endif
  
  }
  
  SSLServer::~SSLServer() { 
#ifndef WIN32
    close (_master); 
#else 
    closesocket(_master);
#endif  
    // clean up the SSL context
    SSL_CTX_free(_ctx);
  }

  struct ssl_thread_hider { 
    ssl_thread_hider(SSLConnectionManager& obj, SSL* ssl) :_obj(obj), _ssl(ssl) { 
    }
    void operator()() {
      _obj(_ssl);
    }
    // this is preserved from thread above
    SSLConnectionManager& _obj;
    // copy!
    SSL* _ssl;
  };
        
  void SSLServer::checkClients(int wait_t) {
#ifdef DEBUG
    std::cerr << "*************** SSLServer::checkClients ***************" << std::endl;
#endif
  
    //! file descriptor set
    fd_set fdset;

    struct timeval tv;
      
    // Set how long to block.
    tv.tv_sec = wait_t;
    tv.tv_usec = 0;
    
    FD_ZERO(&fdset);
    FD_SET(_master, &fdset);
    
    fd_set errorSet;
    FD_ZERO(&errorSet);
    FD_SET(_master, &errorSet);
    
    select(_master+1, &fdset, NULL, &errorSet, (struct timeval *)&tv);
    
    // If _master is set then someone is trying to connect, and no errors
    if(FD_ISSET(_master, &fdset) && !FD_ISSET(_master, &errorSet)) {
      SSL *ssl;
      
      // Open up new connection
      struct sockaddr_in addr;
      int len = sizeof(addr);
#ifdef WIN32
      int client = accept(_master, (struct sockaddr *)&addr, &len);
#else
#ifdef DEBUG
      cerr << "Entering accept" << endl;
#endif
      int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
#ifdef DEBUG
      cerr << "Back from accept" << endl;
#endif
#endif
        
      if (client == -1) { 
        throw ServerConnectionException("Failed to accept connection!");
      }
  
#ifdef DEBUG
      struct in_addr ip_address;

      // memcpy is used to bypass type casting
      memcpy(&ip_address, &addr.sin_addr.s_addr, 4);
      
      std::cerr << "\n\n---------------------------------------------\n";
      std::cerr << "Connection from: " << inet_ntoa(ip_address) << "  (" 
                << ntohs(addr.sin_port) << ")\n";
#endif
        
      ssl = SSL_new(_ctx);
      if (!ssl) { 
        throw ServerConnectionException("Failed to create an SSL instance");
      }

      SSL_set_fd(ssl, client);
      
      // spawn thread, bind ssl parameter onto first of server object call
      // to operator()

#ifdef ENABLE_THREADED_SERVER
#ifdef DEBUG
      std::cerr << "Creating Thread" << std::endl;
#endif /* DEBUG */
#ifdef DEBUG
      std::cerr << "Using Threaded code path" << std::endl;
#endif
      // todo: replace with pthreads direct implementation
              
      ssl_thread_hider f(_serverObject, ssl);
#ifdef MYNAHSA_USE_BOOST
      boost::thread myThread(f);
#else
      Thread< ssl_thread_hider > myThread(f);
#endif
#else
#ifdef DEBUG
      std::cerr << "Using single-thread code path" << std::endl;
#endif
      // enable this and disable above for single thread implementation
      _serverObject(ssl);
#endif
    }
  }

  void SSLServer::createCTX(std::string serverCAFile) {
    // The method describes which SSL protocol we will be using.
    SSL_METHOD *method;
    
    sslInit();
    
    // Compatible with SSLv2, SSLv3 and TLSv1
    method = SSLv3_server_method();
    if (method == 0) { 
      throw ServerException("Failed to create SSL method.");
    }

    // Create new context from method.
    _ctx = SSL_CTX_new(method);
    if(_ctx == NULL) {
      throw ServerException(std::string("Failed to create SSL context: ")+ERR_error_string(ERR_get_error(),0));
    }    
    
    //SSL_CTX_set_default_verify_paths(_ctx);
    if (serverCAFile != "") { 
      if (!SSL_CTX_load_verify_locations(_ctx, serverCAFile.c_str(), 0)) { 
        throw ServerException(std::string("Failed to open certificate authority: ")+serverCAFile);
      }
#ifdef DEBUG
      std::cerr << "Calling SSL_CTX_set_verify - will force client authentication" << std::endl;
#endif
      SSL_CTX_set_verify(_ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,0);
    }  
  }
    
  void SSLServer::loadCerts(std::string cFile, std::string kFile) {
    if ( SSL_CTX_use_certificate_chain_file(_ctx, cFile.c_str()) <= 0) {
      throw ServerException(std::string("Certificat load failed: ")+ERR_error_string(ERR_get_error(),0));
    }
    if ( SSL_CTX_use_PrivateKey_file(_ctx, kFile.c_str(), SSL_FILETYPE_PEM) <= 0) {
      throw ServerException(std::string("Private key load failed: ")+ERR_error_string(ERR_get_error(),0));
    }
    
    // Verify that the two keys goto together.
    if ( !SSL_CTX_check_private_key(_ctx) ) {
      throw(ServerException("Private key is invalid."));
    }
    
  }
  
  
  void SSLServer::bindPort(void)  {
    int on = 1;
    struct sockaddr_in addr;
    
    _master = socket(AF_INET, SOCK_STREAM, 0);
    memset(&addr, 0, sizeof(addr));
  
#ifdef WIN32
    int sockres = setsockopt(_master, SOL_SOCKET, SO_REUSEADDR, (const char*) &on, sizeof(on));
#else
    int sockres = setsockopt(_master, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
#endif
  
    if ( sockres != 0) { 
      throw ServerException("Setsockopt failed");
    }

    addr.sin_family = AF_INET;
    addr.sin_port = htons(_port);
    addr.sin_addr.s_addr = INADDR_ANY;

    // Open the socket
    if (bind(_master, (struct sockaddr *)&addr, sizeof(addr)) != 0) {
      throw ServerException("Bind failed");
    }
    
    // Set a limit on connection queue.
    if(listen(_master, 5) != 0) {
      throw ServerException("Listen failed");
    }    
  }
}; // close namespace MynahSA

