#ifndef _COMMON_H #define _COMMON_H #include <openssl/ssl.h> #include <openssl/err.h> #include <poll.h> #include <unistd.h> #define _cleanup_(x) __attribute__((cleanup(x))) #ifndef COMMON_LOG #include <stdio.h> #define COMMON_LOG(prio, msg, ...) (fprintf(stderr, msg "\n", ## __VA_ARGS__ )) #endif static void __attribute__((unused)) free_ssl_ctx(SSL_CTX **ctxp) { if (*ctxp) { SSL_CTX_free(*ctxp); } } static void __attribute__((unused)) free_ssl(SSL **sslp) { if (*sslp) { SSL_free(*sslp); } } static void __attribute__((unused)) free_fd(int *fdp) { if (*fdp != -1) { close(*fdp); } } static void __attribute__((unused)) free_file(FILE **filep) { if (*filep != NULL) { fclose(*filep); } } static void __attribute__((unused)) free_string(char **ptr) { if (*ptr != NULL) { free(*ptr); } } static void __attribute__((unused)) psslerror(char *str) { COMMON_LOG(LOG_ERR, "%s:", str); unsigned long ssl_err; while ((ssl_err = ERR_get_error())) { COMMON_LOG(LOG_ERR, "ssl error: %lud:%s:%s:%s", ssl_err, ERR_lib_error_string(ssl_err), ERR_func_error_string(ssl_err), ERR_reason_error_string(ssl_err)); } } static int wait_rd_with_timeout(int fd, int timeout) { struct pollfd pollfd = { .fd = fd, .events = POLLIN }; int status = poll(&pollfd, 1, timeout); if (status == -1) return -1; if (status == 0) { errno = ETIMEDOUT; return -1; } return 0; } static char *ssl_err(int e) { switch(e) { case SSL_ERROR_NONE: return("SSL_ERROR_NONE"); case SSL_ERROR_ZERO_RETURN: return("SSL_ERROR_ZERO_RETURN"); case SSL_ERROR_WANT_READ: return("SSL_ERROR_WANT_READ"); case SSL_ERROR_WANT_WRITE: return("SSL_ERROR_WANT_WRITE"); case SSL_ERROR_WANT_CONNECT: return("SSL_ERROR_WANT_CONNECT"); case SSL_ERROR_WANT_ACCEPT: return("SSL_ERROR_WANT_ACCEPT"); case SSL_ERROR_WANT_X509_LOOKUP: return("SSL_ERROR_WANT_X509_LOOKUP"); case SSL_ERROR_WANT_ASYNC: return("SSL_ERROR_WANT_ASYNC"); case SSL_ERROR_WANT_ASYNC_JOB: return("SSL_ERROR_WANT_ASYNC_JOB"); case SSL_ERROR_WANT_CLIENT_HELLO_CB: return("SSL_ERROR_WANT_CLIENT_HELLO_CB"); case SSL_ERROR_SYSCALL: return("SSL_ERROR_SYSCALL"); case SSL_ERROR_SSL: return("SSL_ERROR_SSL"); } return"?"; } static int __attribute__((unused)) ssl_write_with_timeout(SSL *ssl, int fd, char *data, size_t datalen, int timeout) { while (1) { int status = SSL_write(ssl, data, datalen); if (status > 0) { if (status == datalen) return 0; COMMON_LOG(LOG_ERR, "%s: unexpected partial write. requested %lud bytes, returned: %d", __func__, datalen, status); errno = EPIPE; return -1; } int ssl_error = SSL_get_error(ssl, status); switch (ssl_error) { case SSL_ERROR_WANT_READ: status = wait_rd_with_timeout(fd, timeout); if (status == -1) { COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; } continue; case SSL_ERROR_SYSCALL: COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; case SSL_ERROR_SSL: psslerror(""); errno = EPROTO; return -1; default: COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error)); errno = ENOSYS; return -1; } } } static int __attribute__((unused)) ssl_read_with_timeout(SSL *ssl, int fd, void *buf, size_t num, int timeout){ errno = 0; /* see commit message */ while (1) { int status = SSL_read(ssl, buf, num); if (status > 0) return status; int ssl_error = SSL_get_error(ssl, status); switch (ssl_error) { case SSL_ERROR_WANT_READ: status = wait_rd_with_timeout(fd, timeout); if (status == -1) { COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; } continue; case SSL_ERROR_SYSCALL: if (errno == 0) { COMMON_LOG(LOG_ERR, "%s: unexpected EOF from peer", __func__); errno = ECONNABORTED; return -1; } COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; case SSL_ERROR_ZERO_RETURN: return 0; default: COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error)); errno = ENOSYS; return -1; } } } static int __attribute__((unused)) ssl_accept_with_timeout(SSL *ssl, int fd, int timeout) { errno = 0; /* see commit message */ while (1) { int status = SSL_accept(ssl); if (status == 1) return 1; int ssl_error = SSL_get_error(ssl, status); switch (ssl_error) { case SSL_ERROR_WANT_READ: status = wait_rd_with_timeout(fd, timeout); if (status == -1) { COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; } continue; case SSL_ERROR_SYSCALL: if (errno == 0) { COMMON_LOG(LOG_ERR, "%s: unexpected EOF from peer", __func__); errno = ECONNABORTED; return -1; } COMMON_LOG(LOG_ERR, "%s: %m", __func__); return -1; case SSL_ERROR_SSL: psslerror(""); errno = EPROTO; return -1; default: COMMON_LOG(LOG_ERR, "%s: %s unimplemented", __func__, ssl_err(ssl_error)); errno = ENOSYS; return -1; } } } #endif /* _COMMON_H */