diff --git a/cstp.c b/cstp.c index eed45f6f..bf2fc6ec 100644 --- a/cstp.c +++ b/cstp.c @@ -580,6 +580,7 @@ int cstp_connect(struct openconnect_info *vpninfo) { int ret; int deflate_bufsize = 0; + int compr_type; /* This needs to be done before openconnect_setup_dtls() because it's sent with the CSTP CONNECT handshake. Even if we don't end up doing @@ -599,13 +600,18 @@ int cstp_connect(struct openconnect_info *vpninfo) if (ret) goto out; + /* Allow for the theoretical possibility of having *different* + * compression type for CSTP and DTLS. Although all we've seen + * in practice is that one is enabled and the other isn't. */ + compr_type = vpninfo->cstp_compr | vpninfo->dtls_compr; + /* This will definitely be smaller than zlib's */ - if (vpninfo->cstp_compr == COMPR_LZS || vpninfo->cstp_compr == COMPR_LZ4) + if (compr_type & (COMPR_LZS|COMPR_LZ4)) deflate_bufsize = vpninfo->ip_info.mtu; /* If deflate compression is enabled (which is CSTP-only), it needs its * context to be allocated. */ - if (vpninfo->cstp_compr == COMPR_DEFLATE) { + if (compr_type & COMPR_DEFLATE) { vpninfo->deflate_adler32 = 1; vpninfo->inflate_adler32 = 1; @@ -663,22 +669,21 @@ static int cstp_reconnect(struct openconnect_info *vpninfo) return ssl_reconnect(vpninfo); } -int decompress_and_queue_packet(struct openconnect_info *vpninfo, +int decompress_and_queue_packet(struct openconnect_info *vpninfo, int compr_type, unsigned char *buf, int len) { struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->ip_info.mtu); - const char *comprtype = ""; + const char *comprname = ""; if (!new) return -ENOMEM; new->next = NULL; - if (vpninfo->cstp_compr == COMPR_DEFLATE) { + if (compr_type == COMPR_DEFLATE) { uint32_t pkt_sum; - /* Not sure this actually needs to be translated? */ - comprtype = _("deflate"); + comprname = "deflate"; vpninfo->inflate_strm.next_in = buf; vpninfo->inflate_strm.avail_in = len - 4; @@ -703,8 +708,8 @@ int decompress_and_queue_packet(struct openconnect_info *vpninfo, if (vpninfo->inflate_adler32 != pkt_sum) vpninfo->quit_reason = "Compression (inflate) adler32 failure"; - } else if (vpninfo->cstp_compr == COMPR_LZS) { - comprtype = "LZS"; + } else if (compr_type == COMPR_LZS) { + comprname = "LZS"; new->len = lzs_decompress(new->data, vpninfo->ip_info.mtu, buf, len); if (new->len < 0) { @@ -717,8 +722,8 @@ int decompress_and_queue_packet(struct openconnect_info *vpninfo, return len; } #ifdef HAVE_LZ4 - } else if (vpninfo->cstp_compr == COMPR_LZ4) { - comprtype = "LZ4"; + } else if (compr_type == COMPR_LZ4) { + comprname = "LZ4"; new->len = LZ4_decompress_safe((void *)buf, (void *)new->data, len, vpninfo->ip_info.mtu); if (new->len <= 0) { len = new->len; @@ -731,12 +736,12 @@ int decompress_and_queue_packet(struct openconnect_info *vpninfo, #endif } else { vpn_progress(vpninfo, PRG_ERR, - _("Unknown compression type %d\n"), (int)vpninfo->cstp_compr); + _("Unknown compression type %d\n"), compr_type); return -EINVAL; } vpn_progress(vpninfo, PRG_TRACE, _("Received %s compressed data packet of %d bytes (was %d)\n"), - comprtype, new->len, len); + comprname, new->len, len); queue_packet(&vpninfo->incoming_queue, new); return 0; @@ -745,6 +750,7 @@ int decompress_and_queue_packet(struct openconnect_info *vpninfo, int compress_packet(struct openconnect_info *vpninfo, int compr_type, struct pkt *this) { int ret; + if (compr_type == COMPR_DEFLATE) { vpninfo->deflate_strm.next_in = this->data; vpninfo->deflate_strm.avail_in = this->len; @@ -770,7 +776,7 @@ int compress_packet(struct openconnect_info *vpninfo, int compr_type, struct pkt vpninfo->deflate_pkt->len = vpninfo->deflate_strm.total_out + 4; return 0; - } else if (vpninfo->cstp_compr == COMPR_LZS) { + } else if (compr_type == COMPR_LZS) { if (this->len < 40) return -EFBIG; @@ -782,7 +788,7 @@ int compress_packet(struct openconnect_info *vpninfo, int compr_type, struct pkt vpninfo->deflate_pkt->len = ret; return 0; #ifdef HAVE_LZ4 - } else if (vpninfo->cstp_compr == COMPR_LZ4) { + } else if (compr_type == COMPR_LZ4) { if (this->len < 40) return -EFBIG; @@ -909,8 +915,8 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout) _("Compressed packet received in !deflate mode\n")); goto unknown_pkt; } - decompress_and_queue_packet(vpninfo, vpninfo->cstp_pkt->data, - payload_len); + decompress_and_queue_packet(vpninfo, vpninfo->cstp_compr, + vpninfo->cstp_pkt->data, payload_len); work_done = 1; continue; diff --git a/dtls.c b/dtls.c index cc7afd66..fc46f821 100644 --- a/dtls.c +++ b/dtls.c @@ -786,8 +786,8 @@ int dtls_mainloop(struct openconnect_info *vpninfo, int *timeout) _("Compressed DTLS packet received when compression not enabled\n")); goto unknown_pkt; } - decompress_and_queue_packet(vpninfo, vpninfo->dtls_pkt->data, - len - 1); + decompress_and_queue_packet(vpninfo, vpninfo->dtls_compr, + vpninfo->dtls_pkt->data, len - 1); break; default: vpn_progress(vpninfo, PRG_ERR, diff --git a/openconnect-internal.h b/openconnect-internal.h index 4fd0e324..b637a9ea 100644 --- a/openconnect-internal.h +++ b/openconnect-internal.h @@ -758,7 +758,7 @@ void cstp_common_headers(struct openconnect_info *vpninfo, struct oc_text_buf *b int cstp_connect(struct openconnect_info *vpninfo); int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout); int cstp_bye(struct openconnect_info *vpninfo, const char *reason); -int decompress_and_queue_packet(struct openconnect_info *vpninfo, +int decompress_and_queue_packet(struct openconnect_info *vpninfo, int compr_type, unsigned char *buf, int len); int compress_packet(struct openconnect_info *vpninfo, int compr_type, struct pkt *this);