/*
 * Copyright Staffan Gimåker 2008-2010.
 *
 * ---
 *
 * Distributed under the Boost Software License, Version 1.0.
 * (See accompanying file LICENSE_1_0.txt or copy at
 * http://www.boost.org/LICENSE_1_0.txt)
 */

#include <boost/bind.hpp>
#include <boost/scoped_array.hpp>
#include <boost/thread/recursive_mutex.hpp>
#include <boost/format.hpp>
#include <boost/lexical_cast.hpp>
#include <iostream>

#include "Exceptions.hh"
#include "ClientImpl.hh"
#include "ServerConnection.hh"
#include "../Sockets.hh"
#include "../Bytesex.hh"
#include "../serialization/SerializationInterface.hh"
#include "../serialization/DeserializationInterface.hh"
#include "../serialization/ChunkedBufferAdapter.hh"
#include "../serialization/MemAdapters.hh"
#include "../ChunkedBuffer.hh"
#include "../Protocol.hh"
#include "../Action.hh"
#include "../Version.hh"

extern "C" {
#include <liblzf/lzf.h>
}


using namespace peekabot;
using namespace peekabot::client;


namespace
{
    const size_t SENDBUF_SIZE = 8*1024;
    const bool IS_BIG_ENDIAN = (PEEKABOT_BYTE_ORDER == PEEKABOT_BIG_ENDIAN);
}



ServerConnection::ServerConnection(boost::shared_ptr<ClientImpl> client)
    : Transport(client),
      m_tx_thread(0),
      m_rx_thread(0),
      m_stop_signal(false)
{
}


ServerConnection::~ServerConnection()
{
    if( is_connected() )
        disconnect();
    else
    {
        // We're either disconnected, or being disconnected
        //
        // If we're being disconnected, we want to wait for the RX and TX
        // threads to finish
        while( m_tx_thread != 0 || m_rx_thread != 0 );
    }
}


void ServerConnection::connect(
    const std::string &hostname,
    unsigned int port,
    bool low_latency_mode)
{
    if( is_connected() )
        throw std::runtime_error("Already connected");

    // Connect to remote server and perform authentication - blocks
    _connect(hostname, port, low_latency_mode);

    m_stop_signal = false;

    // Start TX and RX threads
    m_tx_thread = new boost::thread(
        boost::bind(&ServerConnection::tx_thread, this));

    m_rx_thread = new boost::thread(
        boost::bind(&ServerConnection::rx_thread, this));
}


void ServerConnection::disconnect()
{
    if( !is_connected() )
        return;

    // Signal TX and RX threads to halt
    m_stop_signal = true;

    // Close the socket - this will interrupt blocking send()
    // operations in the TX thread
    {
        boost::recursive_mutex::scoped_lock lock(m_sockfd_mutex);
        sockets::SocketType tmp = m_sockfd;
        m_sockfd = INVALID_SOCKET;
        sockets::close_socket(tmp);
    }

    // Wake the TX thread from waiting on outbound actions
    m_outbound_push_cond.notify_all();

    // Wait for RX and TX threads to finish
    while( m_tx_thread != 0 || m_rx_thread != 0 );
}


void ServerConnection::dispatch_action(boost::shared_ptr<Action> action)
{
    boost::recursive_mutex::scoped_lock lock(m_outbound_mutex);
    m_outbound.push(action);
    m_outbound_push_cond.notify_all();
}


boost::posix_time::time_duration ServerConnection::get_uptime() const throw()
{
    return boost::posix_time::microsec_clock::local_time() - m_up_since;
}


void ServerConnection::flush()
{
    boost::recursive_mutex::scoped_lock lock(m_outbound_mutex);
    if( !m_outbound.empty() )
        m_outbound_pop_cond.wait(lock);
}

void ServerConnection::discard_unsent() throw()
{
    boost::recursive_mutex::scoped_lock lock(m_outbound_mutex);
    while( !m_outbound.empty() )
        m_outbound.pop();
    m_outbound_pop_cond.notify_all();
}


bool ServerConnection::is_connected() const
{
    return m_rx_thread != 0 && m_tx_thread != 0;
}


void ServerConnection::tx_thread()
{
    while( !m_stop_signal )
    {
        boost::shared_ptr<Action> action;

        // Waiting for data to send... (or for us to be explicitly woken up)
        {
            boost::recursive_mutex::scoped_lock lock(m_outbound_mutex);
            if( m_outbound.empty() )
            {
                m_outbound_push_cond.wait(lock);
                if( m_outbound.empty() )
                    continue;
            }

            action = m_outbound.front();
            m_outbound.pop();
            m_outbound_pop_cond.notify_all();
        }

        try
        {
            serialize_and_send(action);
        }
        catch(...)
        {
            // Connection closed
            m_stop_signal = true;
            //__close_socket(m_sockfd);
            break;
        }
    }

    // Delete queued but unsent actions
    discard_unsent();

    delete m_tx_thread;
    m_tx_thread = 0;
}


void ServerConnection::rx_thread()
{
    ChunkedBuffer recv_buf(4*1024);
    uint8_t ibuf[4*1024];

    ChunkedBufferAdapter adapter(recv_buf);
    DeserializationInterface buf(adapter, m_server_is_big_endian);

    uint8_t ctrl_byte;
    uint32_t body_size;
    bool reading_header = true;

    while( !m_stop_signal )
    {
        try
        {
            size_t n_read = timed_recv(ibuf, 4*1024, 100);
            recv_buf.write(ibuf, n_read);
        }
        catch(...)
        {
            // Connection closed
            m_stop_signal = true;
            // Wake the tx_thread
            m_outbound_push_cond.notify_all();
            break;
        }


        size_t size;
        do
        {
            size = recv_buf.get_size();

            if( reading_header && recv_buf.get_size() >= sizeof(uint32_t)+1  )
            {
                // Entire header read
                buf >> ctrl_byte >> body_size;
                reading_header = false;
                // support for compression not implemented
                assert( ctrl_byte == 0 );
            }


            if( !reading_header &&
                recv_buf.get_size() >= body_size )
            {
                // Entire body read
                reading_header = true;

                try
                {
                    Action *action = 0;
                    buf >> action;
                    execute_action(boost::shared_ptr<Action>(action));
                }
                catch(std::exception &e)
                {
                    std::cerr << "WARNING: peekabot client failed to "
                              << "deserialize/construct incoming action\n"
                              << "  what(): " << e.what() << std::endl;
                }
                catch(...)
                {
                    std::cerr << "WARNING: peekabot client failed to "
                              << "deserialize/construct incoming action"
                              << std::endl;
                }
            }
        }
        while( size-recv_buf.get_size() != 0 );
    }

    delete m_rx_thread;
    m_rx_thread = 0;

    disconnected();
}


void ServerConnection::serialize_and_send(boost::shared_ptr<Action> action)
{
    MemSerializationBuffer uncomp;
    SerializationInterface ar(uncomp);

    try
    {
        ar << action.get();
    }
    catch(std::exception &e)
    {
        std::cerr << "WARNING: peekabot client failed to "
                  << "serialize outbound action\n"
                  << "  what(): " << e.what() << std::endl;
        return;
    }
    catch(...)
    {
        std::cerr << "WARNING: peekabot client failed to serialize "
                  << "outbound action, caught unknown exception"
                  << std::endl;
        return;
    }

    uint32_t uncomp_len = (uint32_t)uncomp.size();
    assert( uncomp_len > 0 );

    // Compress data?
    if( uncomp_len > 128 )
    {
        // require at least 5% compression, or it's not worth the effort
        boost::scoped_array<boost::uint8_t> comp(
            new boost::uint8_t[95*uncomp_len/100]);

        uint32_t comp_len = (uint32_t)lzf_compress(
            uncomp.get(), uncomp_len, comp.get(), 95*uncomp_len/100-1);

        if( comp_len != 0 )
        {
            // header
            boost::uint8_t ctrl_byte = 1;
            blocking_send(&ctrl_byte, 1);
            blocking_send(&uncomp_len, sizeof(uint32_t));
            blocking_send(&comp_len, sizeof(uint32_t));

            // body
            blocking_send(comp.get(), comp_len);

            return;
        }
        // else: failed to compress, fall back to sending uncompressed
    }

    // Send header
    boost::uint8_t ctrl_byte = 0;
    blocking_send(&ctrl_byte, 1);
    blocking_send(&uncomp_len, sizeof(uint32_t));

    // Send body
    blocking_send(uncomp.get(), uncomp_len);
}


void ServerConnection::blocking_send(const void *buf, size_t n)
{
    for( std::size_t m = 0; m < n; )
        m += timed_send((const uint8_t *)buf+m, n-m, 100);
}


size_t ServerConnection::timed_send(
    const void *buf, size_t n, uint32_t timeout_ms)
{
    timeval tv;
    tv.tv_sec = timeout_ms/1000;
    tv.tv_usec = (timeout_ms-timeout_ms/1000) * 1000;

    sockets::SocketType sockfd;
    {
        boost::recursive_mutex::scoped_lock lock(m_sockfd_mutex);
        if( m_sockfd == INVALID_SOCKET )
            throw std::runtime_error("Connection closed by local host");
        sockfd = m_sockfd;
    }

    fd_set fds;
    FD_ZERO(&fds);
    FD_SET(sockfd, &fds);
    int n_ready = select(sockfd+1, 0, &fds, 0, &tv);

    if( n_ready < 1 )
        return 0;

    ssize_t n_sent = send(sockfd, (sockets::BufType)buf, n, MSG_NOSIGNAL);

    if( n_sent < 1 )
        throw std::runtime_error("Connection closed by remote host");

    return (size_t)n_sent;
}


size_t ServerConnection::timed_recv(void *buf, size_t n, uint32_t timeout_ms)
{
    timeval tv;
    tv.tv_sec = timeout_ms/1000;
    tv.tv_usec = (timeout_ms-timeout_ms/1000) * 1000;

    sockets::SocketType sockfd;
    {
        boost::recursive_mutex::scoped_lock lock(m_sockfd_mutex);
        if( m_sockfd == INVALID_SOCKET )
            throw std::runtime_error("Connection closed by local host");
        sockfd = m_sockfd;
    }

    fd_set fds;
    FD_ZERO(&fds);
    FD_SET(sockfd, &fds);
    int n_ready = select(sockfd+1, &fds, 0, 0, &tv);

    if( n_ready < 1 )
        return 0;

    ssize_t n_recv = recv(sockfd, (sockets::BufType)buf, n, 0);

    if( n_recv < 1 )
        throw std::runtime_error("Connection closed by remote host");

    return (size_t)n_recv;
}


void ServerConnection::_connect(
    const std::string &hostname,
    unsigned int port,
    bool low_latency_mode)
{
    boost::recursive_mutex::scoped_lock lock(m_sockfd_mutex);

    if( (m_sockfd = socket(PF_INET, SOCK_STREAM, 0)) == SOCKET_ERROR )
        throw std::runtime_error("Could not initialize socket");

    // Disable blocking on the socket
    sockets::set_nonblocking(m_sockfd);

    // Try to connect to the remote host
    timed_connect(hostname, port, 10000);

    m_up_since = boost::posix_time::microsec_clock::local_time();

    if( low_latency_mode )
    {
        // Disable Nagel's algorithm for lower latency
        int yes = 1;
        if( setsockopt(m_sockfd, IPPROTO_TCP, TCP_NODELAY, 
                       (sockets::BufType)&yes, sizeof(int)) == SOCKET_ERROR )
            std::cerr << "WARNING: setsockopt failed to enable TCP_NODELAY";
    }

    try
    {
        perform_authentication(low_latency_mode);
    }
    catch(std::exception &e)
    {
        // Authentication failed, shut the socket to that mofo down!
        sockets::close_socket(m_sockfd);
        throw;
    }
    // else: Authentication OK! We're set to go!
}



void ServerConnection::timed_connect(
    const std::string &hostname, int port, size_t timeout_ms)
{
    sockaddr_in dest_addr;   // will hold the destination addr

    hostent *h;
    if( (h = gethostbyname(hostname.c_str())) == 0 )
        throw HostResolveFailed("Failed to resolve hostname");

    memset(&dest_addr, 0, sizeof(dest_addr));
    dest_addr.sin_family = AF_INET;     // host byte order
    dest_addr.sin_port = htons(port);   // short, network byte order
    memcpy(&dest_addr.sin_addr.s_addr,
           h->h_addr,
           h->h_length);
    memset(&(dest_addr.sin_zero), '\0', 8);  // zero the rest of the struct


    boost::posix_time::ptime elapsed =
        boost::posix_time::microsec_clock::local_time();


    if( ::connect(m_sockfd, (sockaddr *)&dest_addr,
                  sizeof(sockaddr)) == SOCKET_ERROR )
    {
        if( sockets::get_socket_errno() == EINPROGRESS )
        {
            while(true)
            {
                timeval tv;
                tv.tv_sec = 1;
                tv.tv_usec = 0;

                fd_set fds;
                FD_ZERO(&fds);
                FD_SET(m_sockfd, &fds);

                int res = select(m_sockfd+1, 0, &fds, 0, &tv);

                if( res < 0 && sockets::get_socket_errno() != EINTR )
                {
                    throw ConnectionRefused(
                        std::string("Connection refused: ") +
                        strerror(sockets::get_socket_errno()));
                }
                else if( res > 0 )
                {
                    int valopt;
                    socklen_t optlen = sizeof(int);
                    // Socket selected for write
                    if( getsockopt(m_sockfd, SOL_SOCKET, SO_ERROR,
                                   (sockets::BufType)&valopt, &optlen) < 0 )
                        throw ConnectionRefused(
                            std::string("Connection refused: ") +
                            strerror(sockets::get_socket_errno()));

                    // Check the value returned...
                    if( valopt )
                        throw ConnectionRefused("Connection refused");

                    break; // Connect ok!
                }
                else
                {
                    // Timeout, disconnect ?
                    if( (size_t)
                        (boost::posix_time::microsec_clock::local_time() -
                         elapsed).total_milliseconds() >= timeout_ms )
                        throw ConnectionRefused("Connection timed out");
                }
            }
        }
        else
        {
            throw ConnectionRefused(
                std::string("Connection refused: ") +
                strerror(sockets::get_socket_errno()));
        }
    }
}



void ServerConnection::perform_authentication(
    bool low_latency_mode) throw(std::exception)
{
    {
        //
        // Send the client's authentication data:
        //
        ChunkedBuffer auth_data(256);

        uint8_t is_big_endian = (PEEKABOT_BYTE_ORDER == PEEKABOT_BIG_ENDIAN);

        auth_data.write(&is_big_endian, sizeof(uint8_t));
        auth_data.write(&protocol::UNIQUE_ID, sizeof(protocol::UNIQUE_ID));
        auth_data.write(&protocol::PROTOCOL_VERSION, sizeof(protocol::PROTOCOL_VERSION));
        uint32_t ver = PEEKABOT_VERSION | (PEEKABOT_RELEASE_STATUS << 24);
        auth_data.write(&ver, sizeof(uint32_t));
        //auth_data.write((uint8_t *)&low_latency_mode, sizeof(uint8_t));

        uint8_t ibuf[256];
        size_t bytes = auth_data.read(ibuf, 256);
        size_t n_sent = 0;
        while( bytes-n_sent > 0 )
        {
            n_sent += timed_send(ibuf+n_sent, bytes-n_sent, 100);

            if( get_uptime().total_milliseconds() >=
                protocol::AUTHENTICATION_TIMEOUT )
                throw AuthenticationFailed("Authentication timed out");
        }
    }

    protocol::AuthenticationResult auth_res = protocol::AUTH_SUCCEEDED;
    {
        //
        // Receive authentication data (part one):
        //
        const size_t SIZE_OF_PART_ONE = 1+4+4+4;


        ChunkedBuffer auth_response(256);
        while( auth_response.get_size() < SIZE_OF_PART_ONE )
        {
            uint8_t ibuf[256];
            size_t n = timed_recv(
                ibuf, SIZE_OF_PART_ONE - auth_response.get_size(), 100);
            auth_response.write(ibuf, n);

            if( get_uptime().total_milliseconds() >=
                protocol::AUTHENTICATION_TIMEOUT )
                throw AuthenticationFailed("Authentication timed out");
        }
        // Part one read


        // Extract read information
        uint8_t is_big_endian;
        uint32_t unique_id;
        uint32_t protocol_version;
        uint32_t peekabot_version; // RC<<24 | MAJOR<<16 | MINOR<<8 | REV

        auth_response.read(&is_big_endian, sizeof(uint8_t));
        auth_response.read(&unique_id, sizeof(uint32_t));
        auth_response.read(&protocol_version, sizeof(uint32_t));
        auth_response.read(&peekabot_version, sizeof(uint32_t));

        // switch byte order if needed
        if( is_big_endian != (PEEKABOT_BYTE_ORDER == PEEKABOT_BIG_ENDIAN) )
        {
            switch_byte_order(&unique_id, 1);
            switch_byte_order(&protocol_version, 1);
            switch_byte_order(&peekabot_version, 1);
        }

        // Remember the endianness
        m_server_is_big_endian = is_big_endian;

        // Check identifier
        if( unique_id != protocol::UNIQUE_ID )
            // Erroneous unique identifier received
            auth_res = protocol::AUTH_CLIENT_RECEIVED_UNEXPECTED_DATA;
        // Check for protocol compatibility
        else if( protocol_version != protocol::PROTOCOL_VERSION )
            auth_res = protocol::AUTH_INCOMPAT_PROTOCOL_VERSION;
        // Check for peekabot version compatibility
        else if( (peekabot_version & 0xFFFFFF) < PEEKABOT_COMPATIBLE_VERSION )
            auth_res = protocol::AUTH_INCOMPAT_PEEKABOT_VERSION;

        //
        // Send/read authentication status
        //
        uint8_t tmp = (uint8_t)auth_res;

        while( timed_send(&tmp, 1, 100) == 0 )
        {
            if( get_uptime().total_milliseconds() >=
                protocol::AUTHENTICATION_TIMEOUT )
                throw AuthenticationFailed("Authentication timed out");
        }

        while( timed_recv(&tmp, 1, 100) == 0 )
        {
            if( get_uptime().total_milliseconds() >=
                protocol::AUTHENTICATION_TIMEOUT )
                throw AuthenticationFailed("Authentication timed out");
        }

        if( auth_res == protocol::AUTH_SUCCEEDED )
            auth_res = (protocol::AuthenticationResult)tmp;

        const std::string server_ver_str = make_peekabot_version_string(
            (peekabot_version >> 16) & 255,
            (peekabot_version >> 8) & 255,
            peekabot_version & 255,
            (peekabot_version >> 24) & 255);

        //
        // Check authentication result and throw if there was an error
        // occured
        //
        if( auth_res != protocol::AUTH_SUCCEEDED )
        {
            switch( auth_res )
            {
                case protocol::AUTH_NO_MORE_CONNECTIONS:
                    throw AuthenticationFailed(
                        "No more client connections allowed");
                    break;
                case protocol::AUTH_INCOMPAT_PEEKABOT_VERSION:
                    throw AuthenticationFailed(
                        "Server and client library are of incompatible "
                        "versions (server is version " + server_ver_str +
                        ", client is " + PEEKABOT_VERSION_STRING +
                        ")");
                    break;
                case protocol::AUTH_INCOMPAT_PROTOCOL_VERSION:
                    throw AuthenticationFailed(
                        "Server and client library use incompatible protocol "
                        "versions (server uses protocol version " +
                        boost::lexical_cast<std::string>(protocol_version) +
                        ", client uses " + boost::lexical_cast<std::string>(
                            protocol::PROTOCOL_VERSION) + ")");
                    break;
                case protocol::AUTH_CLIENT_RECEIVED_UNEXPECTED_DATA:
                    throw AuthenticationFailed(
                        "Unexpected data received from "
                        "server during authentication");
                    break;
                case protocol::AUTH_SERVER_RECEIVED_UNEXPECTED_DATA:
                    throw AuthenticationFailed(
                        "Rejected by server due to unexpected "
                        "authentication data");
                    break;
                default:
                    throw AuthenticationFailed(
                        "Authentication failed, reason unknown");
                    break;
            }
        }
    }
}
