From 67c6cf51ef4924c4db31e88992709cb58fe405e5 Mon Sep 17 00:00:00 2001 From: David Woodhouse Date: Tue, 11 Feb 2014 12:41:51 +0000 Subject: [PATCH] Abstract select() and FD_SET handling This should let us make the mainloop work for Windows, where we can't just select() on the tun device file descriptor. Or indeed *get* a proper file descriptor for the tun device, AFAICT. It might also let us use epoll() etc. if we wanted to. Signed-off-by: David Woodhouse --- cstp.c | 14 +++++++------- dtls.c | 20 +++++++++----------- gnutls.c | 6 +++--- mainloop.c | 13 ++++++------- openconnect-internal.h | 22 ++++++++++++++++++---- openssl.c | 6 +++--- tun.c | 15 +++++++-------- 7 files changed, 53 insertions(+), 43 deletions(-) diff --git a/cstp.c b/cstp.c index 1edd8971..dbc1ae10 100644 --- a/cstp.c +++ b/cstp.c @@ -489,11 +489,11 @@ static int start_cstp_connection(struct openconnect_info *vpninfo) vpn_progress(vpninfo, PRG_INFO, _("CSTP connected. DPD %d, Keepalive %d\n"), vpninfo->ssl_times.dpd, vpninfo->ssl_times.keepalive); - if (vpninfo->select_nfds <= vpninfo->ssl_fd) - vpninfo->select_nfds = vpninfo->ssl_fd + 1; - FD_SET(vpninfo->ssl_fd, &vpninfo->select_rfds); - FD_SET(vpninfo->ssl_fd, &vpninfo->select_efds); + monitor_fd_new(vpninfo, ssl); + + monitor_read_fd(vpninfo, ssl); + monitor_except_fd(vpninfo, ssl); if (!sessid_found) vpninfo->dtls_attempt_period = 0; @@ -666,7 +666,7 @@ static int cstp_write(struct openconnect_info *vpninfo, void *buf, int buflen) case SSL_ERROR_WANT_WRITE: /* Waiting for the socket to become writable -- it's probably stalled, and/or the buffers are full */ - FD_SET(vpninfo->ssl_fd, &vpninfo->select_wfds); + monitor_write_fd(vpninfo, ssl); case SSL_ERROR_WANT_READ: return 0; @@ -706,7 +706,7 @@ static int cstp_write(struct openconnect_info *vpninfo, void *buf, int buflen) if (gnutls_record_get_direction(vpninfo->https_sess)) { /* Waiting for the socket to become writable -- it's probably stalled, and/or the buffers are full */ - FD_SET(vpninfo->ssl_fd, &vpninfo->select_wfds); + monitor_write_fd(vpninfo, ssl); } return 0; } @@ -823,7 +823,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout) if (vpninfo->current_ssl_pkt) { handle_outgoing: vpninfo->ssl_times.last_tx = time(NULL); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_wfds); + unmonitor_write_fd(vpninfo, ssl); ret = cstp_write(vpninfo, vpninfo->current_ssl_pkt->hdr, diff --git a/dtls.c b/dtls.c index 8ae512ac..8a6dbdf3 100644 --- a/dtls.c +++ b/dtls.c @@ -540,11 +540,9 @@ int connect_dtls_socket(struct openconnect_info *vpninfo) vpninfo->dtls_state = DTLS_CONNECTING; vpninfo->dtls_fd = dtls_fd; - if (vpninfo->select_nfds <= dtls_fd) - vpninfo->select_nfds = dtls_fd + 1; - - FD_SET(dtls_fd, &vpninfo->select_rfds); - FD_SET(dtls_fd, &vpninfo->select_efds); + monitor_fd_new(vpninfo, dtls); + monitor_read_fd(vpninfo, dtls); + monitor_except_fd(vpninfo, dtls); time(&vpninfo->new_dtls_started); @@ -556,9 +554,9 @@ void dtls_close(struct openconnect_info *vpninfo) if (vpninfo->dtls_ssl) { DTLS_FREE(vpninfo->dtls_ssl); closesocket(vpninfo->dtls_fd); - FD_CLR(vpninfo->dtls_fd, &vpninfo->select_rfds); - FD_CLR(vpninfo->dtls_fd, &vpninfo->select_wfds); - FD_CLR(vpninfo->dtls_fd, &vpninfo->select_efds); + unmonitor_read_fd(vpninfo, dtls); + unmonitor_write_fd(vpninfo, dtls); + unmonitor_except_fd(vpninfo, dtls); vpninfo->dtls_ssl = NULL; vpninfo->dtls_fd = -1; } @@ -800,7 +798,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout) } /* Service outgoing packet queue */ - FD_CLR(vpninfo->dtls_fd, &vpninfo->select_wfds); + unmonitor_write_fd(vpninfo, dtls); while (vpninfo->outgoing_queue) { struct pkt *this = vpninfo->outgoing_queue; int ret; @@ -817,7 +815,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout) ret = SSL_get_error(vpninfo->dtls_ssl, ret); if (ret == SSL_ERROR_WANT_WRITE) { - FD_SET(vpninfo->dtls_fd, &vpninfo->select_wfds); + monitor_write_fd(vpninfo, dtls); vpninfo->outgoing_queue = this; vpninfo->outgoing_qlen++; @@ -847,7 +845,7 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout) vpninfo->outgoing_qlen++; work_done = 1; } else if (gnutls_record_get_direction(vpninfo->dtls_ssl)) { - FD_SET(vpninfo->dtls_fd, &vpninfo->select_wfds); + monitor_write_fd(vpninfo, dtls); vpninfo->outgoing_queue = this; vpninfo->outgoing_qlen++; } diff --git a/gnutls.c b/gnutls.c index 3afb9c6b..8d319eff 100644 --- a/gnutls.c +++ b/gnutls.c @@ -1998,9 +1998,9 @@ void openconnect_close_https(struct openconnect_info *vpninfo, int final) } if (vpninfo->ssl_fd != -1) { closesocket(vpninfo->ssl_fd); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_rfds); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_wfds); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_efds); + unmonitor_read_fd(vpninfo, ssl); + unmonitor_write_fd(vpninfo, ssl); + unmonitor_except_fd(vpninfo, ssl); vpninfo->ssl_fd = -1; } if (final && vpninfo->https_cred) { diff --git a/mainloop.c b/mainloop.c index d7663f4a..42f6a8ad 100644 --- a/mainloop.c +++ b/mainloop.c @@ -62,9 +62,8 @@ int openconnect_mainloop(struct openconnect_info *vpninfo, vpninfo->reconnect_interval = reconnect_interval; if (vpninfo->cmd_fd != -1) { - FD_SET(vpninfo->cmd_fd, &vpninfo->select_rfds); - if (vpninfo->cmd_fd >= vpninfo->select_nfds) - vpninfo->select_nfds = vpninfo->cmd_fd + 1; + monitor_fd_new(vpninfo, cmd); + monitor_read_fd(vpninfo, cmd); } while (!vpninfo->quit_reason) { @@ -119,14 +118,14 @@ int openconnect_mainloop(struct openconnect_info *vpninfo, vpn_progress(vpninfo, PRG_TRACE, _("No work to do; sleeping for %d ms...\n"), timeout); - memcpy(&rfds, &vpninfo->select_rfds, sizeof(rfds)); - memcpy(&wfds, &vpninfo->select_wfds, sizeof(wfds)); - memcpy(&efds, &vpninfo->select_efds, sizeof(efds)); + memcpy(&rfds, &vpninfo->_select_rfds, sizeof(rfds)); + memcpy(&wfds, &vpninfo->_select_wfds, sizeof(wfds)); + memcpy(&efds, &vpninfo->_select_efds, sizeof(efds)); tv.tv_sec = timeout / 1000; tv.tv_usec = (timeout % 1000) * 1000; - select(vpninfo->select_nfds, &rfds, &wfds, &efds, &tv); + select(vpninfo->_select_nfds, &rfds, &wfds, &efds, &tv); } cstp_bye(vpninfo, vpninfo->quit_reason); diff --git a/openconnect-internal.h b/openconnect-internal.h index 51e663d1..b3089e3e 100644 --- a/openconnect-internal.h +++ b/openconnect-internal.h @@ -272,10 +272,10 @@ struct openconnect_info { struct oc_ip_info ip_info; - int select_nfds; - fd_set select_rfds; - fd_set select_wfds; - fd_set select_efds; + int _select_nfds; + fd_set _select_rfds; + fd_set _select_wfds; + fd_set _select_efds; #ifdef __sun__ int ip_fd; @@ -319,6 +319,20 @@ struct openconnect_info { openconnect_protect_socket_vfn protect_socket; }; +#define monitor_read_fd(_v, _n) FD_SET(_v-> _n##_fd, &vpninfo->_select_rfds) +#define unmonitor_read_fd(_v, _n) FD_CLR(_v-> _n##_fd, &vpninfo->_select_rfds) +#define monitor_write_fd(_v, _n) FD_SET(_v-> _n##_fd, &vpninfo->_select_wfds) +#define unmonitor_write_fd(_v, _n) FD_CLR(_v-> _n##_fd, &vpninfo->_select_wfds) +#define monitor_except_fd(_v, _n) FD_SET(_v-> _n##_fd, &vpninfo->_select_efds) +#define unmonitor_except_fd(_v, _n) FD_CLR(_v-> _n##_fd, &vpninfo->_select_efds) + +#define monitor_fd_new(_v, _n) do { \ + if (_v->_select_nfds <= vpninfo->_n##_fd) \ + vpninfo->_select_nfds = vpninfo->_n##_fd + 1; \ + } while (0) + +#define read_fd_monitored(_v, _n) FD_ISSET(_v->_n##_fd, &_v->_select_rfds) + #if (defined(DTLS_OPENSSL) && defined(SSL_OP_CISCO_ANYCONNECT)) || \ (defined(DTLS_GNUTLS) && defined(HAVE_GNUTLS_SESSION_SET_PREMASTER)) #define HAVE_DTLS 1 diff --git a/openssl.c b/openssl.c index ad3e6ee8..d6e83cf1 100644 --- a/openssl.c +++ b/openssl.c @@ -1425,9 +1425,9 @@ void openconnect_close_https(struct openconnect_info *vpninfo, int final) } if (vpninfo->ssl_fd != -1) { closesocket(vpninfo->ssl_fd); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_rfds); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_wfds); - FD_CLR(vpninfo->ssl_fd, &vpninfo->select_efds); + unmonitor_read_fd(vpninfo, ssl); + unmonitor_write_fd(vpninfo, ssl); + unmonitor_except_fd(vpninfo, ssl); vpninfo->ssl_fd = -1; } if (final) { diff --git a/tun.c b/tun.c index 1fcb19fd..3b98c7ab 100644 --- a/tun.c +++ b/tun.c @@ -640,13 +640,12 @@ int openconnect_setup_tun_fd(struct openconnect_info *vpninfo, int tun_fd) set_fd_cloexec(tun_fd); if (vpninfo->tun_fd != -1) - FD_CLR(vpninfo->tun_fd, &vpninfo->select_rfds); - vpninfo->tun_fd = tun_fd; + unmonitor_read_fd(vpninfo, tun); - if (vpninfo->select_nfds <= tun_fd) - vpninfo->select_nfds = tun_fd + 1; + vpninfo->tun_fd = tun_fd; - FD_SET(tun_fd, &vpninfo->select_rfds); + monitor_fd_new(vpninfo, tun); + monitor_read_fd(vpninfo, tun); set_sock_nonblock(tun_fd); @@ -723,7 +722,7 @@ int tun_mainloop(struct openconnect_info *vpninfo, int *timeout) prefix_size = sizeof(int); #endif - if (FD_ISSET(vpninfo->tun_fd, &vpninfo->select_rfds)) { + if (read_fd_monitored(vpninfo, tun)) { while (1) { int len = vpninfo->ip_info.mtu; @@ -749,12 +748,12 @@ int tun_mainloop(struct openconnect_info *vpninfo, int *timeout) work_done = 1; vpninfo->outgoing_qlen++; if (vpninfo->outgoing_qlen == vpninfo->max_qlen) { - FD_CLR(vpninfo->tun_fd, &vpninfo->select_rfds); + unmonitor_read_fd(vpninfo, tun); break; } } } else if (vpninfo->outgoing_qlen < vpninfo->max_qlen) { - FD_SET(vpninfo->tun_fd, &vpninfo->select_rfds); + monitor_read_fd(vpninfo, tun); } /* The kernel returns -ENOMEM when the queue is full, so theoretically