Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
apply openssl 3x shutdown fix to dtls
  • Loading branch information
Fabian Mauchle committed Jan 20, 2023
1 parent 7b6c965 commit 5ede34f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 175 deletions.
123 changes: 31 additions & 92 deletions dtls.c
Expand Up @@ -94,78 +94,11 @@ void dtlssetsrcres() {
AF_UNSPEC, NULL, protodefs.socktype);
}

int dtlsread(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock) {
int len, cnt, sockerr = 0;
socklen_t errlen = sizeof(sockerr);
struct pollfd fds[1];
unsigned long error;
assert(lock);

pthread_mutex_lock(lock);

for (len = 0; len < num; len += cnt) {
if (!SSL_pending(ssl)) {
fds[0].fd = BIO_get_fd(SSL_get_rbio(ssl), NULL);
fds[0].events = POLLIN;

pthread_mutex_unlock(lock);

cnt = poll(fds, 1, timeout? timeout * 1000 : -1);
if (cnt == 0)
return cnt;

pthread_mutex_lock(lock);
if (cnt < 0 || fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) {
if (fds[0].revents & POLLERR) {
if(!getsockopt(BIO_get_fd(SSL_get_rbio(ssl), NULL), SOL_SOCKET, SO_ERROR, (void *)&sockerr, &errlen))
debug(DBG_INFO, "DTLS Connection lost: %s", strerror(sockerr));
else
debug(DBG_INFO, "DTLS Connection lost: unknown error");
} else if (fds[0].revents & POLLHUP) {
debug(DBG_INFO, "DTLS Connection lost: hang up");
} else if (fds[0].revents & POLLNVAL) {
debug(DBG_INFO, "DTLS Connection error: fd not open");
}

SSL_shutdown(ssl);
pthread_mutex_unlock(lock);
return -1;
}
}

cnt = SSL_read(ssl, buf + len, num - len);
if (cnt <= 0)
switch (cnt = SSL_get_error(ssl, cnt)) {
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
cnt = 0;
continue;
case SSL_ERROR_ZERO_RETURN:
debug(DBG_DBG, "dtlsread: got ssl shutdown");
/* fallthrough */
default:
while ((error = ERR_get_error()))
debug(DBG_ERR, "dtlsread: SSL: %s", ERR_error_string(error, NULL));
if (cnt == 0)
debug(DBG_INFO, "dtlsread: connection closed by remote host");
else {
debugerrno(errno, DBG_ERR, "dtlsread: connection lost");
}
/* snsure ssl connection is shutdown */
SSL_shutdown(ssl);
pthread_mutex_unlock(lock);
return -1;
}
}
pthread_mutex_unlock(lock);
return num;
}

unsigned char *raddtlsget(SSL *ssl, int timeout, pthread_mutex_t *lock) {
int cnt, len;
unsigned char buf[4], *rad;

cnt = dtlsread(ssl, buf, 4, timeout, lock);
cnt = sslreadtimeout(ssl, buf, 4, timeout, lock);
if (cnt < 1)
return NULL;

Expand All @@ -184,7 +117,7 @@ unsigned char *raddtlsget(SSL *ssl, int timeout, pthread_mutex_t *lock) {
}
memcpy(rad, buf, 4);

cnt = dtlsread(ssl, rad + 4, len - 4, timeout, lock);
cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout, lock);
if (cnt < 1) {
free(rad);
return NULL;
Expand All @@ -207,7 +140,7 @@ void *dtlsserverwr(void *arg) {
for (;;) {
pthread_mutex_lock(&replyq->mutex);
while (!list_first(replyq->entries)) {
if (client->ssl) {
if (!SSL_get_shutdown(client->ssl)) {
debug(DBG_DBG, "dtlsserverwr: waiting for signal");
pthread_cond_wait(&replyq->cond, &replyq->mutex);
debug(DBG_DBG, "dtlsserverwr: got signal");
Expand All @@ -219,12 +152,12 @@ void *dtlsserverwr(void *arg) {
pthread_mutex_unlock(&replyq->mutex);

pthread_mutex_lock(&client->lock);
if (!client->ssl) {
if (SSL_get_shutdown(client->ssl)) {
/* ssl might have changed while waiting */
pthread_mutex_unlock(&client->lock);
if (reply)
freerq(reply);
debug(DBG_DBG, "tlsserverwr: exiting as requested");
debug(DBG_DBG, "dtlsserverwr: ssl connection shutdown; exiting as requested");
pthread_exit(NULL);
}

Expand Down Expand Up @@ -263,29 +196,34 @@ void dtlsserverrd(struct client *client) {
}

for (;;) {
buf = raddtlsget(client->ssl, IDLE_TIMEOUT * 3, &client->lock);
if (!buf) {
debug(DBG_ERR, "dtlsserverrd: connection from %s lost", addr2string(client->addr, tmp, sizeof(tmp)));
break;
}
debug(DBG_DBG, "dtlsserverrd: got Radius message from %s", addr2string(client->addr, tmp, sizeof(tmp)));
rq = newrequest();
if (!rq) {
free(buf);
continue;
}
rq->buf = buf;
rq->from = client;
if (!radsrv(rq)) {
debug(DBG_ERR, "dtlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr, tmp, sizeof(tmp)));
break;
}
buf = raddtlsget(client->ssl, IDLE_TIMEOUT * 3, &client->lock);
if (!buf) {
pthread_mutex_lock(&client->lock);
if (SSL_get_shutdown(client->ssl))
debug(DBG_ERR, "dtlsserverrd: connection from %s lost", addr2string(client->addr, tmp, sizeof(tmp)));
else {
debug(DBG_WARN, "tlsserverrd: timeout from %s, client %s (no requests), closing connection", addr2string(client->addr, tmp, sizeof(tmp)), client->conf->name);
SSL_shutdown(client->ssl);
}
SSL_set_shutdown(client->ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
pthread_mutex_unlock(&client->lock);
break;
}
debug(DBG_DBG, "dtlsserverrd: got Radius message from %s", addr2string(client->addr, tmp, sizeof(tmp)));
rq = newrequest();
if (!rq) {
free(buf);
continue;
}
rq->buf = buf;
rq->from = client;
if (!radsrv(rq)) {
debug(DBG_ERR, "dtlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr, tmp, sizeof(tmp)));
break;
}
}

/* stop writer by setting ssl to NULL and give signal in case waiting for data */
pthread_mutex_lock(&client->lock);
client->ssl = NULL;
pthread_mutex_unlock(&client->lock);
pthread_mutex_lock(&client->replyq->mutex);
pthread_cond_signal(&client->replyq->cond);
pthread_mutex_unlock(&client->replyq->mutex);
Expand Down Expand Up @@ -314,6 +252,7 @@ void *dtlsservernew(void *arg) {
if (!conf)
goto exit;

memset(&tmpsrvaddr, 0, sizeof(struct addrinfo));
tmpsrvaddr.ai_addr = (struct sockaddr *)&params->bind;
tmpsrvaddr.ai_addrlen = SOCKADDR_SIZE(params->bind);
tmpsrvaddr.ai_family = params->bind.ss_family;
Expand Down
83 changes: 0 additions & 83 deletions tls.c
Expand Up @@ -20,7 +20,6 @@
#include <pthread.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <assert.h>
#include "radsecproxy.h"
#include "hostport.h"
#include "debug.h"
Expand Down Expand Up @@ -219,88 +218,6 @@ int tlsconnect(struct server *server, int timeout, char *text) {
return 1;
}

/* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
/* returns 0 on timeout, -1 on error and num if ok */
int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock) {
int ndesc, cnt = 0, len, sockerr = 0;
socklen_t errlen = sizeof(sockerr);
struct pollfd fds[1];
unsigned long error;
uint8_t want_write = 0;
assert(lock);

pthread_mutex_lock(lock);

for (len = 0; len < num; len += cnt) {
if (SSL_pending(ssl) == 0) {
fds[0].fd = SSL_get_fd(ssl);
fds[0].events = POLLIN;
if (want_write) {
fds[0].events |= POLLOUT;
want_write = 0;
}
pthread_mutex_unlock(lock);

ndesc = poll(fds, 1, timeout ? timeout * 1000 : -1);
if (ndesc == 0)
return ndesc;

pthread_mutex_lock(lock);
if (ndesc < 0 || fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) {
if (fds[0].revents & POLLERR) {
if(!getsockopt(SSL_get_fd(ssl), SOL_SOCKET, SO_ERROR, (void *)&sockerr, &errlen))
debug(DBG_INFO, "sslreadtimeout: connection lost: %s", strerror(sockerr));
else
debug(DBG_INFO, "sslreadtimeout: connection lost: unknown error");
} else if (fds[0].revents & POLLHUP) {
debug(DBG_INFO, "sslreadtimeout: connection lost: hang up");
} else if (fds[0].revents & POLLNVAL) {
debug(DBG_ERR, "sslreadtimeout: connection error: fd not open");
}

SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
pthread_mutex_unlock(lock);
return -1;
}
}

cnt = SSL_read(ssl, buf + len, num - len);
if (cnt <= 0) {
switch (SSL_get_error(ssl, cnt)) {
case SSL_ERROR_WANT_WRITE:
want_write = 1;
/* fallthrough */
case SSL_ERROR_WANT_READ:
cnt = 0;
continue;
case SSL_ERROR_ZERO_RETURN:
debug(DBG_DBG, "sslreadtimeout: got ssl shutdown");
SSL_shutdown(ssl);
break;
case SSL_ERROR_SYSCALL:
if (errno)
debugerrno(errno, DBG_INFO, "sslreadtimeout: connection lost");
else
debug(DBG_INFO, "sslreadtimeout: connection lost: EOF");
/* fallthrough */
case SSL_ERROR_SSL:
while ((error = ERR_get_error()))
debug(DBG_ERR, "sslreadtimeout: SSL: %s", ERR_error_string(error, NULL));
SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
break;
default:
debug(DBG_ERR, "sslreadtimeout: uncaught SSL error");
SSL_shutdown(ssl);
SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
}
pthread_mutex_unlock(lock);
return -1;
}
}
pthread_mutex_unlock(lock);
return cnt;
}

/* timeout in seconds, 0 means no timeout (blocking) */
unsigned char *radtlsget(SSL *ssl, int timeout, pthread_mutex_t *lock) {
int cnt, len;
Expand Down
92 changes: 92 additions & 0 deletions tlscommon.c
Expand Up @@ -27,6 +27,7 @@
#include <openssl/err.h>
#include <openssl/md5.h>
#include <openssl/x509v3.h>
#include <assert.h>
#include "debug.h"
#include "hash.h"
#include "util.h"
Expand Down Expand Up @@ -1263,6 +1264,97 @@ int sslconnecttimeout(SSL *ssl, int timeout) {
return r;
}

/**
* @brief read from ssl connection with timeout.
* In case of error, ssl connection will be closed and shutdown state is set.
*
* @param ssl SSL connection
* @param buf destination buffer
* @param num number of bytes to read
* @param timeout maximum time to wait for data, 0 waits indefinetely
* @param lock the lock to aquire before performing any operation on the ssl conneciton
* @return number of bytes received, 0 on timeout, -1 on error (connection lost)
*/
int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock) {
int ndesc, cnt = 0, len, sockerr = 0;
socklen_t errlen = sizeof(sockerr);
struct pollfd fds[1];
unsigned long error;
uint8_t want_write = 0;
assert(lock);

pthread_mutex_lock(lock);

for (len = 0; len < num; len += cnt) {
if (SSL_pending(ssl) == 0) {
fds[0].fd = SSL_get_fd(ssl);
fds[0].events = POLLIN;
if (want_write) {
fds[0].events |= POLLOUT;
want_write = 0;
}
pthread_mutex_unlock(lock);

ndesc = poll(fds, 1, timeout ? timeout * 1000 : -1);
if (ndesc == 0)
return ndesc;

pthread_mutex_lock(lock);
if (ndesc < 0 || fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) {
if (fds[0].revents & POLLERR) {
if(!getsockopt(SSL_get_fd(ssl), SOL_SOCKET, SO_ERROR, (void *)&sockerr, &errlen))
debug(DBG_INFO, "sslreadtimeout: connection lost: %s", strerror(sockerr));
else
debug(DBG_INFO, "sslreadtimeout: connection lost: unknown error");
} else if (fds[0].revents & POLLHUP) {
debug(DBG_INFO, "sslreadtimeout: connection lost: hang up");
} else if (fds[0].revents & POLLNVAL) {
debug(DBG_ERR, "sslreadtimeout: connection error: fd not open");
}

SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
pthread_mutex_unlock(lock);
return -1;
}
}

cnt = SSL_read(ssl, buf + len, num - len);
if (cnt <= 0) {
switch (SSL_get_error(ssl, cnt)) {
case SSL_ERROR_WANT_WRITE:
want_write = 1;
/* fallthrough */
case SSL_ERROR_WANT_READ:
cnt = 0;
continue;
case SSL_ERROR_ZERO_RETURN:
debug(DBG_DBG, "sslreadtimeout: got ssl shutdown");
SSL_shutdown(ssl);
break;
case SSL_ERROR_SYSCALL:
if (errno)
debugerrno(errno, DBG_INFO, "sslreadtimeout: connection lost");
else
debug(DBG_INFO, "sslreadtimeout: connection lost: EOF");
/* fallthrough */
case SSL_ERROR_SSL:
while ((error = ERR_get_error()))
debug(DBG_ERR, "sslreadtimeout: SSL: %s", ERR_error_string(error, NULL));
SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
break;
default:
debug(DBG_ERR, "sslreadtimeout: uncaught SSL error");
SSL_shutdown(ssl);
SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
}
pthread_mutex_unlock(lock);
return -1;
}
}
pthread_mutex_unlock(lock);
return cnt;
}

#else
/* Just to makes file non-empty, should rather avoid compiling this file when not needed */
static void tlsdummy() {
Expand Down
1 change: 1 addition & 0 deletions tlscommon.h
Expand Up @@ -53,6 +53,7 @@ void freematchcertattr(struct clsrvconf *conf);
void tlsreloadcrls();
int sslconnecttimeout(SSL *ssl, int timeout);
int sslaccepttimeout (SSL *ssl, int timeout);
int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock);
#endif

/* Local Variables: */
Expand Down

0 comments on commit 5ede34f

Please sign in to comment.