From 408151a5a4271ab0e3d74ae2336c5162b035a2fc Mon Sep 17 00:00:00 2001
From: Fabian Mauchle <fabian.mauchle@switch.ch>
Date: Fri, 23 Feb 2018 21:54:39 +0100
Subject: [PATCH] replace select by poll

---
 ChangeLog   |  1 +
 dtls.c      | 19 ++++++-------------
 tcp.c       | 18 ++++++++----------
 tls.c       | 21 +++++++++------------
 tlscommon.c |  1 -
 udp.c       |  7 +------
 util.c      | 29 ++++++++++++++---------------
 7 files changed, 39 insertions(+), 57 deletions(-)

diff --git a/ChangeLog b/ChangeLog
index 1ff479c..8342b1b 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -17,6 +17,7 @@ Changes between 1.6.9 and the master branch
 	still enables code known to be buggy.
 	- Replace several server status bits with a single state enum.
 	(RADSECPROXY-71)
+	- Use poll instead of select to allow > 1000 concurrent connections.
 
 	Bug fixes:
 	- Detect the presence of docbook2x-man correctly.
diff --git a/dtls.c b/dtls.c
index 54fb5c7..822f00b 100644
--- a/dtls.c
+++ b/dtls.c
@@ -14,7 +14,7 @@
 #endif
 #include <sys/time.h>
 #include <sys/types.h>
-#include <sys/select.h>
+#include <poll.h>
 #include <ctype.h>
 #include <sys/wait.h>
 #include <arpa/inet.h>
@@ -444,8 +444,8 @@ void *udpdtlsserverrd(void *arg) {
     struct sockaddr_storage from;
     socklen_t fromlen = sizeof(from);
     struct dtlsservernewparams *params;
-    fd_set readfds;
-    struct timeval timeout, lastexpiry;
+    struct pollfd fds[1];
+    struct timeval lastexpiry;
     pthread_t dtlsserverth;
     struct hash *sessioncache;
     struct sessioncacheentry *cacheentry;
@@ -456,11 +456,9 @@ void *udpdtlsserverrd(void *arg) {
     gettimeofday(&lastexpiry, NULL);
 
     for (;;) {
-	FD_ZERO(&readfds);
-        FD_SET(s, &readfds);
-	memset(&timeout, 0, sizeof(struct timeval));
-	timeout.tv_sec = 60;
-	ndesc = select(s + 1, &readfds, NULL, NULL, &timeout);
+    fds[0].fd = s;
+    fds[0].events = POLLIN;
+	ndesc = poll(fds, 1, 60000);
 	if (ndesc < 1) {
 	    cacheexpire(sessioncache, &lastexpiry);
 	    continue;
@@ -622,13 +620,8 @@ void *udpdtlsclientrd(void *arg) {
     struct sockaddr_storage from;
     socklen_t fromlen = sizeof(from);
     struct clsrvconf *conf;
-    fd_set readfds;
 
     for (;;) {
-	FD_ZERO(&readfds);
-        FD_SET(s, &readfds);
-	if (select(s + 1, &readfds, NULL, NULL, NULL) < 1)
-	    continue;
 	cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
 	if (cnt == -1) {
 	    debug(DBG_WARN, "udpdtlsclientrd: recv failed");
diff --git a/tcp.c b/tcp.c
index 80cabee..d4ffccd 100644
--- a/tcp.c
+++ b/tcp.c
@@ -14,7 +14,7 @@
 #endif
 #include <sys/time.h>
 #include <sys/types.h>
-#include <sys/select.h>
+#include <poll.h>
 #include <ctype.h>
 #include <sys/wait.h>
 #include <arpa/inet.h>
@@ -134,23 +134,21 @@ int tcpconnect(struct server *server, struct timeval *when, int timeout, char *t
 /* returns 0 on timeout, -1 on error and num if ok */
 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
     int ndesc, cnt, len;
-    fd_set readfds;
-    struct timeval timer;
+    struct pollfd fds[1];
 
     if (s < 0)
 	return -1;
     /* make socket non-blocking? */
     for (len = 0; len < num; len += cnt) {
-	FD_ZERO(&readfds);
-	FD_SET(s, &readfds);
-	if (timeout) {
-	    timer.tv_sec = timeout;
-	    timer.tv_usec = 0;
-	}
-	ndesc = select(s + 1, &readfds, NULL, NULL, timeout ? &timer : NULL);
+        fds[0].fd = s;
+        fds[0].events = POLLIN;
+	ndesc = poll(fds, 1, timeout? timeout * 1000 : -1);
 	if (ndesc < 1)
 	    return ndesc;
 
+    if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL) ) {
+        return -1;
+    }
 	cnt = read(s, buf + len, num - len);
 	if (cnt <= 0)
 	    return -1;
diff --git a/tls.c b/tls.c
index 8d9bfa6..3f2132c 100644
--- a/tls.c
+++ b/tls.c
@@ -14,7 +14,7 @@
 #endif
 #include <sys/time.h>
 #include <sys/types.h>
-#include <sys/select.h>
+#include <poll.h>
 #include <ctype.h>
 #include <sys/wait.h>
 #include <arpa/inet.h>
@@ -168,8 +168,7 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
 /* returns 0 on timeout, -1 on error and num if ok */
 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
     int s, ndesc, cnt, len;
-    fd_set readfds;
-    struct timeval timer;
+    struct pollfd fds[1];
 
     s = SSL_get_fd(ssl);
     if (s < 0)
@@ -177,15 +176,13 @@ int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
     /* make socket non-blocking? */
     for (len = 0; len < num; len += cnt) {
 	if (SSL_pending(ssl) == 0) {
-            FD_ZERO(&readfds);
-            FD_SET(s, &readfds);
-            if (timeout) {
-                timer.tv_sec = timeout;
-                timer.tv_usec = 0;
-            }
-	    ndesc = select(s + 1, &readfds, NULL, NULL, timeout ? &timer : NULL);
-            if (ndesc < 1)
-                return ndesc;
+        fds[0].fd = s;
+        fds[0].events = POLLIN;
+	    ndesc = poll(fds, 1, timeout ? timeout * 1000 : -1);
+        if (ndesc < 1)
+            return ndesc;
+        if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL))
+            return -1;
 	}
 
 	cnt = SSL_read(ssl, buf + len, num - len);
diff --git a/tlscommon.c b/tlscommon.c
index 8ca67f0..7d95178 100644
--- a/tlscommon.c
+++ b/tlscommon.c
@@ -15,7 +15,6 @@
 #endif
 #include <sys/time.h>
 #include <sys/types.h>
-#include <sys/select.h>
 #include <ctype.h>
 #include <sys/wait.h>
 #include <arpa/inet.h>
diff --git a/udp.c b/udp.c
index 57225e3..a1f14cb 100644
--- a/udp.c
+++ b/udp.c
@@ -14,7 +14,6 @@
 #endif
 #include <sys/time.h>
 #include <sys/types.h>
-#include <sys/select.h>
 #include <ctype.h>
 #include <sys/wait.h>
 #include <arpa/inet.h>
@@ -137,7 +136,6 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
     socklen_t fromlen = sizeof(from);
     struct clsrvconf *p;
     struct list_node *node;
-    fd_set readfds;
     struct client *c = NULL;
     struct timeval now;
 
@@ -146,10 +144,7 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
 	    free(rad);
 	    rad = NULL;
 	}
-	FD_ZERO(&readfds);
-        FD_SET(s, &readfds);
-	if (select(s + 1, &readfds, NULL, NULL, NULL) < 1)
-	    continue;
+
 	cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
 	if (cnt == -1) {
 	    debug(DBG_WARN, "radudpget: recv failed");
diff --git a/util.c b/util.c
index cce7432..8e59884 100644
--- a/util.c
+++ b/util.c
@@ -12,7 +12,7 @@
 #include <unistd.h>
 #include <fcntl.h>
 #include <errno.h>
-#include <sys/select.h>
+#include <poll.h>
 #include <stdarg.h>
 #include "debug.h"
 #include "util.h"
@@ -177,10 +177,9 @@ int bindtoaddr(struct addrinfo *addrinfo, int family, int reuse) {
     return -1;
 }
 
-int connectnonblocking(int s, const struct sockaddr *addr, socklen_t addrlen, struct timeval *timeout) {
-    int origflags, error = 0, r = -1;
-    fd_set writefds;
-    socklen_t len;
+int connectnonblocking(int s, const struct sockaddr *addr, socklen_t addrlen, int timeout) {
+    int origflags, r = -1;
+    struct pollfd fds[1];
 
     origflags = fcntl(s, F_GETFL, 0);
     if (origflags == -1) {
@@ -198,14 +197,17 @@ int connectnonblocking(int s, const struct sockaddr *addr, socklen_t addrlen, st
     if (errno != EINPROGRESS)
 	goto exit;
 
-    FD_ZERO(&writefds);
-    FD_SET(s, &writefds);
-    if (select(s + 1, NULL, &writefds, NULL, timeout) < 1)
+    fds[0].fd = s;
+    fds[0].events = POLLOUT;
+    if (poll(fds, 1, timeout * 1000) < 1)
 	goto exit;
 
-    len = sizeof(error);
-    if (!getsockopt(s, SOL_SOCKET, SO_ERROR, (char*)&error, &len) && !error)
-	r = 0;
+    if (fds[0].revents & (POLLERR | POLLHUP | POLLNVAL) ) {
+        debug(DBG_WARN, "Connection failed or refused");
+    } else if(fds[0].revents & POLLOUT) {
+        debug(DBG_DBG, "Connection up");
+        r = 0;
+    }
 
 exit:
     if (fcntl(s, F_SETFL, origflags) == -1)
@@ -216,14 +218,11 @@ int connectnonblocking(int s, const struct sockaddr *addr, socklen_t addrlen, st
 int connecttcp(struct addrinfo *addrinfo, struct addrinfo *src, uint16_t timeout) {
     int s;
     struct addrinfo *res;
-    struct timeval to;
 
     s = -1;
     if (timeout) {
 	if (addrinfo && addrinfo->ai_next && timeout > 5)
 	    timeout = 5;
-	to.tv_sec = timeout;
-	to.tv_usec = 0;
     }
 
     for (res = addrinfo; res; res = res->ai_next) {
@@ -233,7 +232,7 @@ int connecttcp(struct addrinfo *addrinfo, struct addrinfo *src, uint16_t timeout
 	    continue;
 	}
 	if ((timeout
-	     ? connectnonblocking(s, res->ai_addr, res->ai_addrlen, &to)
+	     ? connectnonblocking(s, res->ai_addr, res->ai_addrlen, timeout)
 	     : connect(s, res->ai_addr, res->ai_addrlen)) == 0)
 	    break;
 	debug(DBG_WARN, "connecttoserver: connect failed");