//  This file is distributed as part of the bit-babbler package.
//  Copyright 1998 - 2016,  Ron <ron@debian.org>

#ifndef _BB_SOCKET_H
#define _BB_SOCKET_H

#include <bit-babbler/log.h>

#if EM_PLATFORM_POSIX

    #include <sys/socket.h>
    #include <sys/un.h>
    #include <netinet/in.h>
    #include <netdb.h>

#elif EM_PLATFORM_MSW

    #include <ws2tcpip.h>

#else

    #error Unsupported platform

#endif


namespace BitB
{
    union sockaddr_any_t
    {
        struct sockaddr         any; // Generic socket address.
        struct sockaddr_storage ss;  // Largest available socket address space.

        struct sockaddr_in      in;  // IPv4 domain socket address.
        struct sockaddr_in6     in6; // IPv6 domain socket address.

       #if EM_PLATFORM_POSIX
        struct sockaddr_un      un;  // Unix domain socket address.
       #endif
    };


    struct SockAddr
    { //{{{

        std::string     host;
        std::string     service;

        int             addr_type;
        int             addr_protocol;
        socklen_t       addr_len;
        sockaddr_any_t  addr;


        // Parse an address string of the form 'host:service'
        // where the host part (but not the colon) is optional.
        // INADDR_ANY is assumed if no host is provided.
        SockAddr( const std::string &addrstr )
        { //{{{

            size_t  n = addrstr.rfind(':');

            if( n != std::string::npos && n + 1 < addrstr.size() )
            {
                service = addrstr.substr( n + 1 );

                if( addrstr[0] == '[' && n > 2 )
                    host = addrstr.substr( 1, n - 2 );
                else
                    host = addrstr.substr( 0, n );
            }

            if( service.empty() )
                throw Error( _("SockAddr( '%s' ): no service address"),
                                                    addrstr.c_str() );
        } //}}}


        std::string AddrStr() const
        { //{{{

            if( host.find(':') != std::string::npos )
                return '[' + host + "]:" + service;

            return host + ':' + service;

        } //}}}

        void GetAddrInfo( int socktype, int flags )
        { //{{{

            addrinfo    hints;
            addrinfo   *addrinf;

            memset( &hints, 0, sizeof(addrinfo) );

            hints.ai_flags      = flags;
            hints.ai_family     = AF_UNSPEC;
            hints.ai_socktype   = socktype;
         // hints.ai_protocol   = 0;
         // hints.ai_addrlen    = 0;
         // hints.ai_addr       = NULL;
         // hints.ai_canonname  = NULL;
         // hints.ai_next       = NULL;

            int err = ::getaddrinfo( host.empty() ? NULL : host.c_str(),
                                     service.c_str(), &hints, &addrinf );
            if( err )
                throw Error( _("SockAddr( '%s' ): failed to get address: %s"),
                                        AddrStr().c_str(), gai_strerror( err ) );

            if( addrinf->ai_addrlen > sizeof(sockaddr_storage) )
            {
                freeaddrinfo( addrinf );
                throw Error( _("SockAddr( '%s' ): ai_addrlen %u > sockaddr_storage %zu"),
                            AddrStr().c_str(), addrinf->ai_addrlen, sizeof(sockaddr_storage) );
            }

            addr_type       = addrinf->ai_socktype;
            addr_protocol   = addrinf->ai_protocol;
            addr_len        = addrinf->ai_addrlen;

            memcpy( &addr.any, addrinf->ai_addr, addr_len );
            memset( reinterpret_cast<uint8_t*>(&addr.any) + addr_len, 0,
                    sizeof(sockaddr_storage) - addr_len );

            freeaddrinfo( addrinf );

        } //}}}

    }; //}}}



   #if EM_PLATFORM_MSW

    class SocketError : public Error
    { //{{{
    private:

        int     m_errno;
        char    m_errmsg[65536];


        char *GetSysMsg()
        {
            m_errmsg[0] = '\0';
            FormatMessageA( FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
                            0, m_errno, 0, m_errmsg, sizeof(m_errmsg), NULL );
            return m_errmsg;
        }


    public:

        SocketError() throw()
            : m_errno( WSAGetLastError() )
        {
            SetMessage( "Socket Error: %s", GetSysMsg() );
        }

        SocketError( const std::string &msg ) throw()
            : m_errno( WSAGetLastError() )
        {
            SetMessage( "%s", (msg + ": " + GetSysMsg()).c_str() );
        }

        BB_PRINTF_FORMAT(2,3)
        SocketError( const char *format, ... ) throw()
            : m_errno( WSAGetLastError() )
        {
            va_list         arglist;
            std::string     msg( format );

            msg.append( ": " ).append( GetSysMsg() );

            va_start( arglist, format );
            SetMessage( msg.c_str(), arglist );
            va_end( arglist );
        }

        BB_PRINTF_FORMAT(3,4)
        SocketError( int code, const char *format, ... ) throw()
            : m_errno( code )
        {
            va_list         arglist;
            std::string     msg( format );

            msg.append( ": " ).append( GetSysMsg() );

            va_start( arglist, format );
            SetMessage( msg.c_str(), arglist );
            va_end( arglist );
        }


        int GetErrorCode() const { return m_errno; }

    }; //}}}


    template< int N >
    BB_PRINTF_FORMAT(1,2)
    void LogSocketErr( const char *format, ... )
    { //{{{

        va_list         arglist;
        std::string     fmt( format );
        char            errmsg[65536];

        if( fmt.size() && fmt[fmt.size() - 1] == '\n' )
            fmt.erase( fmt.size() - 1 );

        errmsg[0] = '\0';
        FormatMessageA( FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
                        0, WSAGetLastError(), 0, errmsg, sizeof(errmsg), NULL );

        fmt.append(": ").append( errmsg ).append(1,'\n');

        va_start( arglist, format );
        Logv<N>( fmt.c_str(), arglist );
        va_end( arglist );

    } //}}}


    class WinsockScope
    { //{{{
    private:

        WSADATA m_wsa;

    public:

        WinsockScope()
        {
            int ret = WSAStartup(MAKEWORD(2,2), &m_wsa);

            if( ret )
                throw Error( "WSAStartup failed with error %d" );
        }

        ~WinsockScope()
        {
            WSACleanup();
        }

    }; //}}}

   #else  // ! EM_PLATFORM_MSW

    #define SocketError     SystemError
    #define LogSocketErr    LogErr

    struct WinsockScope { WinsockScope() {} };

   #endif

}   // BitB namespace

#endif // _BB_SOCKET_H

// vi:sts=4:sw=4:et:foldmethod=marker
