#include <arpa/inet.h>
#include <iostream>
#include <netinet/in.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <ostream>
#include <sys/socket.h>
#include <unistd.h>

using std::cout, std::endl, std::string;

/*
 * Remember to build with the flags "-L/usr/lib -lssl -lcrypto"
 */

int main() {

    SSL_library_init();
    SSL_load_error_strings();

    SSL_CTX *ctx = SSL_CTX_new(TLS_server_method());
    if (!ctx) {
        cout << "Creation of SSL context failed" << endl;
    }

    // Load certificate
    if (SSL_CTX_use_certificate_file(ctx, "./ssl/server.crt",
                                     SSL_FILETYPE_PEM) <= 0) {
        unsigned long errCode = ERR_get_error();
        char errBuffer[128];
        ERR_error_string_n(errCode, errBuffer, sizeof(errBuffer));
        cout << "Certificate load failed: " << errBuffer << endl;
    }

    // Load certificate private key
    if (SSL_CTX_use_PrivateKey_file(ctx, "./ssl/server.key",
                                    SSL_FILETYPE_PEM) <= 0) {
        unsigned long errCode = ERR_get_error();
        char errBuffer[128];
        ERR_error_string_n(errCode, errBuffer, sizeof(errBuffer));
        cout << "Private key load failed: " << errBuffer << endl;
    }

    cout << "Initializing server" << endl;
    int serverSocket = socket(AF_INET, SOCK_STREAM, 0);
    if (serverSocket == -1) {
        cout << "Server socket failed" << endl;
    }

    sockaddr_in serverAddress;
    serverAddress.sin_family = AF_INET;
    serverAddress.sin_port = htons(8080);
    serverAddress.sin_addr.s_addr = inet_addr("127.0.0.50");

    int bindStatus = bind(serverSocket, (struct sockaddr *)&serverAddress,
                          sizeof(serverAddress));
    if (bindStatus == -1) {
        cout << "Bind failed" << endl;
    }

    int listenStatus = listen(serverSocket, 5);
    if (listenStatus == -1) {
        cout << "Listen failed" << endl;
    }

    int clientSocket = accept(serverSocket, nullptr, nullptr);
    if (clientSocket == -1) {
        cout << "Client socket failed" << endl;
    }

    SSL *ssl = SSL_new(ctx);
    SSL_set_fd(ssl, clientSocket);

    // Perform the SSL negotiation
    int sslAcceptCode = SSL_accept(ssl);
    if (sslAcceptCode <= 0) {
        int sslError = SSL_get_error(ssl, sslAcceptCode);
        char *errStr = ERR_error_string(ERR_get_error(), nullptr);
        cout << "SSL_accept failed with SSL error: " << sslError << endl;
        cout << "OpenSSL error: " << errStr << endl;
    } else {
        cout << "negotiated SSL" << endl;
    }

    while (true) {
        char buffer[1024] = {0};
        SSL_read(ssl, buffer, sizeof(buffer) - 1);

        if (strlen(buffer) <= 0) {
            cout << "Client seems to have just straight up left :(" << endl;
            break;
        }

        cout << buffer << endl;
    }

    SSL_free(ssl);
    SSL_CTX_free(ctx);
    ERR_free_strings();
    close(serverSocket);

    return 0;
}