From 6d3b520232d471c3ead9623406ba98b304fd9e64 Mon Sep 17 00:00:00 2001 From: Fabian Mauchle Date: Wed, 2 May 2018 19:04:22 +0200 Subject: [PATCH] Rework DTLS code and implement locking similar to TLS --- ChangeLog | 1 + dtls.c | 853 ++++++++++++++++++++++++-------------------------- dtls.h | 3 + radsecproxy.c | 26 +- radsecproxy.h | 6 +- tlscommon.c | 89 ++++++ tlscommon.h | 1 + 7 files changed, 510 insertions(+), 469 deletions(-) diff --git a/ChangeLog b/ChangeLog index 3ab85d7..f6adba7 100644 --- a/ChangeLog +++ b/ChangeLog @@ -22,6 +22,7 @@ Changes between 1.6.9 and the master branch - Replace several server status bits with a single state enum. (RADSECPROXY-71) - Use poll instead of select to allow > 1000 concurrent connections. + - Rework DTLS code. Bug fixes: - Detect the presence of docbook2x-man correctly. diff --git a/dtls.c b/dtls.c index 822f00b..055a2af 100644 --- a/dtls.c +++ b/dtls.c @@ -2,6 +2,8 @@ * Copyright (c) 2012,2016-2017, NORDUnet A/S */ /* See LICENSE for licensing information. */ +#define _GNU_SOURCE + #include #include #include @@ -22,6 +24,7 @@ #include #include #include +#include #include "hash.h" #include "radsecproxy.h" @@ -32,7 +35,7 @@ static void setprotoopts(struct commonprotoopts *opts); static char **getlistenerargs(); -void *udpdtlsserverrd(void *arg); +void *dtlslistener(void *arg); int dtlsconnect(struct server *server, struct timeval *when, int timeout, char *text); void *dtlsclientrd(void *arg); int clientradputdtls(struct server *server, unsigned char *rad); @@ -52,18 +55,16 @@ static const struct protodefs protodefs = { DUPLICATE_INTERVAL, /* duplicateintervaldefault */ setprotoopts, /* setprotoopts */ getlistenerargs, /* getlistenerargs */ - udpdtlsserverrd, /* listener */ + dtlslistener, /* listener */ dtlsconnect, /* connecter */ dtlsclientrd, /* clientconnreader */ clientradputdtls, /* clientradput */ NULL, /* addclient */ - addserverextradtls, /* addserverextra */ + NULL, /* addserverextra */ dtlssetsrcres, /* setsrcres */ - initextradtls /* initextra */ + NULL /* initextra */ }; -static int client4_sock = -1; -static int client6_sock = -1; static struct addrinfo *srcres = NULL; static uint8_t handle; static struct commonprotoopts *protoopts = NULL; @@ -81,16 +82,10 @@ static char **getlistenerargs() { return protoopts ? protoopts->listenargs : NULL; } -struct sessioncacheentry { - pthread_mutex_t mutex; - struct gqueue *rbios; - struct timeval expiry; -}; - struct dtlsservernewparams { - struct sessioncacheentry *sesscache; - int sock; struct sockaddr_storage addr; + struct sockaddr_storage bind; + SSL *ssl; }; void dtlssetsrcres() { @@ -100,164 +95,83 @@ void dtlssetsrcres() { AF_UNSPEC, NULL, protodefs.socktype); } -int udp2bio(int s, struct gqueue *q, int cnt) { - unsigned char *buf; - BIO *rbio; - - if (cnt < 1) - return 0; - - buf = malloc(cnt); - if (!buf) { - unsigned char err; - debug(DBG_ERR, "udp2bio: malloc failed"); - recv(s, &err, 1, 0); - return 0; - } - - cnt = recv(s, buf, cnt, 0); - if (cnt < 1) { - debug(DBG_WARN, "udp2bio: recv failed"); - free(buf); - return 0; - } - - rbio = BIO_new_mem_buf(buf, cnt); - BIO_set_mem_eof_return(rbio, -1); - - pthread_mutex_lock(&q->mutex); - if (!list_push(q->entries, rbio)) { - BIO_free(rbio); - pthread_mutex_unlock(&q->mutex); - return 0; - } - pthread_cond_signal(&q->cond); - pthread_mutex_unlock(&q->mutex); - return 1; -} - -BIO *getrbio(SSL *ssl, struct gqueue *q, int timeout) { - BIO *rbio; - struct timeval now; - struct timespec to; - - pthread_mutex_lock(&q->mutex); - if (!(rbio = (BIO *)list_shift(q->entries))) { - if (timeout) { - gettimeofday(&now, NULL); - memset(&to, 0, sizeof(struct timespec)); - to.tv_sec = now.tv_sec + timeout; - pthread_cond_timedwait(&q->cond, &q->mutex, &to); - } else - pthread_cond_wait(&q->cond, &q->mutex); - rbio = (BIO *)list_shift(q->entries); - } - pthread_mutex_unlock(&q->mutex); - return rbio; -} - -int dtlsread(SSL *ssl, struct gqueue *q, unsigned char *buf, int num, int timeout) { +int dtlsread(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock) { int len, cnt; - BIO *rbio; + struct pollfd fds[1]; + unsigned long error; + assert(lock); + + pthread_mutex_lock(lock); for (len = 0; len < num; len += cnt) { - cnt = SSL_read(ssl, buf + len, num - len); - if (cnt <= 0) - switch (cnt = SSL_get_error(ssl, cnt)) { - case SSL_ERROR_WANT_READ: - rbio = getrbio(ssl, q, timeout); - if (!rbio) - return 0; - BIO_free(SSL_get_rbio(ssl)); - SSL_set_bio(ssl, rbio, SSL_get_wbio(ssl)); - cnt = 0; - continue; - case SSL_ERROR_WANT_WRITE: - cnt = 0; - continue; - case SSL_ERROR_ZERO_RETURN: - /* remote end sent close_notify, send one back */ - SSL_shutdown(ssl); - return -1; - default: - return -1; - } - } - return num; -} + if (!SSL_pending(ssl)) { + fds[0].fd = BIO_get_fd(SSL_get_rbio(ssl), NULL); + fds[0].events = POLLIN; + + pthread_mutex_unlock(lock); + + cnt = poll(fds, 1, timeout? timeout * 1000 : -1); + if (cnt < 1) + return cnt; + if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) { + pthread_mutex_lock(lock); + SSL_shutdown(ssl); + pthread_mutex_unlock(lock); + return -1; + } + + pthread_mutex_lock(lock); + } -/* accept if acc == 1, else connect */ -SSL *dtlsacccon(uint8_t acc, SSL_CTX *ctx, int s, struct sockaddr *addr, struct gqueue *rbios) { - SSL *ssl; - int i, res; - unsigned long error; - BIO *mem0bio, *wbio; - - ssl = SSL_new(ctx); - if (!ssl) - return NULL; - - mem0bio = BIO_new(BIO_s_mem()); - BIO_set_mem_eof_return(mem0bio, -1); - wbio = BIO_new_dgram(s, BIO_NOCLOSE); - i = BIO_dgram_set_peer(wbio, addr); /* i just to avoid warning */ - SSL_set_bio(ssl, mem0bio, wbio); - - for (i = 0; i < 5; i++) { - res = acc ? SSL_accept(ssl) : SSL_connect(ssl); - if (res > 0) - return ssl; - if (res == 0) - break; - if (SSL_get_error(ssl, res) == SSL_ERROR_WANT_READ) { - BIO_free(SSL_get_rbio(ssl)); - SSL_set_bio(ssl, getrbio(ssl, rbios, 5), SSL_get_wbio(ssl)); - if (!SSL_get_rbio(ssl)) - break; + cnt = SSL_read(ssl, buf + len, num - len); + if (cnt <= 0) + switch (cnt = SSL_get_error(ssl, cnt)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + cnt = 0; + continue; + case SSL_ERROR_ZERO_RETURN: + debug(DBG_DBG, "dtlsread: got ssl shutdown"); + default: + while ((error = ERR_get_error())) + debug(DBG_ERR, "dtlsread: SSL: %s", ERR_error_string(error, NULL)); + /* snsure ssl connection is shutdown */ + SSL_shutdown(ssl); + pthread_mutex_unlock(lock); + return -1; } - while ((error = ERR_get_error())) - debug(DBG_ERR, "dtls%st: DTLS: %s", acc ? "accep" : "connec", ERR_error_string(error, NULL)); } - - SSL_free(ssl); - return NULL; + pthread_mutex_unlock(lock); + return num; } -unsigned char *raddtlsget(SSL *ssl, struct gqueue *rbios, int timeout) { +unsigned char *raddtlsget(SSL *ssl, int timeout, pthread_mutex_t *lock) { int cnt, len; unsigned char buf[4], *rad; - for (;;) { - cnt = dtlsread(ssl, rbios, buf, 4, timeout); - if (cnt < 1) { - debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout"); - return NULL; - } - - len = RADLEN(buf); - if (len < 4) { - debug(DBG_ERR, "raddtlsget: length too small"); - continue; - } - rad = malloc(len); - if (!rad) { - debug(DBG_ERR, "raddtlsget: malloc failed"); - continue; - } - memcpy(rad, buf, 4); - - cnt = dtlsread(ssl, rbios, rad + 4, len - 4, timeout); - if (cnt < 1) { - debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout"); - free(rad); - return NULL; - } + cnt = dtlsread(ssl, buf, 4, timeout, lock); + if (cnt < 1) { + debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout"); + return NULL; + } - if (len >= 20) - break; + len = RADLEN(buf); + if (len < 20) { + debug(DBG_ERR, "raddtlsget: length too small, malformed packet! closing conneciton!"); + return NULL; + } + rad = malloc(len); + if (!rad) { + debug(DBG_ERR, "raddtlsget: malloc failed"); + return NULL; + } + memcpy(rad, buf, 4); + cnt = dtlsread(ssl, rad + 4, len - 4, timeout, lock); + if (cnt < 1) { + debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout"); free(rad); - debug(DBG_WARN, "raddtlsget: packet smaller than minimum radius size"); + return NULL; } debug(DBG_DBG, "raddtlsget: got %d bytes", len); @@ -274,30 +188,47 @@ void *dtlsserverwr(void *arg) { debug(DBG_DBG, "dtlsserverwr: starting for %s", addr2string(client->addr)); replyq = client->replyq; for (;;) { - pthread_mutex_lock(&replyq->mutex); - while (!list_first(replyq->entries)) { - if (client->ssl) { - debug(DBG_DBG, "dtlsserverwr: waiting for signal"); - pthread_cond_wait(&replyq->cond, &replyq->mutex); - debug(DBG_DBG, "dtlsserverwr: got signal"); - } - if (!client->ssl) { - /* ssl might have changed while waiting */ - pthread_mutex_unlock(&replyq->mutex); - debug(DBG_DBG, "dtlsserverwr: exiting as requested"); - pthread_exit(NULL); - } - } - reply = (struct request *)list_shift(replyq->entries); - pthread_mutex_unlock(&replyq->mutex); - cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf)); - if (cnt > 0) - debug(DBG_DBG, "dtlsserverwr: sent %d bytes, Radius packet of length %d to %s", - cnt, RADLEN(reply->replybuf), addr2string(client->addr)); - else - while ((error = ERR_get_error())) - debug(DBG_ERR, "dtlsserverwr: SSL: %s", ERR_error_string(error, NULL)); - freerq(reply); + pthread_mutex_lock(&replyq->mutex); + while (!list_first(replyq->entries)) { + if (client->ssl) { + debug(DBG_DBG, "dtlsserverwr: waiting for signal"); + pthread_cond_wait(&replyq->cond, &replyq->mutex); + debug(DBG_DBG, "dtlsserverwr: got signal"); + } else + break; + } + + reply = (struct request *)list_shift(replyq->entries); + pthread_mutex_unlock(&replyq->mutex); + + pthread_mutex_lock(&client->lock); + if (!client->ssl) { + /* ssl might have changed while waiting */ + pthread_mutex_unlock(&client->lock); + if (reply) + freerq(reply); + debug(DBG_DBG, "tlsserverwr: exiting as requested"); + pthread_exit(NULL); + } + + while ((cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf))) <= 0) { + switch (SSL_get_error(client->ssl, cnt)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + continue; + default: + while ((error = ERR_get_error())) + debug(DBG_ERR, "dtlsserverwr: SSL: %s", ERR_error_string(error, NULL)); + pthread_mutex_unlock(&client->lock); + freerq(reply); + debug(DBG_DBG, "tlsserverwr: SSL error. exiting."); + pthread_exit(NULL); + } + } + debug(DBG_DBG, "dtlsserverwr: sent %d bytes, Radius packet of length %d to %s", + cnt, RADLEN(reply->replybuf), addr2string(client->addr)); + pthread_mutex_unlock(&client->lock); + freerq(reply); } } @@ -314,7 +245,7 @@ void dtlsserverrd(struct client *client) { } for (;;) { - buf = raddtlsget(client->ssl, client->rbios, IDLE_TIMEOUT); + buf = raddtlsget(client->ssl, IDLE_TIMEOUT * 3, &client->lock); if (!buf) { debug(DBG_ERR, "dtlsserverrd: connection from %s lost", addr2string(client->addr)); break; @@ -334,8 +265,9 @@ void dtlsserverrd(struct client *client) { } /* stop writer by setting ssl to NULL and give signal in case waiting for data */ + pthread_mutex_lock(&client->lock); client->ssl = NULL; - + pthread_mutex_unlock(&client->lock); pthread_mutex_lock(&client->replyq->mutex); pthread_cond_signal(&client->replyq->cond); pthread_mutex_unlock(&client->replyq->mutex); @@ -349,245 +281,297 @@ void *dtlsservernew(void *arg) { struct client *client; struct clsrvconf *conf; struct list_node *cur = NULL; - SSL *ssl = NULL; X509 *cert = NULL; - SSL_CTX *ctx = NULL; - uint8_t delay = 60; struct tls *accepted_tls = NULL; + int s; + unsigned long error; + struct timeval timeout; + struct addrinfo tmpsrvaddr; - debug(DBG_DBG, "dtlsservernew: starting"); - conf = find_clconf(handle, (struct sockaddr *)¶ms->addr, NULL); - if (conf) { - ctx = tlsgetctx(handle, conf->tlsconf); - if (!ctx) - goto exit; - ssl = dtlsacccon(1, ctx, params->sock, (struct sockaddr *)¶ms->addr, params->sesscache->rbios); - if (!ssl) - goto exit; - cert = verifytlscert(ssl); - if (!cert) - goto exit; - accepted_tls = conf->tlsconf; + debug(DBG_WARN, "dtlsservernew: incoming DTLS connection from %s", addr2string((struct sockaddr *)¶ms->addr)); + + if (!srcres) + dtlssetsrcres(); + memcpy(&tmpsrvaddr, srcres, sizeof(struct addrinfo)); + tmpsrvaddr.ai_addr = (struct sockaddr *)¶ms->bind; + tmpsrvaddr.ai_addrlen = SOCKADDR_SIZE(params->bind); + if ((s = bindtoaddr(&tmpsrvaddr, params->addr.ss_family, 1)) < 0) + goto exit; + if (connect(s, (struct sockaddr *)¶ms->addr, SOCKADDR_SIZE(params->addr))) + goto exit; + + BIO_set_fd(SSL_get_rbio(params->ssl), s, BIO_NOCLOSE); + BIO_ctrl(SSL_get_rbio(params->ssl), BIO_CTRL_DGRAM_SET_CONNECTED, 0,(struct sockaddr *)¶ms->addr); + + if (SSL_accept(params->ssl) <= 0) { + while ((error = ERR_get_error())) + debug(DBG_ERR, "dtlsservernew: SSL: %s", ERR_error_string(error, NULL)); + debug(DBG_ERR, "dtlsservernew: SSL_accept failed"); + goto exit; } + timeout.tv_sec = 5; + timeout.tv_usec = 0; + BIO_ctrl(SSL_get_rbio(params->ssl), BIO_CTRL_DGRAM_SET_RECV_TIMEOUT, 0, &timeout); + + conf = find_clconf(handle, (struct sockaddr *)¶ms->addr, NULL); + if (!conf) + goto exit; + + cert = verifytlscert(params->ssl); + if (!cert) + goto exit; + accepted_tls = conf->tlsconf; while (conf) { - if (accepted_tls == conf->tlsconf && verifyconfcert(cert, conf)) { - X509_free(cert); - client = addclient(conf, 1); - if (client) { - client->sock = params->sock; - client->addr = addr_copy((struct sockaddr *)¶ms->addr); - client->rbios = params->sesscache->rbios; - client->ssl = ssl; - dtlsserverrd(client); - removeclient(client); - delay = 0; - } else { - debug(DBG_WARN, "dtlsservernew: failed to create new client instance"); - } - goto exit; - } - conf = find_clconf(handle, (struct sockaddr *)¶ms->addr, &cur); + if (accepted_tls == conf->tlsconf && verifyconfcert(cert, conf)) { + X509_free(cert); + client = addclient(conf, 1); + if (client) { + client->sock = s; + client->addr = addr_copy((struct sockaddr *)¶ms->addr); + client->ssl = params->ssl; + dtlsserverrd(client); + removeclient(client); + } else { + debug(DBG_WARN, "dtlsservernew: failed to create new client instance"); + } + goto exit; + } + conf = find_clconf(handle, (struct sockaddr *)¶ms->addr, &cur); } debug(DBG_WARN, "dtlsservernew: ignoring request, no matching TLS client"); if (cert) - X509_free(cert); + X509_free(cert); exit: - if (ssl) { - SSL_shutdown(ssl); - SSL_free(ssl); + if (params->ssl) { + SSL_shutdown(params->ssl); + SSL_free(params->ssl); } - pthread_mutex_lock(¶ms->sesscache->mutex); - freebios(params->sesscache->rbios); - params->sesscache->rbios = NULL; - gettimeofday(¶ms->sesscache->expiry, NULL); - params->sesscache->expiry.tv_sec += delay; - pthread_mutex_unlock(¶ms->sesscache->mutex); + if(s >= 0) + close(s); free(params); - pthread_exit(NULL); debug(DBG_DBG, "dtlsservernew: exiting"); + pthread_exit(NULL); } -void cacheexpire(struct hash *cache, struct timeval *last) { - struct timeval now; - struct hash_entry *he; - struct sessioncacheentry *e; - - gettimeofday(&now, NULL); - if (now.tv_sec - last->tv_sec < 19) - return; - - for (he = hash_first(cache); he; he = hash_next(he)) { - e = (struct sessioncacheentry *)he->data; - pthread_mutex_lock(&e->mutex); - if (!e->expiry.tv_sec || e->expiry.tv_sec > now.tv_sec) { - pthread_mutex_unlock(&e->mutex); - continue; - } - debug(DBG_DBG, "cacheexpire: freeing entry"); - hash_extract(cache, he->key, he->keylen); - if (e->rbios) { - freebios(e->rbios); - e->rbios = NULL; - } - pthread_mutex_unlock(&e->mutex); - pthread_mutex_destroy(&e->mutex); +int getConnectionInfo(int socket, struct sockaddr *from, socklen_t fromlen, struct sockaddr *to, socklen_t tolen) { + uint8_t controlbuf[128]; + int offset = 0, ret, toaddrfound = 0; + struct cmsghdr *ctrlhdr; + struct msghdr msghdr; + struct in6_pktinfo *info6; + + char tmp[48]; + + msghdr.msg_name = from; + msghdr.msg_namelen = fromlen; + msghdr.msg_iov = NULL; + msghdr.msg_iovlen = 0; + msghdr.msg_control = controlbuf; + msghdr.msg_controllen = sizeof(controlbuf); + msghdr.msg_flags = 0; + + if ((ret = recvmsg(socket, &msghdr, MSG_PEEK | MSG_TRUNC)) < 0) + return ret; + + debug(DBG_DBG, "udp packet from %s", addr2string(from)); + + getsockname(socket, to, &tolen); + while (offset < msghdr.msg_controllen) { + ctrlhdr = (struct cmsghdr *)(controlbuf+offset); + if(ctrlhdr->cmsg_level == IPPROTO_IP && ctrlhdr->cmsg_type == IP_PKTINFO) { + debug(DBG_DBG, "udp packet to: %s", inet_ntop(AF_INET, &((struct in_pktinfo *)(ctrlhdr->__cmsg_data))->ipi_addr, tmp, sizeof(tmp))); + + ((struct sockaddr_in *)to)->sin_addr = ((struct in_pktinfo *)(ctrlhdr->__cmsg_data))->ipi_addr; + toaddrfound = 1; + } else if(ctrlhdr->cmsg_level == IPPROTO_IPV6 && ctrlhdr->cmsg_type == IPV6_RECVPKTINFO) { + info6 = (struct in6_pktinfo *)ctrlhdr->__cmsg_data; + debug(DBG_DBG, "udp packet to: %x", inet_ntop(AF_INET6, &info6->ipi6_addr, tmp, sizeof(tmp))); + + ((struct sockaddr_in6 *)to)->sin6_addr = info6->ipi6_addr; + ((struct sockaddr_in6 *)to)->sin6_scope_id = info6->ipi6_ifindex; + toaddrfound = 1; + } + offset += ctrlhdr->cmsg_len; } - last->tv_sec = now.tv_sec; + return toaddrfound ? ret : -1; } -void *udpdtlsserverrd(void *arg) { - int ndesc, cnt, s = *(int *)arg; +void *dtlslistener(void *arg) { + int ndesc, s = *(int *)arg; unsigned char buf[4]; - struct sockaddr_storage from; - socklen_t fromlen = sizeof(from); + struct sockaddr_storage from, to; struct dtlsservernewparams *params; struct pollfd fds[1]; - struct timeval lastexpiry; pthread_t dtlsserverth; - struct hash *sessioncache; - struct sessioncacheentry *cacheentry; + BIO *bio; + struct clsrvconf *conf; + SSL *ssl; + SSL_CTX *ctx; - sessioncache = hash_create(); - if (!sessioncache) - debugx(1, DBG_ERR, "udpdtlsserverrd: malloc failed"); - gettimeofday(&lastexpiry, NULL); + + + debug(DBG_DBG, "dtlslistener: starting"); for (;;) { - fds[0].fd = s; - fds[0].events = POLLIN; - ndesc = poll(fds, 1, 60000); - if (ndesc < 1) { - cacheexpire(sessioncache, &lastexpiry); - continue; - } - cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen); - if (cnt == -1) { - debug(DBG_WARN, "udpdtlsserverrd: recv failed"); - cacheexpire(sessioncache, &lastexpiry); - continue; - } - cacheentry = hash_read(sessioncache, &from, fromlen); - if (cacheentry) { - debug(DBG_DBG, "udpdtlsserverrd: cache hit"); - pthread_mutex_lock(&cacheentry->mutex); - if (cacheentry->rbios) { - if (udp2bio(s, cacheentry->rbios, cnt)) - debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from)); - } else - recv(s, buf, 1, 0); - pthread_mutex_unlock(&cacheentry->mutex); - cacheexpire(sessioncache, &lastexpiry); - continue; - } + fds[0].fd = s; + fds[0].events = POLLIN; + ndesc = poll(fds, 1, -1); + if (ndesc < 0) + continue; + + if (getConnectionInfo(s, (struct sockaddr *)&from, sizeof(from), (struct sockaddr *)&to, sizeof(to)) < 0) { + debug(DBG_DBG, "udptlsserverrd: getConnectionInfo failed"); + continue; + } - /* from new source */ - debug(DBG_DBG, "udpdtlsserverrd: cache miss"); - params = malloc(sizeof(struct dtlsservernewparams)); - if (!params) { - cacheexpire(sessioncache, &lastexpiry); - recv(s, buf, 1, 0); - continue; - } - memset(params, 0, sizeof(struct dtlsservernewparams)); - params->sesscache = malloc(sizeof(struct sessioncacheentry)); - if (!params->sesscache) { - free(params); - cacheexpire(sessioncache, &lastexpiry); - recv(s, buf, 1, 0); - continue; - } - memset(params->sesscache, 0, sizeof(struct sessioncacheentry)); - pthread_mutex_init(¶ms->sesscache->mutex, NULL); - params->sesscache->rbios = newqueue(); - if (hash_insert(sessioncache, &from, fromlen, params->sesscache)) { - params->sock = s; - memcpy(¶ms->addr, &from, fromlen); - - if (udp2bio(s, params->sesscache->rbios, cnt)) { - debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from)); - if (!pthread_create(&dtlsserverth, &pthread_attr, dtlsservernew, (void *)params)) { - pthread_detach(dtlsserverth); - cacheexpire(sessioncache, &lastexpiry); - continue; - } - debug(DBG_ERR, "udpdtlsserverrd: pthread_create failed"); - } - hash_extract(sessioncache, &from, fromlen); - } - freebios(params->sesscache->rbios); - pthread_mutex_destroy(¶ms->sesscache->mutex); - free(params->sesscache); - free(params); - cacheexpire(sessioncache, &lastexpiry); + conf = find_clconf(handle, (struct sockaddr *)&from, NULL); + if (!conf) { + debug(DBG_INFO, "udpdtlsserverrd: got UDP from unknown peer %s, ignoring", addr2string((struct sockaddr *)&from)); + recv(s, buf, 1, 0); + continue; + } + + pthread_mutex_lock(&conf->tlsconf->lock); + if (!conf->tlsconf->dtlssslprep) { + ctx = tlsgetctx(handle, conf->tlsconf); + if (!ctx) { + pthread_mutex_unlock(&conf->tlsconf->lock); + continue; + } + ssl = SSL_new(ctx); + if (!ssl) { + pthread_mutex_unlock(&conf->tlsconf->lock); + continue; + } + bio = BIO_new_dgram(s, BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + SSL_set_options(ssl, SSL_OP_COOKIE_EXCHANGE); + conf->tlsconf->dtlssslprep = ssl; + } else { + BIO_set_fd(SSL_get_rbio(conf->tlsconf->dtlssslprep), s, BIO_NOCLOSE); + } + + if(DTLSv1_listen(ssl, &from)) { + params = malloc(sizeof(struct dtlsservernewparams)); + memcpy(¶ms->addr, &from, sizeof(from)); + memcpy(¶ms->bind, &to, sizeof(to)); + params->ssl = conf->tlsconf->dtlssslprep;; + if (!pthread_create(&dtlsserverth, &pthread_attr, dtlsservernew, (void *)params)) { + pthread_detach(dtlsserverth); + conf->tlsconf->dtlssslprep = NULL; + pthread_mutex_unlock(&conf->tlsconf->lock); + continue; + } else { + free(params); + } + } + pthread_mutex_unlock(&conf->tlsconf->lock); } + return NULL; } int dtlsconnect(struct server *server, struct timeval *when, int timeout, char *text) { - struct timeval now; + struct timeval socktimeout, now, start = {0,0}; time_t elapsed; X509 *cert; SSL_CTX *ctx = NULL; struct hostportres *hp; + unsigned long error; + BIO *bio; debug(DBG_DBG, "dtlsconnect: called from %s", text); pthread_mutex_lock(&server->lock); - if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) { - /* already reconnected, nothing to do */ - debug(DBG_DBG, "dtlsconnect(%s): seems already reconnected", text); - pthread_mutex_unlock(&server->lock); - return 1; - } + + if (server->state == RSP_SERVER_STATE_CONNECTED) + server->state = RSP_SERVER_STATE_RECONNECTING; + hp = (struct hostportres *)list_first(server->conf->hostports)->data; + + gettimeofday(&now, NULL); + if (when && (now.tv_sec - when->tv_sec) < 60 ) + start.tv_sec = now.tv_sec - (60 - (now.tv_sec - when->tv_sec)); + for (;;) { - gettimeofday(&now, NULL); - elapsed = now.tv_sec - server->lastconnecttry.tv_sec; - - if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) { - debug(DBG_DBG, "dtlsconnect: timeout"); - SSL_free(server->ssl); - server->ssl = NULL; - pthread_mutex_unlock(&server->lock); - return 0; - } + /* ensure preioius connection is properly closed */ + if (server->ssl) + SSL_shutdown(server->ssl); + if (server->sock >= 0) + close(server->sock); + if (server->ssl) + SSL_free(server->ssl); + server->ssl = NULL; + + /* no sleep at startup or at first try */ + if (start.tv_sec) { + gettimeofday(&now, NULL); + elapsed = now.tv_sec - start.tv_sec; + + if (timeout && elapsed > timeout) { + debug(DBG_DBG, "tlsconnect: timeout"); + pthread_mutex_unlock(&server->lock); + return 0; + } + + /* give up lock while sleeping for next try */ + pthread_mutex_unlock(&server->lock); + if (elapsed < 1) + sleep(2); + else { + debug(DBG_INFO, "Next connection attempt in %lds", elapsed < 60 ? elapsed : 60); + sleep(elapsed < 60 ? elapsed : 60); + } + pthread_mutex_lock(&server->lock); + debug(DBG_INFO, "tlsconnect: retry connecting"); + } else { + gettimeofday(&start, NULL); + } + /* done sleeping */ - if (server->state == RSP_SERVER_STATE_CONNECTED) { - server->state = RSP_SERVER_STATE_RECONNECTING; - sleep(2); - } else if (elapsed < 1) - sleep(2); - else if (elapsed < 60) { - debug(DBG_INFO, "dtlsconnect: sleeping %lds", elapsed); - sleep(elapsed); - } else if (elapsed < 100000) { - debug(DBG_INFO, "dtlsconnect: sleeping %ds", 60); - sleep(60); - } else - server->lastconnecttry.tv_sec = now.tv_sec; /* no sleep at startup */ - debug(DBG_WARN, "dtlsconnect: trying to open DTLS connection to %s port %s", hp->host, hp->port); - - SSL_free(server->ssl); - server->ssl = NULL; - ctx = tlsgetctx(handle, server->conf->tlsconf); - if (!ctx) - continue; - server->ssl = dtlsacccon(0, ctx, server->sock, hp->addrinfo->ai_addr, server->rbios); - if (!server->ssl) - continue; - debug(DBG_DBG, "dtlsconnect: DTLS: ok"); + debug(DBG_WARN, "dtlsconnect: trying to open DTLS connection to %s port %s", hp->host, hp->port); - cert = verifytlscert(server->ssl); - if (!cert) - continue; + if ((server->sock = bindtoaddr(srcres, hp->addrinfo->ai_family, 0)) < 0) + continue; + if (connect(server->sock, hp->addrinfo->ai_addr, hp->addrinfo->ai_addrlen)) + continue; - if (verifyconfcert(cert, server->conf)) - break; - X509_free(cert); + pthread_mutex_lock(&server->conf->tlsconf->lock); + if (!(ctx = tlsgetctx(handle, server->conf->tlsconf))){ + pthread_mutex_unlock(&server->conf->tlsconf->lock); + continue; + } + + server->ssl = SSL_new(ctx); + pthread_mutex_unlock(&server->conf->tlsconf->lock); + if (!server->ssl) + continue; + + bio = BIO_new_dgram(server->sock, BIO_CLOSE); + BIO_ctrl(bio, BIO_CTRL_DGRAM_SET_CONNECTED, 0, hp->addrinfo->ai_addr); + SSL_set_bio(server->ssl, bio, bio); + if (SSL_connect(server->ssl) <= 0) { + while ((error = ERR_get_error())) + debug(DBG_ERR, "tlsconnect: DTLS: %s", ERR_error_string(error, NULL)); + continue; + } + socktimeout.tv_sec = 5; + socktimeout.tv_usec = 0; + BIO_ctrl(bio, BIO_CTRL_DGRAM_SET_RECV_TIMEOUT, 0, &socktimeout); + + debug(DBG_DBG, "dtlsconnect: DTLS: ok"); + + cert = verifytlscert(server->ssl); + if (!cert) + continue; + if (verifyconfcert(cert, server->conf)) { + X509_free(cert); + break; + } + X509_free(cert); } - X509_free(cert); debug(DBG_WARN, "dtlsconnect: DTLS connection to %s port %s up", hp->host, hp->port); server->state = RSP_SERVER_STATE_CONNECTED; gettimeofday(&server->lastconnecttry, NULL); @@ -600,103 +584,74 @@ int clientradputdtls(struct server *server, unsigned char *rad) { size_t len; unsigned long error; struct clsrvconf *conf = server->conf; + struct timespec timeout; + + timeout.tv_sec = 0; + timeout.tv_nsec = 1000000; if (server->state != RSP_SERVER_STATE_CONNECTED) - return 0; + return 0; + + if (pthread_mutex_timedlock(&server->lock, &timeout)) + return 0; + if (server->state != RSP_SERVER_STATE_CONNECTED) { + pthread_mutex_unlock(&server->lock); + return 0; + } + len = RADLEN(rad); - if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) { - while ((error = ERR_get_error())) - debug(DBG_ERR, "clientradputdtls: DTLS: %s", ERR_error_string(error, NULL)); - return 0; + while ((cnt = SSL_write(server->ssl, rad, len)) <= 0) { + switch (SSL_get_error(server->ssl, cnt)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + continue; + default: + while ((error = ERR_get_error())) + debug(DBG_ERR, "clientradputdtls: DTLS: %s", ERR_error_string(error, NULL)); + pthread_mutex_unlock(&server->lock); + return 0; + } } debug(DBG_DBG, "clientradputdtls: Sent %d bytes, Radius packet of length %d to DTLS peer %s", cnt, len, conf->name); + pthread_mutex_unlock(&server->lock); return 1; } -/* reads UDP containing DTLS and passes it on to dtlsclientrd */ -void *udpdtlsclientrd(void *arg) { - int cnt, s = *(int *)arg; - unsigned char buf[4]; - struct sockaddr_storage from; - socklen_t fromlen = sizeof(from); - struct clsrvconf *conf; - - for (;;) { - cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen); - if (cnt == -1) { - debug(DBG_WARN, "udpdtlsclientrd: recv failed"); - continue; - } - - conf = find_srvconf(handle, (struct sockaddr *)&from, NULL); - if (!conf) { - debug(DBG_WARN, "udpdtlsclientrd: got packet from wrong or unknown DTLS peer %s, ignoring", addr2string((struct sockaddr *)&from)); - recv(s, buf, 4, 0); - continue; - } - if (udp2bio(s, conf->servers->rbios, cnt)) - debug(DBG_DBG, "radudpget: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from)); - } -} - void *dtlsclientrd(void *arg) { struct server *server = (struct server *)arg; unsigned char *buf; struct timeval lastconnecttry; - int secs; for (;;) { /* yes, lastconnecttry is really necessary */ lastconnecttry = server->lastconnecttry; - for (secs = 0; !(buf = raddtlsget(server->ssl, server->rbios, 10)) && !server->lostrqs && secs < IDLE_TIMEOUT; secs += 10); + buf = raddtlsget(server->ssl, 5, &server->lock); if (!buf) { - dtlsconnect(server, &lastconnecttry, 0, "dtlsclientrd"); + if(SSL_get_shutdown(server->ssl) || server->lostrqs) { + if (server->lostrqs) + debug (DBG_WARN, "dtlsclientrd: server %s did not respond, closing connection.", server->conf->name); + dtlsconnect(server, &lastconnecttry, 0, "dtlsclientrd"); + server->lostrqs = 0; + } continue; } replyh(server, buf); } + + debug(DBG_INFO, "dtlsclientrd: exiting for %s", server->conf->name); + pthread_mutex_lock(&server->lock); + SSL_shutdown(server->ssl); + close(server->sock); + + /* Wake up clientwr(). */ server->clientrdgone = 1; + pthread_mutex_lock(&server->newrq_mutex); + pthread_cond_signal(&server->newrq_cond); + pthread_mutex_unlock(&server->newrq_mutex); + pthread_mutex_unlock(&server->lock); return NULL; } -void addserverextradtls(struct clsrvconf *conf) { - switch (((struct hostportres *)list_first(conf->hostports)->data)->addrinfo->ai_family) { - case AF_INET: - if (client4_sock < 0) { - client4_sock = bindtoaddr(srcres, AF_INET, 0); - if (client4_sock < 0) - debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name); - } - conf->servers->sock = client4_sock; - break; - case AF_INET6: - if (client6_sock < 0) { - client6_sock = bindtoaddr(srcres, AF_INET6, 0); - if (client6_sock < 0) - debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name); - } - conf->servers->sock = client6_sock; - break; - default: - debugx(1, DBG_ERR, "addserver: unsupported address family"); - } -} - -void initextradtls() { - pthread_t cl4th, cl6th; - - if (srcres) { - freeaddrinfo(srcres); - srcres = NULL; - } - - if (client4_sock >= 0) - if (pthread_create(&cl4th, &pthread_attr, udpdtlsclientrd, (void *)&client4_sock)) - debugx(1, DBG_ERR, "pthread_create failed"); - if (client6_sock >= 0) - if (pthread_create(&cl6th, &pthread_attr, udpdtlsclientrd, (void *)&client6_sock)) - debugx(1, DBG_ERR, "pthread_create failed"); -} #else const struct protodefs *dtlsinit(uint8_t h) { return NULL; diff --git a/dtls.h b/dtls.h index 2b6a336..70a1ce7 100644 --- a/dtls.h +++ b/dtls.h @@ -1,8 +1,11 @@ /* Copyright (c) 2008, UNINETT AS */ /* See LICENSE for licensing information. */ +#include + const struct protodefs *dtlsinit(uint8_t h); + /* Local Variables: */ /* c-file-style: "stroustrup" */ /* End: */ diff --git a/radsecproxy.c b/radsecproxy.c index bbb8ae4..f9d39c8 100644 --- a/radsecproxy.c +++ b/radsecproxy.c @@ -68,7 +68,6 @@ #include "udp.h" #include "tcp.h" #include "tls.h" -#include "dtls.h" #include "fticks.h" #include "fticks_hashmac.h" @@ -183,16 +182,6 @@ void removequeue(struct gqueue *q) { free(q); } -void freebios(struct gqueue *q) { - BIO *bio; - - pthread_mutex_lock(&q->mutex); - while ((bio = (BIO *)list_shift(q->entries))) - BIO_free(bio); - pthread_mutex_unlock(&q->mutex); - removequeue(q); -} - struct client *addclient(struct clsrvconf *conf, uint8_t lock) { struct client *new = NULL; @@ -304,8 +293,6 @@ void freeserver(struct server *server, uint8_t destroymutex) { } free(server->requests); } - if (server->rbios) - freebios(server->rbios); free(server->dynamiclookuparg); if (server->ssl) { SSL_free(server->ssl); @@ -334,10 +321,6 @@ int addserver(struct clsrvconf *conf) { memset(conf->servers, 0, sizeof(struct server)); conf->servers->conf = conf; -#ifdef RADPROT_DTLS - if (conf->type == RAD_DTLS) - conf->servers->rbios = newqueue(); -#endif conf->pdef->setsrcres(); conf->servers->sock = -1; @@ -1967,6 +1950,15 @@ void createlistener(uint8_t type, char *arg) { if (setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) debugerrno(errno, DBG_WARN, "createlistener: IPV6_V6ONLY"); #endif + if (res->ai_socktype == SOCK_DGRAM) { + if (res->ai_family == AF_INET6) { + if (setsockopt(s, IPPROTO_IPV6, IPV6_RECVPKTINFO, &on, sizeof(on)) == -1) + debugerrno(errno, DBG_WARN, "craetelistener: IPV6_RECVPKTINFO"); + } else if (res->ai_family == AF_INET) { + if (setsockopt(s, IPPROTO_IP, IP_PKTINFO, &on, sizeof(on)) == -1) + debugerrno(errno, DBG_WARN, "createlistener: IP_PKTINFO"); + } + } if (bind(s, res->ai_addr, res->ai_addrlen)) { debugerrno(errno, DBG_WARN, "createlistener: bind failed"); close(s); diff --git a/radsecproxy.h b/radsecproxy.h index 64a4090..b124ba9 100644 --- a/radsecproxy.h +++ b/radsecproxy.h @@ -10,6 +10,9 @@ #include "tlv11.h" #include "radmsg.h" #include "gconfig.h" +#ifdef RADPROT_DTLS +#include "dtls.h" +#endif #define DEBUG_LEVEL 2 @@ -167,7 +170,6 @@ struct client { pthread_mutex_t lock; struct request *rqs[MAX_REQUESTS]; struct gqueue *replyq; - struct gqueue *rbios; /* for dtls */ struct sockaddr *addr; time_t expiry; /* for udp */ }; @@ -190,7 +192,6 @@ struct server { uint8_t newrq; pthread_mutex_t newrq_mutex; pthread_cond_t newrq_cond; - struct gqueue *rbios; /* for dtls */ }; struct realm { @@ -256,7 +257,6 @@ struct client *addclient(struct clsrvconf *conf, uint8_t lock); void removelockedclient(struct client *client); void removeclient(struct client *client); struct gqueue *newqueue(); -void freebios(struct gqueue *q); struct request *newrequest(); void freerq(struct request *rq); int radsrv(struct request *rq); diff --git a/tlscommon.c b/tlscommon.c index fab629c..d0e27d6 100644 --- a/tlscommon.c +++ b/tlscommon.c @@ -36,6 +36,11 @@ static struct hash *tlsconfs = NULL; +#define COOKIE_SECRET_LENGTH 16 +static unsigned char cookie_secret[COOKIE_SECRET_LENGTH]; +static uint8_t cookie_secret_initialized = 0; + + /* callbacks for making OpenSSL < 1.1 thread safe */ #if OPENSSL_VERSION_NUMBER < 0x10100000 static pthread_mutex_t *ssl_locks = NULL; @@ -136,6 +141,88 @@ static int verify_cb(int ok, X509_STORE_CTX *ctx) { return ok; } +static int cookie_calculate_hash(struct sockaddr *peer, time_t time, uint8_t *result, unsigned int *resultlength) { + uint8_t *buf; + int length; + + length = SOCKADDRP_SIZE(peer) + sizeof(time_t); + buf = OPENSSL_malloc(length); + if (!buf) { + debug(DBG_ERR, "cookie_calculate_hash: malloc failed"); + return 0; + } + + memcpy(buf, &time, sizeof(time_t)); + memcpy(buf+sizeof(time_t), peer, SOCKADDRP_SIZE(peer)); + + HMAC(EVP_sha256(), (const void*) cookie_secret, COOKIE_SECRET_LENGTH, + buf, length, result, resultlength); + OPENSSL_free(buf); + return 1; +} + +static int cookie_generate_cb(SSL *ssl, unsigned char *cookie, unsigned int *cookie_len) { + struct sockaddr_storage peer; + struct timeval now; + uint8_t result[EVP_MAX_MD_SIZE]; + unsigned int resultlength; + + if (!cookie_secret_initialized) { + if (!RAND_bytes(cookie_secret, COOKIE_SECRET_LENGTH)) + debugx(1,DBG_ERR, "cookie_generate_cg: error generating random secret"); + cookie_secret_initialized = 1; + } + + BIO_dgram_get_peer(SSL_get_rbio(ssl), &peer); + gettimeofday(&now, NULL); + if (!cookie_calculate_hash((struct sockaddr *)&peer, now.tv_sec, result, &resultlength)) + return 0; + + memcpy(cookie, &now.tv_sec, sizeof(time_t)); + memcpy(cookie + sizeof(time_t), result, resultlength); + *cookie_len = resultlength + sizeof(time_t); + + return 1; +} + +static int cookie_verify_cb(SSL *ssl, unsigned char *cookie, unsigned int cookie_len) { + struct sockaddr_storage peer; + struct timeval now; + time_t cookie_time; + uint8_t result[EVP_MAX_MD_SIZE]; + unsigned int resultlength; + + if (!cookie_secret_initialized) + return 0; + + if (cookie_len < sizeof(time_t)) { + debug(DBG_DBG, "cookie_verify_cb: cookie too short. ignoring."); + return 0; + } + + gettimeofday(&now, NULL); + cookie_time = *(time_t *)cookie; + if (now.tv_sec - cookie_time > 5) { + debug(DBG_DBG, "cookie_verify_cb: cookie invalid or older than 5s. ignoring."); + return 0; + } + + BIO_dgram_get_peer(SSL_get_rbio(ssl), &peer); + if (!cookie_calculate_hash((struct sockaddr *)&peer, cookie_time, result, &resultlength)) + return 0; + + if (resultlength + sizeof(time_t) != cookie_len) { + debug(DBG_DBG, "cookie_verify_cb: invalid cookie length. ignoring."); + return 0; + } + + if (memcmp(cookie + sizeof(time_t), result, resultlength)) { + debug(DBG_DBG, "cookie_verify_cb: cookie not valid. ignoring."); + return 0; + } + return 1; +} + #ifdef DEBUG static void ssl_info_callback(const SSL *ssl, int where, int ret) { const char *s; @@ -222,6 +309,8 @@ static int tlsaddcacrl(SSL_CTX *ctx, struct tls *conf) { SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb); SSL_CTX_set_verify_depth(ctx, MAX_CERT_DEPTH + 1); + SSL_CTX_set_cookie_generate_cb(ctx, cookie_generate_cb); + SSL_CTX_set_cookie_verify_cb(ctx, cookie_verify_cb); if (conf->crlcheck || conf->vpm) { x509_s = SSL_CTX_get_cert_store(ctx); diff --git a/tlscommon.h b/tlscommon.h index b273148..b7a0ba2 100644 --- a/tlscommon.h +++ b/tlscommon.h @@ -24,6 +24,7 @@ struct tls { X509_VERIFY_PARAM *vpm; SSL_CTX *tlsctx; SSL_CTX *dtlsctx; + SSL *dtlssslprep; pthread_mutex_t lock; };