diff --git a/radsecproxy.c b/radsecproxy.c index 5b51669..d9d0888 100644 --- a/radsecproxy.c +++ b/radsecproxy.c @@ -215,7 +215,8 @@ struct client *addclient(struct clsrvconf *conf, uint8_t lock) { if (conf->pdef->addclient) conf->pdef->addclient(new); else - new->replyq = newqueue(); + new->replyq = newqueue(); + pthread_mutex_init(&new->lock, NULL); list_push(conf->clients, new); if (lock) pthread_mutex_unlock(conf->lock); @@ -261,6 +262,7 @@ void removelockedclient(struct client *client) { removeclientrqs(client); removequeue(client->replyq); list_removedata(conf->clients, client); + pthread_mutex_destroy(&client->lock); free(client->addr); free(client); } diff --git a/radsecproxy.h b/radsecproxy.h index ace2e01..45e23ab 100644 --- a/radsecproxy.h +++ b/radsecproxy.h @@ -157,6 +157,7 @@ struct client { struct clsrvconf *conf; int sock; SSL *ssl; + pthread_mutex_t lock; struct request *rqs[MAX_REQUESTS]; struct gqueue *replyq; struct gqueue *rbios; /* for dtls */ diff --git a/tls.c b/tls.c index 3f2132c..aaeec1b 100644 --- a/tls.c +++ b/tls.c @@ -84,7 +84,7 @@ void tlssetsrcres() { } int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) { - struct timeval now; + struct timeval now, start = {0,0}; time_t elapsed; X509 *cert; SSL_CTX *ctx = NULL; @@ -92,70 +92,79 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t debug(DBG_DBG, "tlsconnect: called from %s", text); pthread_mutex_lock(&server->lock); - if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) { - /* already reconnected, nothing to do */ - debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text); - pthread_mutex_unlock(&server->lock); - return 1; - } - for (;;) { - gettimeofday(&now, NULL); - elapsed = now.tv_sec - server->lastconnecttry.tv_sec; - if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) { - debug(DBG_DBG, "tlsconnect: timeout"); - if (server->sock >= 0) - close(server->sock); - SSL_free(server->ssl); - server->ssl = NULL; - pthread_mutex_unlock(&server->lock); - return 0; - } - if (server->state == RSP_SERVER_STATE_CONNECTED) { - server->state = RSP_SERVER_STATE_RECONNECTING; - sleep(2); - } else if (elapsed < 1) - sleep(2); - else if (elapsed < 60) { - debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed); - sleep(elapsed); - } else if (elapsed < 100000) { - debug(DBG_INFO, "tlsconnect: sleeping %ds", 60); - sleep(60); - } else - server->lastconnecttry.tv_sec = now.tv_sec; /* no sleep at startup */ - - if (server->sock >= 0) - close(server->sock); - if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) < 0) - continue; + if (server->state == RSP_SERVER_STATE_CONNECTED) + server->state = RSP_SERVER_STATE_RECONNECTING; - if (server->conf->keepalive) - enable_keepalive(server->sock); + gettimeofday(&now, NULL); + if (when && (now.tv_sec - when->tv_sec) < 60 ) + start.tv_sec = now.tv_sec - (60 - (now.tv_sec - when->tv_sec)); - SSL_free(server->ssl); - server->ssl = NULL; - ctx = tlsgetctx(handle, server->conf->tlsconf); - if (!ctx) - continue; - server->ssl = SSL_new(ctx); - if (!server->ssl) - continue; + for (;;) { + /* ensure preioius connection is properly closed */ + if (server->ssl) + SSL_shutdown(server->ssl); + if (server->sock >= 0) + close(server->sock); + if (server->ssl) + SSL_free(server->ssl); + server->ssl = NULL; + + /* no sleep at startup or at first try */ + if (start.tv_sec) { + gettimeofday(&now, NULL); + elapsed = now.tv_sec - start.tv_sec; + + if (timeout && elapsed > timeout) { + debug(DBG_DBG, "tlsconnect: timeout"); + pthread_mutex_unlock(&server->lock); + return 0; + } + + /* give up lock while sleeping for next try */ + pthread_mutex_unlock(&server->lock); + if (elapsed < 1) + sleep(2); + else { + debug(DBG_INFO, "Next connection attempt in %lds", elapsed < 60 ? elapsed : 60); + sleep(elapsed < 60 ? elapsed : 60); + } + pthread_mutex_lock(&server->lock); + debug(DBG_INFO, "tlsconnect: retry connecting"); + } else { + gettimeofday(&start, NULL); + } + /* done sleeping */ + + if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) < 0) + continue; + + pthread_mutex_lock(&server->conf->tlsconf->lock); + if (!(ctx = tlsgetctx(handle, server->conf->tlsconf))){ + pthread_mutex_unlock(&server->conf->tlsconf->lock); + continue; + } + + server->ssl = SSL_new(ctx); + pthread_mutex_unlock(&server->conf->tlsconf->lock); + if (!server->ssl) + continue; + + SSL_set_fd(server->ssl, server->sock); + if (SSL_connect(server->ssl) <= 0) { + while ((error = ERR_get_error())) + debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL)); + continue; + } - SSL_set_fd(server->ssl, server->sock); - if (SSL_connect(server->ssl) <= 0) { - while ((error = ERR_get_error())) - debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL)); - continue; - } - cert = verifytlscert(server->ssl); - if (!cert) - continue; - if (verifyconfcert(cert, server->conf)) { - X509_free(cert); - break; - } - X509_free(cert); + cert = verifytlscert(server->ssl); + if (!cert) + continue; + if (verifyconfcert(cert, server->conf)) { + X509_free(cert); + break; + } + X509_free(cert); } debug(DBG_WARN, "tlsconnect: TLS connection to %s up", server->conf->name); server->state = RSP_SERVER_STATE_CONNECTED; @@ -166,50 +175,68 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */ /* returns 0 on timeout, -1 on error and num if ok */ -int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) { +int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout, pthread_mutex_t *lock) { int s, ndesc, cnt, len; struct pollfd fds[1]; + if (lock) + pthread_mutex_lock(lock); + s = SSL_get_fd(ssl); - if (s < 0) - return -1; + if (s < 0){ + if (lock) + pthread_mutex_unlock(lock); + return -1; + } + /* make socket non-blocking? */ for (len = 0; len < num; len += cnt) { - if (SSL_pending(ssl) == 0) { - fds[0].fd = s; - fds[0].events = POLLIN; - ndesc = poll(fds, 1, timeout ? timeout * 1000 : -1); - if (ndesc < 1) - return ndesc; - if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) - return -1; - } + if (SSL_pending(ssl) == 0) { + if (lock) + pthread_mutex_unlock(lock); + + fds[0].fd = s; + fds[0].events = POLLIN; + ndesc = poll(fds, 1, timeout ? timeout * 1000 : -1); + if (ndesc < 1) + return ndesc; + if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL)) + return -1; + + if (lock) + pthread_mutex_lock(lock); + } - cnt = SSL_read(ssl, buf + len, num - len); - if (cnt <= 0) - switch (SSL_get_error(ssl, cnt)) { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - cnt = 0; - continue; - case SSL_ERROR_ZERO_RETURN: - /* remote end sent close_notify, send one back */ - SSL_shutdown(ssl); - return -1; - default: - return -1; - } + cnt = SSL_read(ssl, buf + len, num - len); + if (cnt <= 0) { + switch (SSL_get_error(ssl, cnt)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + cnt = 0; + continue; + case SSL_ERROR_ZERO_RETURN: + /* remote end sent close_notify, send one back */ + debug(DBG_DBG, "sslreadtimeout: got ssl shutdown"); + SSL_shutdown(ssl); + default: + if (lock) + pthread_mutex_unlock(lock); + return -1; + } + } } - return num; + if (lock) + pthread_mutex_unlock(lock); + return cnt; } /* timeout in seconds, 0 means no timeout (blocking) */ -unsigned char *radtlsget(SSL *ssl, int timeout) { +unsigned char *radtlsget(SSL *ssl, int timeout, pthread_mutex_t *lock) { int cnt, len; unsigned char buf[4], *rad; for (;;) { - cnt = sslreadtimeout(ssl, buf, 4, timeout); + cnt = sslreadtimeout(ssl, buf, 4, timeout, lock); if (cnt < 1) { debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout"); return NULL; @@ -227,7 +254,7 @@ unsigned char *radtlsget(SSL *ssl, int timeout) { } memcpy(rad, buf, 4); - cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout); + cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout, lock); if (cnt < 1) { debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout"); free(rad); @@ -245,22 +272,53 @@ unsigned char *radtlsget(SSL *ssl, int timeout) { return rad; } +int dosslwrite(SSL *ssl, void *buf, int num, uint8_t may_block){ + int ret; + unsigned long error; + struct pollfd fds[1]; + + if(!may_block) { + fds[0].fd = SSL_get_fd(ssl); + fds[0].events = POLLOUT; + if (!poll(fds, 1, 0)) { + debug(DBG_DBG, "dosslwrite: SSL not ready or buffer full; avoid blocking..."); + return 0; + } + } + + while ((ret = SSL_write(ssl, buf, num)) <= 0) { + switch (SSL_get_error(ssl, ret)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + continue; + default: + while ((error = ERR_get_error())) + debug(DBG_ERR, "dosslwrite: SSL: %s", ERR_error_string(error, NULL)); + return ret; + } + } + return ret; +} + int clientradputtls(struct server *server, unsigned char *rad) { int cnt; size_t len; - unsigned long error; struct clsrvconf *conf = server->conf; - if (server->state != RSP_SERVER_STATE_CONNECTED) - return 0; + pthread_mutex_lock(&server->lock); + if (server->state != RSP_SERVER_STATE_CONNECTED) { + pthread_mutex_unlock(&server->lock); + return 0; + } + len = RADLEN(rad); - if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) { - while ((error = ERR_get_error())) - debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL)); - return 0; + if ((cnt = dosslwrite(server->ssl, rad, len, 0)) <= 0) { + pthread_mutex_unlock(&server->lock); + return 0; } debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->name); + pthread_mutex_unlock(&server->lock); return 1; } @@ -272,7 +330,7 @@ void *tlsclientrd(void *arg) { for (;;) { /* yes, lastconnecttry is really necessary */ lastconnecttry = server->lastconnecttry; - buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0); + buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0, &server->lock); if (!buf) { if (server->dynamiclookuparg) break; @@ -305,7 +363,6 @@ void *tlsclientrd(void *arg) { void *tlsserverwr(void *arg) { int cnt; - unsigned long error; struct client *client = (struct client *)arg; struct gqueue *replyq; struct request *reply; @@ -313,30 +370,33 @@ void *tlsserverwr(void *arg) { debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr)); replyq = client->replyq; for (;;) { - pthread_mutex_lock(&replyq->mutex); - while (!list_first(replyq->entries)) { - if (client->ssl) { - debug(DBG_DBG, "tlsserverwr: waiting for signal"); - pthread_cond_wait(&replyq->cond, &replyq->mutex); - debug(DBG_DBG, "tlsserverwr: got signal"); - } - if (!client->ssl) { - /* ssl might have changed while waiting */ - pthread_mutex_unlock(&replyq->mutex); - debug(DBG_DBG, "tlsserverwr: exiting as requested"); - pthread_exit(NULL); - } - } - reply = (struct request *)list_shift(replyq->entries); - pthread_mutex_unlock(&replyq->mutex); - cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf)); - if (cnt > 0) - debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s", - cnt, RADLEN(reply->replybuf), addr2string(client->addr)); - else - while ((error = ERR_get_error())) - debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL)); - freerq(reply); + pthread_mutex_lock(&replyq->mutex); + while (!list_first(replyq->entries)) { + if (client->ssl) { + debug(DBG_DBG, "tlsserverwr: waiting for signal"); + pthread_cond_wait(&replyq->cond, &replyq->mutex); + debug(DBG_DBG, "tlsserverwr: got signal"); + } else + break; + } + + reply = (struct request *)list_shift(replyq->entries); + pthread_mutex_unlock(&replyq->mutex); + + pthread_mutex_lock(&client->lock); + if (!client->ssl) { + /* ssl might have changed while waiting */ + pthread_mutex_unlock(&client->lock); + debug(DBG_DBG, "tlsserverwr: exiting as requested"); + pthread_exit(NULL); + } + + if ((cnt = dosslwrite(client->ssl, reply->replybuf, RADLEN(reply->replybuf), 0)) > 0) { + debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s", + cnt, RADLEN(reply->replybuf), addr2string(client->addr)); + } + pthread_mutex_unlock(&client->lock); + freerq(reply); } } @@ -353,7 +413,7 @@ void tlsserverrd(struct client *client) { } for (;;) { - buf = radtlsget(client->ssl, IDLE_TIMEOUT * 3); + buf = radtlsget(client->ssl, IDLE_TIMEOUT * 3, &client->lock); if (!buf) { debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr)); break; @@ -373,7 +433,9 @@ void tlsserverrd(struct client *client) { } /* stop writer by setting ssl to NULL and give signal in case waiting for data */ + pthread_mutex_lock(&client->lock); client->ssl = NULL; + pthread_mutex_unlock(&client->lock); pthread_mutex_lock(&client->replyq->mutex); pthread_cond_signal(&client->replyq->cond); pthread_mutex_unlock(&client->replyq->mutex); @@ -405,24 +467,29 @@ void *tlsservernew(void *arg) { conf = find_clconf(handle, (struct sockaddr *)&from, &cur); if (conf) { - ctx = tlsgetctx(handle, conf->tlsconf); - if (!ctx) - goto exit; - ssl = SSL_new(ctx); - if (!ssl) - goto exit; - SSL_set_fd(ssl, s); - - if (SSL_accept(ssl) <= 0) { - while ((error = ERR_get_error())) - debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL)); - debug(DBG_ERR, "tlsservernew: SSL_accept failed"); - goto exit; - } - cert = verifytlscert(ssl); - if (!cert) - goto exit; - accepted_tls = conf->tlsconf; + pthread_mutex_lock(&conf->tlsconf->lock); + ctx = tlsgetctx(handle, conf->tlsconf); + if (!ctx) { + pthread_mutex_unlock(&conf->tlsconf->lock); + goto exit; + } + + ssl = SSL_new(ctx); + pthread_mutex_unlock(&conf->tlsconf->lock); + if (!ssl) + goto exit; + + SSL_set_fd(ssl, s); + if (SSL_accept(ssl) <= 0) { + while ((error = ERR_get_error())) + debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL)); + debug(DBG_ERR, "tlsservernew: SSL_accept failed"); + goto exit; + } + cert = verifytlscert(ssl); + if (!cert) + goto exit; + accepted_tls = conf->tlsconf; } while (conf) { diff --git a/tlscommon.c b/tlscommon.c index 7d95178..d5b5325 100644 --- a/tlscommon.c +++ b/tlscommon.c @@ -617,6 +617,7 @@ int conftls_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *v debug(DBG_ERR, "conftls_cb: malloc failed"); goto errexit; } + pthread_mutex_init(&conf->lock, NULL); if (!tlsconfs) tlsconfs = hash_create(); diff --git a/tlscommon.h b/tlscommon.h index 2b98a9c..317d1e4 100644 --- a/tlscommon.h +++ b/tlscommon.h @@ -24,6 +24,7 @@ struct tls { X509_VERIFY_PARAM *vpm; SSL_CTX *tlsctx; SSL_CTX *dtlsctx; + pthread_mutex_t lock; }; #if defined(RADPROT_TLS) || defined(RADPROT_DTLS)