#ifndef _COMMON_H
#define _COMMON_H

#include <openssl/ssl.h>
#include <openssl/err.h>
#include <poll.h>
#include <unistd.h>

#define _cleanup_(x) __attribute__((cleanup(x)))

#ifndef COMMON_LOG
#include <stdio.h>
#define COMMON_LOG(prio, msg, ...) (fprintf(stderr, msg "\n", ## __VA_ARGS__ ))
#endif

static void __attribute__((unused)) free_ssl_ctx(SSL_CTX **ctxp) {
    if (*ctxp) {
        SSL_CTX_free(*ctxp);
    }
}

static void __attribute__((unused)) free_ssl(SSL **sslp) {
    if (*sslp) {
        SSL_free(*sslp);
    }
}

static void __attribute__((unused)) free_fd(int *fdp) {
    if (*fdp != -1) {
        close(*fdp);
    }
}

static void __attribute__((unused)) free_file(FILE **filep) {
    if (*filep != NULL) {
        fclose(*filep);
    }
}

static void __attribute__((unused)) free_string(char **ptr) {
    if (*ptr != NULL) {
        free(*ptr);
    }
}

static void __attribute__((unused)) psslerror(char *str) {
    COMMON_LOG(LOG_ERR, "%s:", str);
    unsigned long ssl_err;
    while ((ssl_err = ERR_get_error())) {
        COMMON_LOG(LOG_ERR, "ssl error: %lud:%s:%s:%s",
                ssl_err,
                ERR_lib_error_string(ssl_err),
                ERR_func_error_string(ssl_err),
                ERR_reason_error_string(ssl_err));
    }
}

static int wait_rd_with_timeout(int fd, int timeout) {
    struct pollfd pollfd = { .fd = fd, .events = POLLIN };
    int status = poll(&pollfd, 1, timeout);
    if (status == -1)
        return -1;
    if (status == 0) {
        errno = ETIMEDOUT;
        return -1;
    }
    return 0;
}

static char *ssl_err(int e) {
    switch(e) {
        case SSL_ERROR_NONE: return("SSL_ERROR_NONE");
        case SSL_ERROR_ZERO_RETURN: return("SSL_ERROR_ZERO_RETURN");
        case SSL_ERROR_WANT_READ: return("SSL_ERROR_WANT_READ");
        case SSL_ERROR_WANT_WRITE: return("SSL_ERROR_WANT_WRITE");
        case SSL_ERROR_WANT_CONNECT: return("SSL_ERROR_WANT_CONNECT");
        case SSL_ERROR_WANT_ACCEPT: return("SSL_ERROR_WANT_ACCEPT");
        case SSL_ERROR_WANT_X509_LOOKUP: return("SSL_ERROR_WANT_X509_LOOKUP");
        case SSL_ERROR_WANT_ASYNC: return("SSL_ERROR_WANT_ASYNC");
        case SSL_ERROR_WANT_ASYNC_JOB: return("SSL_ERROR_WANT_ASYNC_JOB");
        case SSL_ERROR_WANT_CLIENT_HELLO_CB: return("SSL_ERROR_WANT_CLIENT_HELLO_CB");
        case SSL_ERROR_SYSCALL: return("SSL_ERROR_SYSCALL");
        case SSL_ERROR_SSL: return("SSL_ERROR_SSL");
    }
    return"?";
}

static int __attribute__((unused)) ssl_write_with_timeout(SSL *ssl, int fd, char *data, size_t datalen, int timeout) {
    while (1) {
        int status = SSL_write(ssl, data, datalen);
        if (status > 0) {
            if (status == datalen)
                return 0;
            COMMON_LOG(LOG_ERR, "%s: unexpected partial write. requested %lud bytes, returned: %d", __func__, datalen, status);
            errno = EPIPE;
            return -1;
        }
        int ssl_error = SSL_get_error(ssl, status);
        switch (ssl_error) {
            case SSL_ERROR_WANT_READ:
                status = wait_rd_with_timeout(fd, timeout);
                if (status == -1) {
                    COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                    return -1;
                }
                continue;
            case SSL_ERROR_SYSCALL:
                COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                return -1;
            case SSL_ERROR_SSL:
                psslerror("");
                errno = EPROTO;
                return -1;
            default:
                COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error));
                errno = ENOSYS;
                return -1;
        }
    }
}

static int __attribute__((unused)) ssl_read_with_timeout(SSL *ssl, int fd, void *buf, size_t num, int timeout){
    errno = 0;   /* see commit message */
    while (1) {
        int status = SSL_read(ssl, buf, num);
        if (status > 0)
            return status;
        int ssl_error = SSL_get_error(ssl, status);
        switch (ssl_error) {
            case SSL_ERROR_WANT_READ:
                status = wait_rd_with_timeout(fd, timeout);
                if (status == -1) {
                    COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                    return -1;
                }
                continue;
            case SSL_ERROR_SYSCALL:
                if (errno == 0) {
                    COMMON_LOG(LOG_ERR, "%s: unexpected EOF from peer", __func__);
                    errno = ECONNABORTED;
                    return -1;
                }
                COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                return -1;
            case SSL_ERROR_ZERO_RETURN:
                return 0;
            default:
                COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error));
                errno = ENOSYS;
                return -1;
        }
    }
}

static int __attribute__((unused))  ssl_accept_with_timeout(SSL *ssl, int fd, int timeout) {
    errno = 0;   /* see commit message */
    while (1) {
        int status = SSL_accept(ssl);
        if (status == 1)
            return 1;
        int ssl_error = SSL_get_error(ssl, status);
        switch (ssl_error) {
            case SSL_ERROR_WANT_READ:
                status = wait_rd_with_timeout(fd, timeout);
                if (status == -1) {
                    COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                    return -1;
                }
                continue;
            case SSL_ERROR_SYSCALL:
                if (errno == 0) {
                    COMMON_LOG(LOG_ERR, "%s: unexpected EOF from peer", __func__);
                    errno = ECONNABORTED;
                    return -1;
                }
                COMMON_LOG(LOG_ERR, "%s: %m", __func__);
                return -1;
            case SSL_ERROR_SSL:
                psslerror("");
                errno = EPROTO;
                return -1;
            default:
                COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error));
                errno = ENOSYS;
                return -1;
        }
    }
}

#endif /* _COMMON_H */