/*
 * Copyright Staffan Gimåker 2010.
 *
 * ---
 *
 * Distributed under the Boost Software License, Version 1.0.
 * (See accompanying file LICENSE_1_0.txt or copy at
 * http://www.boost.org/LICENSE_1_0.txt)
 */

#include "Connection.hh"
#include "Protocol.hh"
#include "Version.hh"
#include "Bytesex.hh"
#include "Action.hh"
#include "serialization/SerializationInterface.hh"
#include "serialization/DeserializationInterface.hh"
#include "serialization/SmartPtrs.hh"

#include <cassert>
#include <boost/bind.hpp>
#include <stdexcept>
#include <iostream>

extern "C" {
#include <liblzf/lzf.h>
}


using namespace peekabot;


Connection::Connection(boost::asio::io_service &io_service)
    : m_socket(io_service),
      m_auth_timeout(io_service),
      m_is_authenticated(false),
      m_bytes_sent(0),
      m_bytes_recv(0)
{
}


/*void Connection::connect(const std::string &host, const std::string &port)
{
    boost::asio::ip::tcp::resolver resolver(m_io_service);
    boost::asio::ip::tcp::resolver::query query(host, port);
    boost::asio::ip::tcp::resolver::iterator endpoint_iterator = resolver.resolve(query);

    boost::system::error_code error = boost::asio::error::host_not_found;
    while( error && endpoint_iterator != boost::asio::ip::tcp::resolver::iterator() )
    {
        socket().close();
        socket().connect(*endpoint_iterator++, error);
    }

    if( error )
    {
        //std::cout << "  Failed! Retrying in 5 seconds..." << std::endl;
        // TODO
    }
    else
    {
        boost::mutex::scoped_lock lock(m_connect_mutex);
        authenticate();
        m_io_service.post(boost::bind(&Connection::connect_done, this));
        m_connect_cond.wait();
    }
}*/


void Connection::close()
{
    m_socket.close();
}


boost::asio::io_service &Connection::get_io_service()
{
    return m_socket.io_service();
}


boost::asio::ip::tcp::socket &Connection::get_socket()
{
    return m_socket;
}


const boost::asio::ip::tcp::socket &Connection::get_socket() const
{
    return m_socket;
}


bool Connection::is_authenticated() const
{
    return m_is_authenticated;
}


bool Connection::is_open() const
{
    return m_socket.is_open();
}


bool Connection::is_operational() const
{
    return is_authenticated() && is_open();
}


boost::posix_time::time_duration Connection::get_uptime() const
{
    return boost::posix_time::microsec_clock::local_time() - m_up_since;
}


boost::uint64_t Connection::get_bytes_sent() const
{
    return m_bytes_sent;
}


boost::uint64_t Connection::get_bytes_received() const
{
    return m_bytes_recv;
}


void Connection::get_throughput(float &in, float &out) const
{
    // TODO
}


void Connection::get_remote_peekabot_version(
    boost::uint8_t &major,
    boost::uint8_t &minor,
    boost::uint8_t &rev,
    boost::uint8_t &rc)
{
    if( !is_operational() )
        throw std::logic_error("Connection not open and authenticated");

    major = (m_remote_peekabot_version>>16) & 0xFF;
    minor = (m_remote_peekabot_version>>8) & 0xFF;
    rev = (m_remote_peekabot_version>>8) & 0xFF;
    rc = (m_remote_peekabot_version>>24) & 0xFF;
}


std::string Connection::get_remote_address() const
{
    // TODO
}


void Connection::read_action()
{
    if( !is_operational() )
        throw std::logic_error("Connection not open and authenticated");

    get_io_service().post(boost::bind(&Connection::read_action_, ptr()));
}


void Connection::read_action_()
{
    boost::shared_ptr<boost::uint8_t> ctrl_byte(new boost::uint8_t);
    boost::asio::async_read(
        m_socket, boost::asio::buffer(ctrl_byte.get(), 1),
        boost::bind(
            &Connection::handle_action_pre_header_read, ptr(), _1, ctrl_byte));
}


void Connection::handle_action_pre_header_read(
    const boost::system::error_code &e,
    boost::shared_ptr<boost::uint8_t> ctrl_byte)
{
    if( e )
    {
        // TODO: do this properly
        close();
        return;
    }

    m_bytes_recv += 1;

    if( *ctrl_byte == 0 ) // uncompressed
    {
        boost::shared_ptr<boost::uint32_t> uncomp_len(new boost::uint32_t);
        boost::asio::async_read(
            m_socket, boost::asio::buffer(uncomp_len.get(), 4),
            boost::bind(
                &Connection::handle_action_header_read, ptr(), _1,
                uncomp_len, uncomp_len));
    }
    else if( *ctrl_byte == 1 ) // LZF compression
    {
        std::vector<boost::asio::mutable_buffer> bufs;
        boost::shared_ptr<boost::uint32_t> uncomp_len(new boost::uint32_t);
        boost::shared_ptr<boost::uint32_t> comp_len(new boost::uint32_t);
        bufs.push_back(boost::asio::buffer(uncomp_len.get(), 4));
        bufs.push_back(boost::asio::buffer(comp_len.get(), 4));
        boost::asio::async_read(
            m_socket, bufs,
            boost::bind(
                &Connection::handle_action_header_read, ptr(), _1,
                uncomp_len, comp_len));
    }
    else
    {
        assert( false );
    }
}


void Connection::handle_action_header_read(
    const boost::system::error_code &e,
    boost::shared_ptr<boost::uint32_t> uncomp_len,
    boost::shared_ptr<boost::uint32_t> comp_len)
{
    if( e )
    {
        // TODO: do this properly
        close();
        return;
    }

    if( *comp_len != *uncomp_len ) // compressed
    {
        m_bytes_recv += 8;

        boost::shared_array<boost::uint8_t> buf(new boost::uint8_t[*comp_len]);
        boost::asio::async_read(
            m_socket, boost::asio::buffer(buf.get(), *comp_len),
            boost::bind(
                &Connection::handle_action_body_read, ptr(), _1,
                buf, *uncomp_len, *comp_len));
    }
    else // uncompressed
    {
        m_bytes_recv += 4;

        boost::shared_array<boost::uint8_t> buf(new boost::uint8_t[*uncomp_len]);
        boost::asio::async_read(
            m_socket, boost::asio::buffer(buf.get(), *uncomp_len),
            boost::bind(
                &Connection::handle_action_body_read, ptr(), _1,
                buf, *uncomp_len, *comp_len));
    }
}


void Connection::handle_action_body_read(
    const boost::system::error_code &e,
    boost::shared_array<boost::uint8_t> buf,
    boost::uint32_t uncomp_len,
    boost::uint32_t comp_len)
{
    if( e )
    {
        // TODO: do this properly
        close();
        return;
    }

    m_bytes_recv += comp_len;

    if( comp_len != uncomp_len ) // <=> compressed
    {
        boost::shared_array<boost::uint8_t> uncomp(
            new boost::uint8_t[uncomp_len]);

        // Decompress
        if( lzf_decompress(buf.get(), comp_len,
                           uncomp.get(), uncomp_len) != uncomp_len )
        {
            throw std::runtime_error("Failed to decompress action");
        }

        std::swap(buf, uncomp);
    }

    // Deserialize action
    MemDeserializationBuffer deser_buf(buf.get(), uncomp_len);
    serialization::DeserializationInterface ar(
        deser_buf, m_remote_is_big_endian);

    try
    {
        boost::shared_ptr<Action> action;
        ar >> action;
        action_read(action);
    }
    catch(std::exception &e)
    {
        // TODO
        assert( false );
    }
    catch(...)
    {
        // TODO
        assert( false );
    }
}


void Connection::write_action(const boost::shared_ptr<Action> &action)
{
    if( !is_operational() )
        throw std::logic_error("Connection not open and authenticated");

    get_io_service().post(
        boost::bind(&Connection::write_action_, ptr(), action));
}


void Connection::write_action_(boost::shared_ptr<Action> action)
{
    boost::shared_ptr<ActionWriteData> awd(new ActionWriteData);
    serialization::SerializationInterface ar(awd->m_buf);

    try
    {
        ar << action;

        awd->m_uncomp_len = awd->m_buf.size();
        awd->m_ctrl_byte = 0; // uncompressed

        // Compress data?
        /*if( awd->m_uncomp_len > 128 )
        {
            // require at least 5% compression, or it's not worth the effort
            boost::scoped_array<boost::uint8_t> comp(
                new boost::uint8_t[95*uncomp_len/100]);

            boost::uint32_t comp_len = (boost::uint32_t)lzf_compress(
                awd->m_buf.get(), awd->m_uncomp_len,
                comp.get(), 95*awd->m_uncomp_len/100-1);

            if( comp_len != 0 )
            {
              awd->m_comp_len = comp_len;
              awd->m_ctrl_byte = 1;
            }
            // else: failed to compress, fall back to sending uncompressed
        }*/


        std::vector<boost::asio::const_buffer> bufs;
        // Pre-header: 1 byte
        // Header: 4-8 bytes
        bufs.push_back(boost::asio::buffer(&awd->m_ctrl_byte, 1));
        bufs.push_back(boost::asio::buffer(&awd->m_uncomp_len, 4));
        // Body: n bytes
        if( awd->m_uncomp_len > 0 )
            bufs.push_back(
                boost::asio::buffer(awd->m_buf.get(), awd->m_buf.size()));

        boost::asio::async_write(
            m_socket, bufs,
            boost::bind(&Connection::handle_action_write, ptr(), _1, awd));
    }
    catch(std::exception &e)
    {
        assert( false );
    }
    catch(...)
    {
        assert( false );
    }
}


void Connection::handle_action_write(
    const boost::system::error_code &e,
    boost::shared_ptr<ActionWriteData> awd)
{
    if( e )
    {
        // TODO: do this properly?
        close();
        return;
    }

    if( awd->m_ctrl_byte == 0 ) // uncompressed
    {
        // Header is 1+4 bytes
        m_bytes_sent += 5 + awd->m_uncomp_len;
    }
    else // compressed
    {
        // Header is 1+4+4 bytes
        m_bytes_sent += 9 + awd->m_comp_len;
    }
}


void Connection::authenticate()
{
    if( !is_open() )
        throw std::logic_error("Connection not open");

    if( is_authenticated() )
        throw std::logic_error("Already authenticated");

    get_io_service().post(boost::bind(&Connection::authenticate_, ptr()));
}


void Connection::authenticate_()
{
    /*m_auth_timeout.expires_from_now(
        boost::posix_time::seconds(protocol::AUTHENTICATION_TIMEOUT));
    m_auth_timeout.async_wait(
        boost::bind(
            &Connection::send_auth_status, ptr(),
            protocol::AUTH_AUTHENTICATION_TIMED_OUT, true));*/

    static const boost::uint8_t is_be =
        PEEKABOT_BYTE_ORDER == PEEKABOT_BIG_ENDIAN;
    static const boost::uint32_t peekabot_version =
        PEEKABOT_VERSION | (PEEKABOT_RELEASE_STATUS<<24);

    std::vector<boost::asio::const_buffer> out;
    // Send endianness flag
    out.push_back(
        boost::asio::buffer(&is_be, sizeof(is_be)));
    // Send unique identification string
    out.push_back(
        boost::asio::buffer(
            &protocol::UNIQUE_ID, sizeof(protocol::UNIQUE_ID)));
    // Send protocol version
    out.push_back(
        boost::asio::buffer(
            &protocol::PROTOCOL_VERSION, sizeof(protocol::PROTOCOL_VERSION)));
    // Send peekabot version
    out.push_back(
        boost::asio::buffer(&peekabot_version, sizeof(peekabot_version)));

    boost::asio::async_write(
        m_socket, out,
        boost::bind(&Connection::handle_auth_write, ptr(), _1));

    // ---

    std::vector<boost::asio::mutable_buffer> in;

    // Read endianness flag
    in.push_back(
        boost::asio::buffer(
            &m_remote_is_big_endian, sizeof(m_remote_is_big_endian)));
    // Read unique identification string
    boost::shared_ptr<boost::uint32_t> unique_id_in(new boost::uint32_t);
    in.push_back(
        boost::asio::buffer(unique_id_in.get(), 4));
    // Read protocol version
    boost::shared_ptr<boost::uint32_t> protocol_ver(new boost::uint32_t);
    in.push_back(
        boost::asio::buffer(protocol_ver.get(), 4));
    // Read peekabot version
    in.push_back(
        boost::asio::buffer(
            &m_remote_peekabot_version, sizeof(m_remote_peekabot_version)));

    boost::asio::async_read(
        m_socket, in,
        boost::bind(
            &Connection::handle_auth_data_read, ptr(),
            _1, unique_id_in, protocol_ver));
}


void Connection::handle_auth_write(const boost::system::error_code &e)
{
    if( e )
    {
        on_authentication_failed(protocol::AUTH_CONNECTION_CLOSED_BY_PEER);
    }
}


void Connection::handle_auth_data_read(
    const boost::system::error_code &e,
    boost::shared_ptr<boost::uint32_t> unique_id,
    boost::shared_ptr<boost::uint32_t> protocol_ver)
{
    if( e )
    {
        on_authentication_failed(protocol::AUTH_CONNECTION_CLOSED_BY_PEER);
        return;
    }

    if( m_remote_is_big_endian !=
        (PEEKABOT_BYTE_ORDER == PEEKABOT_BIG_ENDIAN) )
    {
        switch_byte_order(unique_id.get(), 1);
        switch_byte_order(protocol_ver.get(), 1);
        switch_byte_order(&m_remote_peekabot_version, 1);
    }

    protocol::AuthenticationResult auth_res;
    if( m_remote_is_big_endian != 0 && m_remote_is_big_endian != 1 )
    {
        auth_res = protocol::AUTH_UNEXPECTED_DATA_RECEIVED;
    }
    else if( *unique_id != protocol::UNIQUE_ID )
    {
        auth_res = protocol::AUTH_UNEXPECTED_DATA_RECEIVED;
    }
    else if( *protocol_ver != protocol::PROTOCOL_VERSION )
    {
        auth_res = protocol::AUTH_INCOMPAT_PROTOCOL_VERSION;
    }
    else if(
        (m_remote_peekabot_version & 0x00FFFFFF) > PEEKABOT_VERSION ||
        (m_remote_peekabot_version & 0x00FFFFFF) < PEEKABOT_COMPATIBLE_VERSION )
    {
        auth_res = protocol::AUTH_INCOMPAT_PEEKABOT_VERSION;
    }
    else
    {
        auth_res = protocol::AUTH_SUCCEEDED;
    }

    m_auth_res_local = boost::uint8_t(auth_res);
    boost::asio::async_write(
        m_socket,
        boost::asio::buffer(&m_auth_res_local, sizeof(m_auth_res_local)),
        boost::bind(&Connection::handle_auth_write, ptr(), _1));

    boost::asio::async_read(
        m_socket,
        boost::asio::buffer(&m_auth_res_remote, sizeof(m_auth_res_remote)),
        boost::bind(&Connection::handle_auth_res_read, ptr(), _1));
}


void Connection::handle_auth_res_read(
    const boost::system::error_code &e)
{
    if( e )
    {
        m_auth_res_local = protocol::AUTH_CONNECTION_CLOSED_BY_PEER;
    }

    m_is_authenticated = (
        m_auth_res_remote == protocol::AUTH_SUCCEEDED &&
        m_auth_res_local == protocol::AUTH_SUCCEEDED);

    if( !m_is_authenticated )
    {
        if( m_auth_res_local != protocol::AUTH_SUCCEEDED )
            on_authentication_failed(
                (protocol::AuthenticationResult)m_auth_res_local);
        else
            on_authentication_failed(
                (protocol::AuthenticationResult)m_auth_res_remote);
    }
    else
    {
        on_authenticated();
    }
}
