From 443d43fe2ba83869c137f5de465b265e7ffc64e9 Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Thu, 8 Feb 2018 19:22:43 +0100 Subject: [PATCH 1/7] Add new tcp library --- src/lib/tcp/proto.lua | 612 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 612 insertions(+) create mode 100644 src/lib/tcp/proto.lua diff --git a/src/lib/tcp/proto.lua b/src/lib/tcp/proto.lua new file mode 100644 index 0000000000..944e27e4e9 --- /dev/null +++ b/src/lib/tcp/proto.lua @@ -0,0 +1,612 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Includes code ported from smoltcp +-- (https://github.com/m-labs/smoltcp), whose copyright is the +-- following: +--- +-- Copyright (C) 2016 whitequark@whitequark.org +-- +-- Permission to use, copy, modify, and/or distribute this software for +-- any purpose with or without fee is hereby granted. +-- +-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN +-- AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +-- OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +module(...,package.seeall) + +local lib = require("core.lib") +local ffi = require("ffi") +local bit = require("bit") +local ipsum = require("lib.checksum").ipsum + +local ntohs, ntohl = lib.ntohs, lib.ntohl +local htons, htonl = ntohs, ntohl +local band, bor, bxor, bnot = bit.band, bit.bor, bit.bxor, bit.bnot +local lshift, rshift = bit.lshift, bit.rshift + +local function ptr_to(t) return ffi.typeof("$*", t) end + +local proto_tcp = 6 + +------ +-- Ethernet +------ + +local ethernet_header_t = ffi.typeof [[ +/* All values in network byte order. */ +struct { + uint8_t dhost[6]; + uint8_t shost[6]; + uint16_t type; +} __attribute__((packed)) +]] +local ethernet_header_size = ffi.sizeof(ethernet_header_t) +local ethernet_type_ipv4 = 0x0800 +local ethernet_type_ipv6 = 0x86dd +local ethernet = {} +ethernet.__index = ethernet +function ethernet:read_type() return ntohs(self.type) end +function ethernet:write_type(type) self.type = htons(type) end +function ethernet:is_ipv4() return self:read_type() == ethernet_type_ipv4 end +function ethernet:is_ipv6() return self:read_type() == ethernet_type_ipv6 end +local ethernet_header_ptr_t = ptr_to(ffi.metatype(ethernet_header_t, ethernet)) + +------ +-- IPv4 +------ + +local ipv4_header_t = ffi.typeof [[ +/* All values in network byte order. */ +struct { + uint8_t version_and_ihl; // version:4, ihl:4 + uint8_t dscp_and_ecn; // dscp:6, ecn:2 + uint16_t total_length; + uint16_t id; + uint16_t flags_and_fragment_offset; // flags:3, fragment_offset:13 + uint8_t ttl; + uint8_t protocol; + uint16_t checksum; + uint8_t src_ip[4]; + uint8_t dst_ip[4]; +} __attribute__((packed)) +]] +local ipv4_header_size = ffi.sizeof(ipv4_header_t) +local ipv4 = {}; ipv4.__index = ipv4 +function ipv4:header_length() + return band(self.version_and_ihl, 0xf) * 4 +end +function ipv4:set_header_length(len) + self.version_and_ihl = bor(lshift(4, 4), len / 4) +end +function ipv4:read_total_length() return ntohs(self.total_length) end +function ipv4:write_total_length(len) self.total_length = htons(len) end +function ipv4:read_checksum() return ntohs(self.checksum) end +function ipv4:write_checksum(checksum) self.checksum = htons(checksum) end +function ipv4:compute_and_set_checksum() + self:write_checksum(0) + self:write_checksum(ipsum(ffi.cast('char*', self), self:header_length(), 0)) +end +function ipv4:is_tcp() return self.protocol == proto_tcp end +local ipv4_header_ptr_t = ptr_to(ffi.metatype(ipv4_header_t, ipv4)) + +------ +-- IPv6 +------ + +local ipv6_header_t = ffi.typeof [[ +/* All values in network byte order. */ +struct { + uint32_t v_tc_fl; // version:4, traffic class:8, flow label:20 + uint16_t payload_length; + uint8_t next_header; + uint8_t hop_limit; + uint8_t src_ip[16]; + uint8_t dst_ip[16]; +} __attribute__((packed)) +]] +local ipv6_header_size = ffi.sizeof(ipv6_header_t) +local ipv6 = {}; ipv6.__index = ipv6 +function ipv6:read_payload_length() return ntohs(self.payload_length) end +function ipv6:write_payload_length(len) self.payload_length = htons(len) end +function ipv6:is_tcp() return self.next_header == proto_tcp end +local ipv6_header_ptr_t = ptr_to(ffi.metatype(ipv6_header_t, ipv6)) + +------ +-- TCP +------ + +local tcp_header_t = ffi.typeof [[ +/* All values in network byte order. */ +struct { + uint16_t src_port; + uint16_t dst_port; + uint32_t seq; + uint32_t ack; + uint16_t data_offset_and_flags; + uint16_t window; + uint16_t checksum; + uint16_t urgent; + uint8_t options_and_payload[0]; +} __attribute__((packed)) +]] +local tcp_header_size = ffi.sizeof(tcp_header_t) +local ipv4_pseudo_header_t = ffi.typeof[[ +struct { + char src_ip[4]; + char dst_ip[4]; + uint16_t l4_protocol; + uint16_t l4_length; +} __attribute__((packed)) +]] +local ipv4_pseudo_header_size = ffi.sizeof(ipv4_pseudo_header_t) +local ipv6_pseudo_header_t = ffi.typeof[[ +struct { + char src_ip[16]; + char dst_ip[16]; + uint32_t l4_length; + uint32_t l4_protocol; +} __attribute__((packed)) +]] +local ipv6_pseudo_header_size = ffi.sizeof(ipv6_pseudo_header_t) + +-- Delay initialization until after we can set metatypes + +local tcp_data_offset_shift = 12 + +local flags = { FIN=0x001, SYN=0x002, RST=0x004, PSH=0x008, + ACK=0x010, URG=0x020, ECE=0x040, CWR=0x080, + NS =0x100 } +local flags_mask = 0x1ff + +local control_flags_mask = 0xf +local controls = { INVALID=0, NONE=1, PSH=2, SYN=3, FIN=4, RST=5 } +local control_array = ffi.new('uint8_t[16]') + +local function control_from_flags(flags) + return control_array[band(flags, control_flags_mask)] +end + +do + local function add_control(control, ...) + local flags = bor(0, ...) + assert(flags == band(flags, control_flags_mask)) + assert(control_from_flags(flags) == controls.INVALID) + control_array[flags] = control + end + + add_control(controls.NONE) + add_control(controls.PSH, flags.PSH) + add_control(controls.SYN, flags.SYN) + add_control(controls.SYN, flags.SYN, flags.PSH) + add_control(controls.FIN, flags.FIN) + add_control(controls.FIN, flags.FIN, flags.PSH) + add_control(controls.RST, flags.RST) + add_control(controls.RST, flags.RST, flags.PSH) +end + +local options = { END=0, NOP=1, MSS=2, WS=3 } + +local tcp = {}; tcp.__index = tcp + +function tcp:read_src_port() return ntohs(self.src_port) end +function tcp:write_src_port(port) self.src_port = htons(port) end + +function tcp:read_dst_port() return ntohs(self.dst_port) end +function tcp:write_dst_port(port) self.dst_port = htons(port) end + +function tcp:read_seq() return ntohl(self.seq) end +function tcp:write_seq(seq) self.seq = htonl(seq) end + +function tcp:read_ack() return ntohl(self.ack) end +function tcp:write_ack(ack) self.ack = htonl(ack) end + +function tcp:read_data_offset_and_flags() + return ntohs(self.data_offset_and_flags) +end +function tcp:write_data_offset_and_flags(data_offset_and_flags) + self.data_offset_and_flags = htons(data_offset_and_flags) +end + +function tcp:read_window() return ntohs(self.window) end +function tcp:write_window(window) self.window = htons(window) end + +function tcp:read_checksum() return ntohs(self.checksum) end +function tcp:write_checksum(checksum) self.checksum = htons(checksum) end + +function tcp:read_urgent() return ntohs(self.urgent) end +function tcp:write_urgent(urgent) self.urgent = htons(urgent) end + +function tcp:set_options_length_and_flags(options_length, flags) + local data_offset = tcp_header_size + options_length + self:write_data_offset_and_flags( + bor(lshift(data_offset/4, tcp_data_offset_shift), flags)) +end + +function tcp:header_length() + return 4 * rshift(self:read_data_offset_and_flags(), + tcp_data_offset_shift) +end +function tcp:set_header_length(len) + local flags = self:flags() + self:write_data_offset_and_flags( + bor(lshift(rshift(len, 2), tcp_data_offset_shift), flags)) +end + +function tcp:options_length() + return self:header_length() - tcp_header_size +end + +function tcp:payload_offset() + return self:header_length() +end +function tcp:payload() + return self.options_and_payload + self:options_length() +end +function tcp:payload_length(l4_length) + return l4_length - self:payload_offset() +end + +function tcp:flags() + return band(self:read_data_offset_and_flags(), flags_mask) +end +function tcp:set_flags(flags) + self:write_data_offset_and_flags( + bor(self:read_data_offset_and_flags(), flags)) +end +function tcp:clear_flags(flags) + self:write_data_offset_and_flags( + band(self:read_data_offset_and_flags(), band(bnot(flags), flags_mask))) +end + +local function has_flag(flags, flag) + return band(flags, flag) ~= 0 +end +function tcp:has_flag(flag) + return has_flag(self:flags(), flag) +end + +-- Return the length of the segment, in terms of sequence space. +function tcp:segment_len(l4_length) + local len = l4_length - self:payload_length(l4_length) + local f = self:flags() + if has_flag(f, flags.SYN) then len = len + 1 end + if has_flag(f, flags.FIN) then len = len + 1 end + return len +end + +local scratch_ipv4_pseudo_header = ipv4_pseudo_header_t() +local function ipv4_tcp_pseudo_header_checksum(src_ip, dst_ip, l4_length) + local ph = scratch_ipv4_pseudo_header + ph.src_ip, ph.dst_ip = src_ip, dst_ip + ph.l4_protocol = htons(proto_tcp) + ph.l4_length = htons(l4_length) + return ipsum(ffi.cast("uint8_t*", ph), ipv4_pseudo_header_size, 0) +end + +local scratch_ipv6_pseudo_header = ipv6_pseudo_header_t() +local function ipv6_tcp_pseudo_header_checksum(src_ip, dst_ip, l4_length) + local ph = scratch_ipv6_pseudo_header + ph.src_ip, ph.dst_ip = src_ip, dst_ip + ph.l4_protocol = htonl(proto_tcp) + ph.l4_length = htonl(l4_length) + return ipsum(ffi.cast("uint8_t*", ph), ipv6_pseudo_header_size, 0) +end + +function tcp:prezeroed_checksum(l4_length, ph_csum) + -- Return checksum of packet, assuming checksum field itself has been + -- zeroed out. + return ipsum(ffi.cast("uint8_t*", self), l4_length, bnot(ph_csum)) +end + +function tcp:compute_and_set_checksum(l4_length, ph_csum) + self:write_checksum(0) + self:write_checksum(self:prezeroed_checksum(l4_length, ph_csum)) +end + +function tcp:compute_checksum(l4_length, ph_csum) + local csum = self:prezeroed_checksum(l4_length, ph_csum) + -- We just did a checksum but didn't reset the checksum value in the + -- header to 0. Now munge the result to give the checksum that would + -- have been, if the checksum field were zero. + csum = band(bnot(csum), 0xffff) + csum = csum + band(bnot(self:read_checksum()), 0xffff) + csum = rshift(csum, 16) + band(csum, 0xffff) + csum = csum + rshift(csum, 16) + return band(bnot(csum), 0xffff) +end + +function tcp:is_valid_checksum(l4_length, ph_csum) + return self:compute_checksum(l4_length, ph_csum) == self:read_checksum() +end + +function tcp:is_valid_checksum_ipv4(src_ip, dst_ip, l4_length) + local ph_csum = ipv4_tcp_pseudo_header_checksum(src_ip, dst_ip, l4_length) + return self:is_valid_checksum(l4_length, ph_csum) +end + +function tcp:is_valid_checksum_ipv6(src_ip, dst_ip, l4_length) + local ph_csum = ipv6_tcp_pseudo_header_checksum(src_ip, dst_ip, l4_length) + return self:is_valid_checksum(l4_length, ph_csum) +end + +-- When packet too short, return false +-- Otherwise return op, param, next_idx +-- Skip over NOP options +local function read_option(base, idx, len) + if idx >= len then return false end + local op = base[idx] + if op == options.END then return op, nil, idx + 1 end + if op == options.NOP then return op, nil, idx + 1 end + local avail = len - idx + if avail < 2 then return false end + local param_len = base[idx + 1] + if avail < param_len then return false end + if param_len < 2 then return false end + if op == options.MSS then + if param_len ~= 4 then return false end + local mss = ntohs(ffi.cast("uint16_t*", base + idx + 2)[0]) + return op, mss, idx + param_len + elseif op == options.WS then + if param_len ~= 3 then return false end + return op, base[idx + 2], idx + param_len + end + -- Unknown option. Return index into base. + return op, idx + 2, idx + param_len +end + +local function read_options(base, idx, len) + local ret = {} + while true do + if idx == len then return ret end + local op, param, next_idx = read_option(base, idx, len) + if not op then return false end + if op == options.END then return ret end + if op ~= options.NOP then + table.insert(ret, { op, param, next_idx }) + end + idx = next_idx + end +end + +local function read_tcp_options(tcp) + return read_options(tcp.options_and_payload, 0, tcp:options_length()) +end + +function tcp:control() + return control_from_flags(self:flags()) +end + +function tcp:is_valid() + return self:read_src_port() ~= 0 and + self:read_dst_port() ~= 0 and + self:control() ~= controls.INVALID + -- could parse options +end + +function tcp:is_valid_ipv4(src_ip, dst_ip, tcp_length) + return self:is_valid() and + self:is_valid_checksum_ipv4(src_ip, dst_ip, tcp_length) +end + +function tcp:is_valid_ipv6(src_ip, dst_ip, tcp_length) + return self:is_valid() and + self:is_valid_checksum_ipv6(src_ip, dst_ip, tcp_length) +end + +local tcp_header_ptr_t = ptr_to(ffi.metatype(tcp_header_t, tcp)) + +------ +-- Getting a TCP header from a packet +------ + +local function as_ethernet(ptr) return ffi.cast(ethernet_header_ptr_t, ptr) end +local function as_ipv4(ptr) return ffi.cast(ipv4_header_ptr_t, ptr) end +local function as_ipv6(ptr) return ffi.cast(ipv6_header_ptr_t, ptr) end +local function as_tcp(ptr) return ffi.cast(tcp_header_ptr_t, ptr) end + +function is_ipv4(p) return as_ethernet(p.data):is_ipv4() end +function is_ipv6(p) return as_ethernet(p.data):is_ipv6() end + +-- Precondition: P's ethertype is IPv4 +function parse_ipv4_tcp(p) + if p.length < ethernet_header_size + ipv4_header_size then return end + local ipv4 = as_ipv4(p.data + ethernet_header_size) + local total_size = ipv4:read_total_length() + local header_size = ipv4:header_length() + if header_size < ipv4_header_size then return end + if total_size < header_size + tcp_header_size then return end + if p.length < ethernet_header_size + total_size then return end + if not ipv4:is_tcp() then return end + -- FIXME: validate IPv4 checksum + local l4_size = total_size - header_size + local tcp = as_tcp(p.data + ethernet_header_size + header_size) + if not tcp:is_valid_ipv4(ipv4.src_ip, ipv4.dst_ip, l4_size) then return end + + return ipv4, tcp, tcp:payload_length(l4_size) +end + +-- Precondition: P's ethertype is IPv6 +function parse_ipv6_tcp(p) + if p.length < ethernet_header_size + ipv6_header_size then return end + local ipv6 = as_ipv6(p.data + ethernet_header_size) + local l4_size = ipv6:read_payload_length() + if l4_size < tcp_header_size then return end + if p.length < ethernet_header_size + ipv6_header_size + l4_size then return end + if not ipv6:is_tcp() then return end + local tcp = as_tcp(p.data + ethernet_header_size + ipv6_header_size) + if not tcp:is_valid_ipv6(ipv6.src_ip, ipv6.dst_ip, l4_size) then return end + + return ipv6, tcp, tcp:payload_length(l4_size) +end + +------ +-- Pushing TCP, IP, and Ethernet headers onto a packet +------ + +local function push_tcp_header(p, src_ip, dst_ip, compute_pseudo_header_checksum, + src_port, dst_port, seq, ack, + options_length, flags, window) + local p = packet.shiftright(p, tcp_header_size) + local tcp = ffi.cast(tcp_header_ptr_t, p.data) + + tcp:write_src_port(src_port); tcp:write_dst_port(dst_port) + tcp:write_seq(seq); tcp:write_ack(ack) + tcp:set_options_length_and_flags(options_length, flags) + tcp:write_window(window) + tcp.urgent = 0 + + local ph_csum = compute_pseudo_header_checksum(src_ip, dst_ip, p.length) + tcp:compute_and_set_checksum(p.length, ph_csum) + + return p +end + +local function push_ipv4_header(p, src_ip, dst_ip, ttl) + local p = packet.shiftright(p, ipv4_header_size) + local ipv4 = ffi.cast(ipv4_header_ptr_t, p.data) + + local version = 4 + + ipv4:set_header_length(ipv4_header_size) + ipv4.dscp_and_ecn = 0 + ipv4:write_total_length(p.length) + ipv4.id = 0 + ipv4.flags_and_fragment_offset = 0 + ipv4.ttl = ttl + ipv4.protocol = proto_tcp + ipv4.src_ip, ipv4.dst_ip = src_ip, dst_ip + + ipv4:compute_and_set_checksum() + + return p +end + +local function push_ipv6_header(p, src_ip, dst_ip, ttl) + local payload_length = p.length + local p = packet.shiftright(p, ipv6_header_size) + local ipv6 = ffi.cast(ipv6_header_ptr_t, p.data) + + ipv6.v_tc_fl = 0 + lib.bitfield(32, ipv6, 'v_tc_fl', 0, 4, 6) -- IPv6 Version + lib.bitfield(32, ipv6, 'v_tc_fl', 4, 8, 0) -- Traffic class + lib.bitfield(32, ipv6, 'v_tc_fl', 12, 20, 0) -- Flow label + ipv6.payload_length = htons(payload_length) + ipv6.next_header = proto_tcp + ipv6.hop_limit = ttl + ipv6.src_ip, ipv6.dst_ip = src_ip, dst_ip + + return p +end + +-- Assume an ARP app sets L2 addresses. +local function push_ethernet_header(p, proto) + local p = packet.shiftright(p, ethernet_header_size) + local ether = ffi.cast(ethernet_header_ptr_t, p.data) + + ffi.fill(p.data, ethernet_header_size) + ether.type = htons(proto) + + return p +end + +function push_ethernet_ipv4_tcp_headers(p, src_ip, dst_ip, ttl, + src_port, dst_port, seq, ack, + options_length, flags, window) + p = push_tcp_header(p, src_ip, dst_ip, ipv4_tcp_pseudo_header_checksum, + src_port, dst_port, seq, ack, + options_length, flags, window) + p = push_ipv4_header(p, src_ip, dst_ip, ttl) + return push_ethernet_header(p, ethernet_type_ipv4) +end + +function push_ethernet_ipv6_tcp_headers(p, src_ip, dst_ip, ttl, + src_port, dst_port, seq, ack, + options_length, flags, window) + p = push_tcp_header(p, src_ip, dst_ip, ipv6_tcp_pseudo_header_checksum, + src_port, dst_port, seq, ack, + options_length, flags, window) + p = push_ipv6_header(p, src_ip, dst_ip, ttl) + return push_ethernet_header(p, ethernet_type_ipv6) +end + +function selftest() + print('selftest: lib.tcp.proto') + local packet = require('core.packet') + + local function assert_eq(a, b) + if not lib.equal(a, b) then + print('not equal', a, b) + error('not equal') + end + end + + local p = packet.from_string(lib.hexundump([[ + 52:54:00:02:02:02 52:54:00:01:01:01 08 00 45 00 + 00 34 00 00 00 00 40 06 49 A9 c0 a8 14 a9 6b 15 + f0 b4 de 0b 01 bb e7 db 57 bc 91 cd 18 32 80 10 + 05 9f 38 2a 00 00 01 01 08 0a 06 0c 5c bd fa 4a + e1 65 + ]], 66)) + + local ipv4, tcp, payload_length = parse_ipv4_tcp(p) + + assert_eq(tcp:read_src_port(), 56843) + assert_eq(tcp:read_dst_port(), 443) + assert_eq(tcp:read_seq(), 0xe7db57bc) + assert_eq(tcp:read_ack(), 0x91cd1832) + assert_eq(tcp:header_length(), 32); + assert_eq(tcp:flags(), flags.ACK) + assert_eq(tcp:has_flag(flags.ACK), true) + assert_eq(tcp:has_flag(flags.SYN), false) + assert_eq(tcp:read_window(), 1439) + assert_eq(tcp:read_urgent(), 0) + -- This particular packet has two nops followed by a timestamps + -- option and 8 bytes of data. We don't really support timestamps + -- yet, so the param is a pointer into the options array with its + -- corresponding end index. + local tcp_option_timestamps = 8 + assert_eq(read_tcp_options(tcp), { { tcp_option_timestamps, 4, 12 } }) + + local p2 = packet.from_pointer(tcp:payload(), payload_length) + + p2 = packet.prepend(p2, tcp.options_and_payload, tcp:options_length()) + p2 = push_ethernet_ipv4_tcp_headers( + p2, ipv4.src_ip, ipv4.dst_ip, ipv4.ttl, + tcp:read_src_port(), tcp:read_dst_port(), + tcp:read_seq(), tcp:read_ack(), + tcp:options_length(), flags.ACK, tcp:read_window()) + + assert_eq(p.length, p2.length) + -- Only compare L3 and onwards; p2 has empty L2 addresses. + p = packet.shiftleft(p, ethernet_header_size) + p2 = packet.shiftleft(p2, ethernet_header_size) + assert(ffi.C.memcmp(p.data, p2.data, p.length) == 0) + + packet.free(p2) + packet.free(p) + + local function validate_options(str, res) + local bytes = lib.hexundump(str, math.floor((#str+1)/3)) + local buf = ffi.cast('uint8_t*', bytes) + assert_eq(read_options(buf, 0, #bytes), res) + end + validate_options("", {}) + validate_options("00", {}) + validate_options("01", {}) + validate_options("02 04 05 dc", {{options.MSS, 1500, 4}}) + validate_options("03 03 0c", {{options.WS, 12, 3}}) + validate_options("0c 05 01 02 03", {{0x0c, 2, 5}}) + + validate_options("0c", false) -- Unknown option, missing length + validate_options("0c 05 01 02", false) -- Missing last byte of option data + validate_options("0c 01", false) -- Length invalid (less than 2) + validate_options("02 02", false) -- Bad length for MSS + validate_options("03 02", false) -- Bad length for WS + + print('selftest: ok') +end From 03dc374be9795bc71301420b9d9a84b1d15b78c4 Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Thu, 22 Feb 2018 08:21:06 +0100 Subject: [PATCH 2/7] Add ring buffer module --- src/lib/tcp/buffer.lua | 157 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 src/lib/tcp/buffer.lua diff --git a/src/lib/tcp/buffer.lua b/src/lib/tcp/buffer.lua new file mode 100644 index 0000000000..9470c7e6e0 --- /dev/null +++ b/src/lib/tcp/buffer.lua @@ -0,0 +1,157 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Ring buffer for bytes + +module(...,package.seeall) + +local lib = require("core.lib") +local ffi = require("ffi") +local bit = require("bit") + +local band = bit.band + +local buffer_t = ffi.typeof[[ +struct { + uint32_t read_idx, write_idx; + uint32_t size; + uint8_t buf[?]; +} __attribute__((packed)) +]] + +local function to_uint32(n) + return ffi.new('uint32_t[1]', n)[0] +end + +function new(size) + local ret = buffer_t(size) + ret:init(size) + return ret +end + +local buffer = {} +buffer.__index = buffer + +function buffer:init(size) + assert(size ~= 0 and band(size, size - 1) == 0, "size not power of two") + self.size = size + return self +end + +function buffer:is_empty() + return self.write_idx == self.read_idx +end +function buffer:read_avail() + return to_uint32(self.write_idx - self.read_idx) +end +function buffer:is_full() + return self:read_avail() == self.size +end +function buffer:write_avail() + return self.size - self:read_avail() +end + +function buffer:write_pos() + return band(self.write_idx, self.size - 1) +end +function buffer:rewrite_pos(offset) + return band(self.read_idx + offset, self.size - 1) +end +function buffer:read_pos() + return band(self.read_idx, self.size - 1) +end + +function buffer:advance_write(count) + self.write_idx = self.write_idx + count +end +function buffer:advance_read(count) + self.read_idx = self.read_idx + count +end + +function buffer:write(bytes, count) + if count > self:write_avail() then error('write xrun') end + local pos = self:write_pos() + local count1 = math.min(self.size - pos, count) + ffi.copy(self.buf + pos, bytes, count1) + ffi.copy(self.buf, bytes + count1, count - count1) + self:advance_write(count) +end + +function buffer:rewrite(offset, bytes, count) + if offset + count > self:read_avail() then error('rewrite xrun') end + local pos = self:rewrite_pos(offset) + local count1 = math.min(self.size - pos, count) + ffi.copy(self.buf + pos, bytes, count1) + ffi.copy(self.buf, bytes + count1, count - count1) +end + +function buffer:read(bytes, count) + if count > self:read_avail() then error('read xrun') end + local pos = self:read_pos() + local count1 = math.min(self.size - pos, count) + ffi.copy(bytes, self.buf + pos, count1) + ffi.copy(bytes + count1, self.buf, count - count1) + self:advance_read(count) +end + +function buffer:drop() + if count > self:read_avail() then error('read xrun') end + self:advance_read(count) +end + +function buffer:peek() + local pos = self:read_pos() + return self.buf + pos, math.min(self:read_avail(), self.size - pos) +end + +buffer_t = ffi.metatype(buffer_t, buffer) + +function selftest() + print('selftest: lib.buffer') + local function assert_throws(f, ...) + local success, ret = pcall(f, ...) + assert(not success, "expected failure but got "..tostring(ret)) + end + local function assert_avail(b, readable, writable) + assert(b:read_avail() == readable) + assert(b:write_avail() == writable) + end + local function write_str(b, str) + local scratch = ffi.new('uint8_t[?]', #str) + ffi.copy(scratch, str, #str) + b:write(scratch, #str) + end + local function read_str(b, count) + local scratch = ffi.new('uint8_t[?]', count) + b:read(scratch, count) + return ffi.string(scratch, count) + end + + assert_throws(new, 10) + local b = new(16) + assert_avail(b, 0, 16) + for i = 1,10 do + local s = '0123456789' + write_str(b, s) + assert_avail(b, #s, 16-#s) + assert(read_str(b, #s) == s) + assert_avail(b, 0, 16) + end + + local ptr, avail = b:peek() + assert(avail == 0) + write_str(b, "foo") + local ptr, avail = b:peek() + assert(avail > 0) + + -- Test wrap of indices. + local s = "overflow" + b.read_idx = to_uint32(3 - #s) + b.write_idx = b.read_idx + assert_avail(b, 0, 16) + write_str(b, s) + assert_avail(b, #s, 16-#s) + assert(read_str(b, #s) == s) + assert_avail(b, 0, 16) + + print('selftest: ok') +end From 5c6dd8d1a8f99b486bbdb365c206e3175865b95f Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Thu, 22 Feb 2018 15:43:31 +0100 Subject: [PATCH 3/7] Add reorder buffer Following smoltcp's example, the actual data is written directly in the ring buffer; the reorder buffer just holds bookkeeping data. --- src/lib/tcp/reorder.lua | 213 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 src/lib/tcp/reorder.lua diff --git a/src/lib/tcp/reorder.lua b/src/lib/tcp/reorder.lua new file mode 100644 index 0000000000..a301ca0c8c --- /dev/null +++ b/src/lib/tcp/reorder.lua @@ -0,0 +1,213 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Reorder buffer, designed for use with byte streams where ranges of +-- the byte streams may arrive out of order and might even overlap, as +-- in TCP. +-- +-- The actual data is "on the side" in a ring buffer; the reorder +-- information just tracks which parts of the buffer are valid. +-- +-- The reorder information is tracked in such a way that it's easy to +-- read completed data directly off the front of the ring buffer without +-- having to adjust the contents of the reorder buffer. + +module(..., package.seeall) + +local bit = require("bit") +local ffi = require("ffi") +local lib = require("core.lib") + +local max_hole_count = 4 + +local reorder_t = ffi.typeof([[ + struct { + uint32_t hole_count; + struct { uint32_t start, len; } holes[$]; + } __attribute((packed))]], + max_hole_count) + +local function to_uint32(n) + return ffi.new('uint32_t[1]', n)[0] +end + +function new() + return reorder_t() +end + +local reorder = {} +reorder.__index = reorder + +-- Add COUNT bytes from the uint8_t* BYTES to the ring buffer BUF. The +-- read end of the ring buffer is at stream position BASE, and the bytes +-- to be written start at stream position POS. Both BASE and POS are +-- wrappable uint32_t counters. +-- +-- In the event of overlap between this range and a previously recorded +-- range of data, this implementation will discard the overlapping +-- portion of the new data, preferring the old data. +function reorder:write(buf, base, pos, bytes, count) + local offset = to_uint32(pos - base) + assert(offset + count <= buf.size) + + local i = 0 + while i < self.hole_count do + local hole_offset = to_uint32(self.holes[i].start - base) + local hole_len = self.holes[i].len + if offset < hole_offset then + -- New data overlaps with old data. + local drop = math.min(count, hole_offset - offset) + if drop == count then return end + offset, bytes, count = offset + drop, bytes + drop, count - drop + -- Fall through. + end + if offset < hole_offset + hole_len then + -- New data starts in this hole. + local fill = math.min(count, hole_offset + hole_len - offset) + buf:rewrite(offset, bytes, fill) + if fill == hole_len then + -- Hole completely filled; delete it and loop again with + -- same i. + for j=i,self.hole_count-2 do self.holes[j] = self.holes[j+1] end + self.hole_count = self.hole_count - 1 + -- Fall through. + elseif offset == hole_offset then + -- Hole partially filled from start. + self.holes[i].start = base + hole_offset + fill + self.holes[i].len = hole_len - fill + return + elseif offset + fill < hole_offset + hole_len then + -- Hole split by fill in middle. + assert(fill == count) + if self.hole_count == max_hole_count then + error("fixme: do something sensible here") + else + for j=i+1,self.hole_count-1 do self.holes[j+1] = self.holes[j] end + self.holes[i].len = offset - hole_offset + self.holes[i+1].start = base + offset + count + self.holes[i+1].len = hole_len - count - self.holes[i].len + self.hole_count = self.hole_count + 1 + return + end + else + -- Hole partially filled at end; start looking in next hole. + self.holes[i].len = hole_len - fill + i = i + 1 + -- Fall through. + end + offset, bytes, count = offset + fill, bytes + fill, count - fill + else + -- New data is after this hole; look at next hole. + i = i + 1 + end + end + + -- New data is after all the holes. + if offset < buf:read_avail() then + -- But it starts before the end of the data. Drop overlapping data. + local drop = math.min(count, buf:read_avail() - offset) + if drop == count then return end + offset, bytes, count = offset + drop, bytes + drop, count - drop + end + + -- New data is after all the holes and all of the data. But we might + -- need to extend a hole or open a new hole before it starts. + if offset ~= buf:read_avail() then + local old_size = buf:read_avail() + local new_hole_bytes = offset - old_size + assert(new_hole_bytes > 0) + if i == 0 then + -- First hole. + self.hole_count = 1 + self.holes[0].start = base + old_size + self.holes[0].len = new_hole_bytes + else + local last_hole_offset = to_uint32(self.holes[i-1].start - base) + local last_hole_len = self.holes[i-1].len + if last_hole_offset + last_hole_len == old_size then + -- Reorder buffer ended with a hole. Extend it. + self.holes[i-1].len = self.holes[i-1].len + new_hole_bytes + elseif self.hole_count == max_hole_count then + error("fixme: do something sensible here") + else + -- Reorder buffer ended with data. Make a new hole. + self.hole_count = self.hole_count + 1 + self.holes[i].start = base + old_size + self.holes[i].len = new_hole_bytes + end + end + -- Leave a place for the new or extended hole. + buf:advance_write(new_hole_bytes) + end + + -- Finally, append our new data to the buffer. + buf:write(bytes, count) +end + +function reorder:has_holes() + return self.hole_count ~= 0 +end + +function reorder:read_avail(buf, base) + if self:has_holes() then + return to_uint32(self.holes[0].start - base) + else + return buf:read_avail() + end +end + +reorder_t = ffi.metatype(reorder_t, reorder) + +function selftest() + print('selftest: lib.reorder') + local window = 2^16 + local data_len = math.random(1, window) + local data = lib.random_bytes(data_len) + -- 3 segments. + local offset_12 = math.random(0, data_len) + local offset_23 = math.random(offset_12, data_len) + + local function make_segment(start, len) + local ret = ffi.new('uint8_t[?]', len) + ffi.copy(ret, data + start, len) + return {ret, start, len} + end + local function permute_indices(lo, hi) + if lo == hi then return {{hi}} end + local ret = {} + for _, tail in ipairs(permute_indices(lo + 1, hi)) do + for pos = 1, #tail + 1 do + local order = lib.deepcopy(tail) + table.insert(order, pos, lo) + table.insert(ret, order) + end + end + return ret + end + + local segments = { make_segment(0, offset_12), + make_segment(offset_12, offset_23 - offset_12), + make_segment(offset_23, data_len - offset_23) } + local reorder = new() + local buf = require('lib.buffer').new(window) + + local pos = 0 + for _, order in ipairs(permute_indices(1, #segments)) do + for again = 1,5 do + local advance = math.random(0, 2^32) + buf:advance_read(advance) + buf:advance_write(advance) + pos = to_uint32(pos + advance) + for _, i in ipairs(order) do + local bytes, offset, len = unpack(segments[i]) + reorder:write(buf, pos, to_uint32(pos + offset), bytes, len) + end + assert(reorder.hole_count == 0) + assert(reorder:read_avail(buf) == data_len) + local tmp = ffi.new('uint8_t[?]', data_len) + buf:read(tmp, data_len) + assert(ffi.C.memcmp(data, tmp, data_len) == 0) + end + end + + print('selftest: ok') +end From 14e14435ee04daec13f3627a16576143bea21d77 Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Fri, 23 Feb 2018 11:02:50 +0100 Subject: [PATCH 4/7] Add timer module --- src/lib/tcp/timer.lua | 113 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 src/lib/tcp/timer.lua diff --git a/src/lib/tcp/timer.lua b/src/lib/tcp/timer.lua new file mode 100644 index 0000000000..9db8c7fbb0 --- /dev/null +++ b/src/lib/tcp/timer.lua @@ -0,0 +1,113 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Includes code ported from smoltcp +-- (https://github.com/m-labs/smoltcp), whose copyright is the +-- following: +--- +-- Copyright (C) 2016 whitequark@whitequark.org +-- +-- Permission to use, copy, modify, and/or distribute this software for +-- any purpose with or without fee is hereby granted. +-- +-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN +-- AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +-- OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +module(...,package.seeall) + +local ffi = require("ffi") + +local NONE, IDLE, RETRANSMIT, CLOSE = 0, 1, 2, 3 + +timer_t = ffi.typeof[[ +struct { + uint8_t kind; + uint64_t expires_at; /* for idle, retransmit, close */ + uint64_t delay; /* for retransmit */ +} __attribute__((packed)) +]] + +local retransmit_delay = 100 +local close_delay = 10000 + +local timer = {} +timer.__index = timer + +function timer:should_keep_alive(ts) + return self.kind == IDLE and self.expires_at <= ts +end + +function timer:should_retransmit(ts) + if self.kind == RETRANSMIT and self.expires_at <= ts then + return ts - self.expires_at + self.delay + end +end + +function timer:should_close(ts) + return self.kind == CLOSE and self.expires_at <= ts +end + +function timer:poll_at() + if self.kind ~= NONE then return self.expires_at end +end + +function timer:set_none() + self.kind = NONE +end + +function timer:set_idle(t, keep_alive_at) + self.kind = IDLE + self.expires_at = keep_alive_at or -1ULL +end + +function timer:rewind_keep_alive(t, keep_alive_at) + if self.kind == IDLE then self.expires_at = keep_alive_at or -1ULL end +end + +function timer:set_keep_alive(t) + timer:rewind_keep_alive(t, 0) +end + +function timer:set_retransmit(ts) + if self.kind == NONE or self.kind == IDLE then + self.kind = RETRANSMIT + self.expires_at = ts + retransmit_delay + self.delay = retransmit_delay + elseif self.kind == RETRANSMIT and self.expires_at <= ts then + self.expires_at = ts + retransmit_delay + self.delay = self.delay * 2 + end +end + +function timer:set_close(ts) + self.kind = CLOSE + self.expires_at = ts + close_delay +end + +function timer:is_retransmit(t) + return self.kind == RETRANSMIT +end + +timer_t = ffi.metatype(timer_t, timer) + +function selftest() + print('selftest: lib.tcp.timer') + local t = timer_t() + assert(not t:should_retransmit(1000)) + t:set_retransmit(1000) + assert(not t:should_retransmit(1000)) + assert(not t:should_retransmit(1050)) + assert(t:should_retransmit(1101) == 101) + t:set_retransmit(1101) + assert(not t:should_retransmit(1101)) + assert(not t:should_retransmit(1150)) + assert(not t:should_retransmit(1200)) + assert(t:should_retransmit(1301) == 300) + t:set_idle(1301) + assert(not t:should_retransmit(1350)) + print('selftest: ok') +end From 018d24b1b8103772b48c267c4c663b3ecabf4361 Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Mon, 12 Feb 2018 15:18:28 +0100 Subject: [PATCH 5/7] Add "preserve" argument to ctable:add --- src/lib/README.ctable.md | 15 +++++++++------ src/lib/ctable.lua | 6 ++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/lib/README.ctable.md b/src/lib/README.ctable.md index fe144a49de..b7dc2cba2a 100644 --- a/src/lib/README.ctable.md +++ b/src/lib/README.ctable.md @@ -80,12 +80,15 @@ Add an entry to the ctable, returning the index of the added entry. *updates_allowed* is an optional parameter. If not present or false, then the `:insert` method will raise an error if the *key* is already present in the table. If *updates_allowed* is the string `"required"`, -then an error will be raised if *key* is *not* already in the table. -Any other true value allows updates but does not require them. An -update will replace the existing entry in the table. - -Returns a pointer to the inserted entry. Any subsequent modification -to the table may invalidate this pointer. +then an error will be raised if *key* is *not* already in the table. If +*updates_allowed* is the string `"preserve"`, then no error will be +raised if a key is already present in the table, but in that case the +corresponding value already in the table won't be updated either. Any +other true value allows updates but does not require them. An update +will replace the existing entry in the table. + +Returns a pointer to the inserted or existing entry. Any subsequent +modification to the table may invalidate this pointer. — Method **:update** *key*, *value* diff --git a/src/lib/ctable.lua b/src/lib/ctable.lua index 96df39bf8f..bd32573a5d 100644 --- a/src/lib/ctable.lua +++ b/src/lib/ctable.lua @@ -311,8 +311,10 @@ function CTable:add(key, value, updates_allowed) local entry = entries + index if self.equal_fn(key, entry.key) then assert(updates_allowed, "key is already present in ctable") - entry.key = key - entry.value = value + if updates_allowed ~= 'preserve' then + entry.key = key + entry.value = value + end return entry end index = index + 1 From e0a77c9430bab996002ddc48a706adf448f29720 Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Thu, 1 Mar 2018 08:13:17 +0100 Subject: [PATCH 6/7] Add example server app --- src/apps/tcp/server.lua | 87 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 src/apps/tcp/server.lua diff --git a/src/apps/tcp/server.lua b/src/apps/tcp/server.lua new file mode 100644 index 0000000000..79c5ad1a16 --- /dev/null +++ b/src/apps/tcp/server.lua @@ -0,0 +1,87 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Simple TCP echo service. + +module(..., package.seeall) + +local lib = require("core.lib") +local packet = require("core.packet") +local link = require("core.link") +local tcp = require("lib.tcp.tcp") +local proto = require("lib.tcp.proto") +local scheduler = require("lib.fiber.scheduler") +local fiber = require("lib.fiber.fiber") + +Server = {} +local config_params = { + -- Address or list of addresses to which to bind, as strings of the + -- format ADDR:PORT, where ADDR is either an IPv4 or IPv6 address. + bind = { required=true }, +} + +local function parse_ipv4_address(str) + local head, tail = str:match("^([%d.]*):([1-9][0-9]*)$") + if not head then return end + local parsed = ipv4:pton(head) + if not parsed then return end + return { type='ipv4', addr=parsed, port=tonumber(tail) } +end + +local function parse_ipv6_address(str) + local head, tail = str:match("^([%x:]*):([1-9][0-9]*)$") + if not head then return end + local parsed = ipv6:pton(head) + if not parsed then return end + return { type='ipv6', addr=parsed, port=tonumber(tail) } +end + +function Server:new(conf) + conf = lib.parse(conf, config_params) + + local o = setmetatable({}, {__index = Server}) + o.tcp = tcp.new() + + if type(conf.bind) == 'string' then o:bind(conf.bind) + else for _,str in ipairs(conf.bind) do o:bind(str) end end + + return o +end + +local function echo(fam, sock) + while fibers.wait_readable(sock) do + fibers.write(sock, sock:peek()) + end +end + +-- Override me! +Server.accept_fn = echo + +function Server:bind(addr_and_port) + local addr, port = parse_ipv4_address(addr_and_port) + local function accept_ipv4(sock) fiber.spawn(self.accept_fn, 'ipv4', sock) end + if addr then self.tcp:listen_ipv4(addr, port, accept); return end + local addr, port = parse_ipv6_address(addr_and_port) + local function accept_ipv6(sock) fiber.spawn(self.accept_fn, 'ipv6', sock) end + if addr then self.tcp:listen_ipv6(addr, port, accept); return end + error('Invalid bind address for server, expected ADDR:PORT: '..tostring(str)) +end + +function Server:push() + local now = engine.now() + self.tcp:advance_clock(now) + + for _ = 1, link.nreadable(self.input.input) do + local p = link.receive(self.input.input) + if proto.is_ipv4(p) then + local ip, tcp, payload_length = proto.parse_ipv4_tcp(p) + if ip then self.tcp:handle_ipv4(ip, tcp, payload_length) end + elseif proto.is_ipv6(p) then + local ip, tcp, payload_length = proto.parse_ipv6_tcp(p) + if ip then self.tcp:handle_ipv6(ip, tcp, payload_length) end + end + packet.free(pkt) + end + + self.scheduler:advance_clock(now) + self.scheduler:run_tasks() +end From 6a047ddd003ef74bf654ca0dea42f36b83dc551b Mon Sep 17 00:00:00 2001 From: Andy Wingo Date: Thu, 1 Mar 2018 08:23:59 +0100 Subject: [PATCH 7/7] Add smoltcp port --- src/lib/tcp/socket.lua | 3500 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3500 insertions(+) create mode 100644 src/lib/tcp/socket.lua diff --git a/src/lib/tcp/socket.lua b/src/lib/tcp/socket.lua new file mode 100644 index 0000000000..571bcb598e --- /dev/null +++ b/src/lib/tcp/socket.lua @@ -0,0 +1,3500 @@ +-- Use of this source code is governed by the Apache 2.0 license; see COPYING. + +-- Includes code ported from smoltcp +-- (https://github.com/m-labs/smoltcp), whose copyright is the +-- following: +--- +-- Copyright (C) 2016 whitequark@whitequark.org +-- +-- Permission to use, copy, modify, and/or distribute this software for +-- any purpose with or without fee is hereby granted. +-- +-- THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES +-- WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +-- MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR +-- ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +-- WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN +-- AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +-- OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +-- Heads up! Before working on this file you should read, at least, RFC +-- 793 and the parts of RFC 1122 that discuss TCP. Consult RFC 7414 +-- when implementing a new feature. + +module(...,package.seeall) + +-- NOTE TO READER ~~~ +-- +-- This port is currently unfinished. It's just here as a savepoint. +-- +-- I started this port of smoltcp thinking that I would want a big flat +-- ctable. However in the end I don't think that's the right thing, +-- because the sockets can move around in memory, and if you implement a +-- TCP service, you'd like to be able for the fiber or whatever that +-- serves a connection to be able to work with the socket directly -- +-- but if it can move around in memory, you're inviting problems. +-- +-- So, the next step here is to refactor this to make the "socket" the +-- primary object and not the socket table. At the same time, the +-- "proto" library changed since this code was first written; need to +-- port there. But in general I would say (to myself, probably!), look +-- at apps/tcp/server.lua and figure out what needs to be done here to +-- make that work. + +local lib = require("core.lib") +local ffi = require("ffi") +local bit = require("bit") +local tcp = require("lib.tcp") +local ctable = require("lib.ctable") +local siphash = require("lib.hash.siphash") +local buffer = require("lib.tcp.buffer") +local proto = require("lib.tcp.proto") +local reorder = require("lib.tcp.reorder") +local timer = require("lib.tcp.timer") + +local ntohs, ntohl = lib.ntohs, lib.ntohl +local htons, htonl = ntohs, ntohl +local band, bor, bxor, bnot = bit.band, bit.bor, bit.bxor, bit.bnot +local lshift, rshift = bit.lshift, bit.rshift + +local ipv4_addr_t = ffi.typeof('uint8_t[4]') +local ipv6_addr_t = ffi.typeof('uint8_t[16]') + +local function enum(names) + local ret = {} + for i,name in ipairs(names) do ret[name] = i end + return ret +end + +-- The state of a TCP socket, according to [RFC 793]. +-- +-- [RFC 793]: https://tools.ietf.org/html/rfc793 +local states = enum { 'CLOSED', 'LISTEN', 'SYN_SENT', 'SYN_RECEIVED', + 'ESTABLISHED', 'FIN_WAIT_1', 'FIN_WAIT_2', 'CLOSE_WAIT', + 'CLOSING', 'LAST_ACK', 'TIME_WAIT' } + +-- A Transmission Control Protocol socket. + +local function make_tcp_socket_id_t(addr_t) + -- A socket is identified by the four-tuple of local and remote + -- addresses and ports. A socket in the LISTEN state uses a zero + -- remote address and port. + return ffi.typeof([[ +struct { + $ local_ip; + $ remote_ip; + uint16_t local_port; + uint16_t remote_port; +} __attribute__((packed)); +]], addr_t, addr_t) +end + +local tcp_socket_state_t = ffi.typeof([[ +struct { + uint8_t state; /* one of the tcp_state constants */ + $ timer; + $* rx_buffer; + $* tx_buffer; + + /* Interval after which, if no inbound packets are received, the + connection is aborted, or -1 if not set. */ + uint64_t timeout; + + /* Interval at which keep-alive packets will be sent, or -1 if not + set. */ + uint64_t keep_alive; + + /* The sequence number corresponding to the beginning of the + transmit buffer. I.e. an ACK(local_seq_no+n) packet removes n + bytes from the transmit buffer. */ + uint32_t local_seq_no; + + /* The sequence number corresponding to the beginning of the receive + buffer. I.e. userspace reading n bytes adds n to + remote_seq_no. */ + uint32_t remote_seq_no; + + /* The last sequence number sent. I.e. in an idle socket, + local_seq_no+tx_buffer.len(). */ + uint32_t remote_last_seq; + + /* The last acknowledgement number sent. I.e. in an idle socket, + remote_seq_no+rx_buffer.len(). */ + uint32_t remote_last_ack; // FIXME: Option<> + + /* The last window length sent. */ + uint16_t remote_last_win; + + /* The speculative remote window size. I.e. the actual remote + window size minus the count of in-flight octets. */ + uint32_t remote_win_len; + + /* The maximum number of data octets that the remote side may + receive. */ + uint32_t remote_mss; + + /* The timestamp of the last packet received, or -1 if unknown. */ + uint64_t remote_last_ts; +} __attribute__((packed)); +]], timer.timer_t, socket_buffer_t, socket_buffer_t) + +local default_mss = 536 + +local function clear_tcp_socket_state(sock) + -- note: buffers larger than 65535 require window scaling, which is + -- not implemented + -- FIXME: put rx and tx buffers, if any, back on freelist + ffi.fill(sock, ffi.sizeof(tcp_socket_state_t)) + sock.state = states.CLOSED + sock.timer:set_none() + sock.timeout = -1 + sock.keep_alive = -1 + -- FIXME: Initialize sock.remote_last_ack to "none" + sock.remote_mss = default_mss + sock.remote_last_ts = -1 +end + +local function state_predicate(states) + return function(sock) return lib.bitset(states, sock.state) end +end + +-- This function returns true if the socket will process incoming or +-- dispatch outgoing packets. Note that this does not mean that it is +-- possible to send or receive data through the socket; for that, use +-- [can_send](#method.can_send) or [can_recv](#method.can_recv). +local tcp_socket_is_open = state_predicate( + bnot(lib.bits {states.CLOSED, states.TIME_WAIT})) +local tcp_socket_is_listening = state_predicate( + lib.bits {states.LISTEN}) +-- This function returns true if the socket is actively exchanging packets with +-- a remote endpoint. Note that this does not mean that it is possible to send or receive +-- data through the socket; for that, use [can_send](#method.can_send) or +-- [can_recv](#method.can_recv). +local tcp_socket_is_active = state_predicate( + bnot(lib.bits {states.CLOSED, states.TIME_WAIT, states.LISTEN})) +-- Return whether the transmit half of the full-duplex connection is open. +-- +-- This function returns true if it's possible to send data and have it +-- arrive to the remote endpoint. However, it does not make any +-- guarantees about the state of the transmit buffer, and even if it +-- returns true, [send](#method.send) may not be able to enqueue any +-- octets. +-- +-- In CLOSE_WAIT, the remote endpoint has closed our receive half of the +-- connection but we still can transmit indefinitely. +local tcp_socket_may_send = state_predicate( + lib.bits {states.ESTABLISHED, states.CLOSE_WAIT}) + +local SocketTable = {} + +function new_socket_table(addr_t) + local ret = {} + ret.id_t = make_tcp_socket_id_t(addr_t) + ret.empty_ip = addr_t() + ret.counter = 0xffffffff -- For random_u32. + ret.sockets = ctable.new { + key_type = ret.id_t, + value_type = tcp_socket_state_t, + max_occupancy_rate = 0.4 + } + ret.scratch_entry = ret.table.entry_type() + return setmetatable(ret, { __index = SocketTable }) +end + +function new_ipv4_socket_table() return new_socket_table(ipv4_addr_t) end +function new_ipv6_socket_table() return new_socket_table(ipv6_addr_t) end + +-- A socket with a timeout duration set will abort the connection if +-- either of the following occurs: +-- +-- * After a [connect](#method.connect) call, the remote endpoint does +-- not respond within the specified duration; +-- * After establishing a connection, there is data in the transmit +-- buffer and the remote endpoint exceeds the specified duration +-- between any two packets it sends; +-- * After enabling [keep-alive](#method.set_keep_alive), the remote +-- endpoint exceeds the specified duration between any two packets +-- it sends. +local function tcp_socket_timeout(sock) + if sock.timeout ~= -1 then return sock.timeout end +end +local function set_tcp_socket_timeout(sock, duration) + sock.timeout = duration +end +local function clear_tcp_socket_timeout(sock) + set_tcp_socket_timeout(sock, -1) +end + +-- An idle socket with a keep-alive interval set will transmit a +-- "challenge ACK" packet every time it receives no communication during +-- that interval. As a result, three things may happen: +-- +-- * The remote endpoint is fine and answers with an ACK packet. +-- * The remote endpoint has rebooted and answers with an RST packet. +-- * The remote endpoint has crashed and does not answer. +-- +-- The keep-alive functionality together with the timeout functionality +-- allows to react to these error conditions. +local function tcp_socket_keep_alive(sock) + if sock.keep_alive ~= -1 then return sock.keep_alive end +end +local function set_tcp_socket_keep_alive(sock, interval) + sock.keep_alive = interval + -- If the connection is idle and we've just set the option, it would + -- not take effect until the next packet, unless we wind up the timer + -- explicitly. + sock.timer:set_keep_alive() +end +local function clear_tcp_socket_keep_alive(sock) + sock.keep_alive = -1 + -- and the timer?? +end + +-- Return socket table entry (with key and value properties) or nil. +function SocketTable:lookup_socket(local_ip, remote_ip, local_port, remote_port) + local entry = self.scratch_entry + entry.key.local_ip, entry.key.remote_ip = local_ip, remote_ip + entry.key.local_port, entry.key.remote_port = local_port, remote_port + return self.sockets:lookup_ptr(entry.key) +end + +function SocketTable:random_u32() + local counter = self.random_counter + if counter > 2e9 then + -- Rekey every so often, but start with at least a few bits in the + -- counter. + counter = 0x12345678 + self.hash_u64 = siphash.make_u64_hash() + end + self.random_counter = counter + 1 + return self.hash_u64 +end + +function SocketTable:choose_initial_sequence_number(sockent) + -- FIXME: Consider following RFC 6528 instead. + return bxor(sockent.hash, self:random_u32()) +end + +-- Return socket table entry, adding it to the table if necessary. May +-- signal an error depending on whether the socket exists already or +-- not, according to the updates_allowed parameter; see ctable.add for +-- details. If the socket was newly added, it will be in the +function SocketTable:add_socket(local_ip, remote_ip, local_port, remote_port, + updates_allowed) + local entry = self.scratch_entry + entry.key.local_ip, entry.key.remote_ip = local_ip, remote_ip + entry.key.local_port, entry.key.remote_port = local_port, remote_port + clear_tcp_socket_state(entry.value) + return self.sockets:add(entry.key, entry.value, updates_allowed) +end + +-- Add a new socket to the table. If it exists already, return the +-- existing entry; otherwise return a freshly added entry in the closed +-- state. +function SocketTable:ensure_socket(local_ip, remote_ip, local_port, remote_port) + return self:add_socket(local_ip, remote_ip, local_port, remote_port, + 'preserve') +end + +function SocketTable:listen(ip, port) + assert(port ~= 0, "attempt to listen on port 0") + local sockent = self:add_socket(ip, self.empty_ip, port, 0) + self:set_state(sockent, states.LISTEN) + return sockent +end + +-- The local port must be provided explicitly. Assuming `fn +-- get_ephemeral_port() -> u16` allocates a port between 49152 and +-- 65535, a connection may be established as follows: +-- +-- ```rust,ignore +-- socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port()) +-- ``` +-- +-- The local address may optionally be provided. +-- +-- This function signals an error if the socket was open; see +-- [is_open](#method.is_open). It also signals an error if the local or +-- remote port is zero, or if the remote address is unspecified. +function SocketTable:connect(local_ip, local_port, remote_ip, remote_port) + assert(local_port ~= 0, "attempt to connect from port 0") + assert(remote_port ~= 0, "attempt to connect to port 0") + local sockent = self:ensure_socket(ip, self.empty_ip, port, 0) + assert(not tcp_socket_is_open(sockent.value), "socket already open") + clear_tcp_socket_state(sockent.value) + sockent.value.local_seq_no = self:choose_initial_sequence_number(sockent) + sockent.value.remote_last_seq = self.value.local_seq_no + -- The dispatcher will actually send the packet. + self:set_state(sockent, states.SYN_SENT) + return sockent +end + +-- Close the transmit half of the full-duplex connection. +-- +-- Note that there is no corresponding function for the receive half of +-- the full-duplex connection; only the remote end can close it. If you +-- no longer wish to receive any data and would like to reuse the socket +-- right away, use [abort](#method.abort). +function SocketTable:close(sockent) + return self.handle_close[sockent.value.state](self, sockent) +end + +function SocketTable:add_handler(op, state, handler) + local idx = assert(states[state]) + assert(self['handle_'..op])[idx] = assert(self[op..'_'..handler]) +end +function SocketTable:check_handlers(op) + for i,s in ipairs(tcp_state_names) do assert(self['handle_'..op][i]) end +end +function SocketTable:add_handlers(op, ...) + for _,pair in ipairs({...}) do self:add_handler(op, unpack(pair)) end + self:check_handlers(op) +end + +function SocketTable:close_nop(sockent) end +function SocketTable:close_closed(sockent) + self:set_state(sockent, states.CLOSED) +end +function SocketTable:close_fin_wait_1(sockent) + self:set_state(sockent, states.FIN_WAIT_1) +end +function SocketTable:close_last_ack(sockent) + self:set_state(sockent, states.LAST_ACK) +end + +SocketTable:add_handlers( + 'close', + -- In the LISTEN state there is no established connection; in + -- SYN_SENT state the remote endpoint is not yet synchronized and, + -- upon receiving an RST, will abort the connection. + {'LISTEN', 'closed'}, {'SYN_SENT', 'closed' }, + -- In the SYN_RECEIVED, ESTABLISHED and CLOSE-WAIT states the + -- transmit half of the connection is open, and needs to be + -- explicitly closed with a FIN. + {'SYN_RECEIVED', 'fin_wait_1'}, {'ESTABLISHED', 'fin_wait_1'}, + {'CLOSE_WAIT', 'last_ack'}, + -- In the FIN_WAIT_1, FIN_WAIT_2, CLOSING, LAST_ACK, TIME_WAIT and + -- CLOSED states, the transmit half of the connection is already + -- closed, and no further action is needed. + {'FIN_WAIT_1', 'nop'}, {'FIN_WAIT_2', 'nop'}, {'CLOSING', 'nop'}, + {'TIME_WAIT', 'nop'}, {'LAST_ACK', 'nop'}, {'CLOSED', 'nop'}) + +-- Aborts the connection, if any. +-- +-- This function instantly closes the socket. One reset packet will be +-- sent to the remote endpoint. +-- +-- In terms of the TCP state machine, the socket may be in any state and +-- is moved to the `CLOSED` state. +function SocketTable:abort(sockent) + self:set_state(sockent, states.CLOSED) +end + +-- Return whether the receive half of the full-duplex connection is +-- open. +-- +-- This function returns true if it's possible to receive data from the +-- remote endpoint. It will return true while there is data in the +-- receive buffer, and if there isn't, as long as the remote endpoint +-- has not closed the connection. +-- +-- In terms of the TCP state machine, the socket must be in the +-- `ESTABLISHED`, `FIN-WAIT-1`, or `FIN-WAIT-2` state, or have data in +-- the receive buffer instead. +local may_recv_states = state_partition( + "ESTABLISHED", "FIN_WAIT_1", "FIN_WAIT_2") +local function tcp_socket_may_recv(sock) + return may_recv_states[sock.state] or sock.rx_buffer:is_empty() +end + +-- Check whether the transmit half of the full-duplex connection is open +-- (see [may_send](#method.may_send), and the transmit buffer is not full. +local function tcp_socket_can_send(sock) + return tcp_socket_may_send(socket) and not sock.tx_buffer:is_full() +end + +-- Check whether the transmit half of the full-duplex connection is open +-- (see [may_recv](#method.may_recv), and the transmit buffer is not full. +local function tcp_socket_can_recv(sock) + return tcp_socket_may_recv(socket) and not sock.rx_buffer:is_empty() +end + +-- Enqueue a sequence of octets to be sent, and fill it from a slice. +-- +-- This function returns the amount of bytes actually enqueued, which is limited +-- by the amount of free space in the transmit buffer; down to zero. +function SocketTable:enqueue(sockent, buf, count) + local sock = sockent.value + assert(tcp_socket_may_send(sock)) + -- The connection might have been idle for a long time, and so + -- remote_last_ts would be far in the past. Unless we clear it here, + -- we'll abort the connection down over in dispatch() by erroneously + -- detecting it as timed out. + if sock.tx_buffer:is_empty() then sock.remote_last_ts = -1 end + count = math.min(sock.tx_buffer:write_avail(), count) + -- FIXME: trace + sock.tx_buffer:write(buf, count) + return count +end + +-- Dequeue a sequence of received octets, and fill a slice from it. +-- +-- This function returns the amount of bytes actually dequeued, which is limited +-- by the amount of free space in the transmit buffer; down to zero. +function SocketTable:dequeue(sockent, buf, count) + local sock = sockent.value + -- We may have received some data inside the initial SYN, but until + -- the connection is fully open we must not dequeue any data, as it + -- may be overwritten by e.g. another (stale) SYN. (We do not + -- support TCP Fast Open.) + assert(tcp_socket_may_recv(sock)) + count = math.min(sock.rx_buffer:read_avail(), count) + -- FIXME: trace + sock.rx_buffer:read(buf, count) + sock.remote_seq_no = sock.remote_seq_no + count; + return count +end + +-- Peek at a sequence of received octets without removing them from the +-- receive buffer, and return two values: the pointer and a byte count. +-- +-- This function otherwise behaves identically to [recv](#method.recv). +function SocketTable:peek(sockent) + local sock = sockent.value + -- See dequeue() above. + assert(tcp_socket_may_recv(sock)) + return sock.rx_buffer:peek() +end + +-- Return the amount of octets queued in the transmit buffer. +-- +-- Note that the Berkeley sockets interface does not have an equivalent +-- of this API. +function SocketTable:send_queue(sockent) + return sockent.value.tx_buffer:read_avail() +end + +-- Return the amount of octets queued in the receive buffer. +-- +-- Note that the Berkeley sockets interface does not have an equivalent +-- of this API. +function SocketTable:recv_queue(sockent) + return sockent.value.rx_buffer:read_avail() +end + +function SocketTable:set_state(sockent, state) + local sock = sockent.value + if self.state ~= state then + -- FIXME: trace. + end + self.state = state +end + +function TCP:process_ipv4(p, timestamp) + -- Necessary checks: + -- ipv4 source address is unicast + -- ipv4 checksum matches + -- protocol is tcp +end + +function TCP:process_ipv6(p, timestamp) + -- Necessary checks: + -- ipv6 source address is unicast + -- protocol is tcp +end + +function TCP:process_tcp(p, timestamp) + -- Necessary checks: + -- tcp checksum matches + + -- look up a socket. if one found, process the packet, possibly + -- creating a new socket, and return. + + -- if no socket found, then: + local control = compute_tcp_control(tcp_flags(tcp)) + if control == tcp_control_rst then + -- Don't reply to a TCP RST packet with another TCP RST packet; + -- just pass. + else + local seq = htonl(tcp.ack) + -- See https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ + -- for explanation of why we sometimes send an RST and sometimes + -- an RST|ACK + local ack = 0 + if control == tcp_control_syn then + ack = htonl(tcp.seq) + segment_len(tcp, l4_length) + end + local p = self:add_headers(packet.allocate(), + dst_ip, src_ip, dst_port, src_port, + seq, ack, window, flags) + -- FIXME: send P. + end +end + +function SocketTable:prepare_ack_reply(ip, tcp) + -- From RFC 793: + -- [...] an empty acknowledgment segment containing the current + -- send-sequence number and an acknowledgment indicating the next + -- sequence number expected to be received. + local seq = sock.remote_last_seq + local ack = sock.remote_last_ack + local window_len = sock.window() -- apply window scaling + local p = self:add_headers(packet.allocate(), + dst_ip, src_ip, dst_port, src_port, + seq, ack, window, flags) + -- fixme flags +end + +local default_ttl = 64 +function SocketTable:add_headers(p, src_ip, dst_ip, src_port, dst_port, + seq, ack, window, flags) + -- FIXME set self.push_headers + return self.push_headers(p, src_ip, dst_ip, default_ttl, src_port, dst_port, + seq, ack, window, 0, flags) +end + +-- FIXME: this function is just *not* what we need. +function SocketTable:accepts(sockent, p) + if sockent.value.state == states.CLOSED then + return false + end + + -- If we're still listening for SYNs and the packet has an ACK, it + -- cannot be destined to this socket, but another one may well listen + -- on the same local endpoint. + if sockent.state == states.LISTEN and repr.ack ~= 0 then + return false + end + + -- Reject packets with a wrong destination. + -- Reject packets from a source to which we aren't connected. + + return true +end + +local function make_state_array(typ, val) + local len = #tcp_state_names + 1 + local ret = ffi.typeof("$[?]", ffi.typeof(typ))(len) + for i=0,len do ret[i] = val end + return ret +end + +local function adjoin_states(tab, val, ...) + for _, state in ipairs({...}) do tab[assert(states[state])] = val end +end + +local sent_syn_offsets = make_state_array("uint8_t", 0) +local sent_fin_offsets = make_state_array("uint8_t", 0) + +-- In SYN-SENT or SYN-RECEIVED, we've just sent a SYN. +adjoin_states(sent_syn_offsets, 1, "SYN_SENT", "SYN_RECEIVED") +-- In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN. +adjoin_states(sent_fin_offsets, 1, "FIN_WAIT_1", "LAST_ACK", "CLOSING") +-- In all other states we've already got acknowledgements for all of the +-- control flags we sent. + +function SocketTable:process(sockent, p) + assert(self:accepts(sockent, p)) + + local state = sockent.value.state + local control = compute_tcp_control(tcp_flags(tcp)) + + -- Consider how much the sequence number space differs from the + -- transmit buffer space. + local sent_syn_offset = sent_syn_offsets[state] + local sent_fin_offset = sent_fin_offsets[state] + local control_len = sent_syn_offset + sent_fin_offset + + -- Reject unacceptable acknowledgements. + if has_ack(tcp) then + if state == states.SYN_SENT then + if ack_number == sockent.value.local_seq_no + 1 then + -- RST received SYN_SENT must acknowledge the initial SYN. + else + net_debug("{}:{}:{}: unacceptable RST|ACK in response to initial SYN", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + return drop(p) + end + elseif state == states.LISTEN then + -- The initial SYN cannot contain an acknowledgement. + error("unreachable") -- see accepts(); a FIXME to refactor + else + -- Every acknowledgement must be for transmitted but unacknowledged data. + local unacknowledged = sockent.value.tx_buffer:read_avail() + control_len; + + if ack_number < self.local_seq_no then + net_debug("{}:{}:{}: duplicate ACK ({} not in {}...{})", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + ack_number, self.local_seq_no, self.local_seq_no + unacknowledged) + -- FIXME: implement fast retransmit + return drop(p) + elseif ack_number > self.local_seq_no + unacknowledged then + net_debug("{}:{}:{}: unacceptable ACK ({} not in {}...{})", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + ack_number, self.local_seq_no, self.local_seq_no + unacknowledged) + return Ok(Some(self.ack_reply(ip_repr, repr))) + end + end + else + -- Packet has no ACK; there are only a limited number of states in + -- which this is valid. + if control == tcp_control_rst then + if state == states.SYN_SENT then + -- RST received in SYN_SENT must acknowledge the initial + -- SYN. + net_debug("{}:{}:{}: unacceptable RST (expecting RST|ACK) in response to initial SYN", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + return drop(p) + else + -- Otherwise RST just has to have a valid sequence number. + end + else + -- Every packet after the initial SYN must include ACK. + net_debug("{}:{}:{}: expecting an ACK", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + return drop(p) + end + end + + local window_start = self.remote_seq_no + self.rx_buffer:read_avail(); + local window_end = self.remote_seq_no + self.rx_buffer.size; + local segment_start = seq_number; + local segment_end = seq_number + segment_len(tcp, l4_length) + + local payload_offset + if state == states.LISTEN or state == states.SYN_SENT then + -- In LISTEN and SYN-SENT states, we have not yet synchronized + -- with the remote end. + payload_offset = 0 + else + -- In all other states, segments must occupy a valid portion of + -- the receive window. + local segment_in_window; + + if window_start == window_end and segment_start ~= segment_end then + net_debug("{}:{}:{}: non-zero-length segment with zero receive window, will only send an ACK", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + segment_in_window = false + elseif segment_start == segment_end and segment_end == window_start - 1 then + net_debug("{}:{}:{}: received a keep-alive or window probe packet, will send an ACK", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + segment_in_window = false + elseif not ((window_start <= segment_start and segment_start <= window_end) and + (window_start <= segment_end and segment_end <= window_end)) then + net_debug("{}:{}:{}: segment not in receive window ({}..{} not intersecting {}..{}), will send challenge ACK", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + segment_start, segment_end, window_start, window_end) + segment_in_window = false + else + segment_in_window = true + end + + if segment_in_window then + -- We've checked that segment_start >= window_start above. + payload_offset = to_uint32(segment_start - window_start) + else + -- If we're in the TIME-WAIT state, restart the TIME-WAIT timeout, since + -- the remote end may not have realized we've closed the connection. + if state == states.TIME_WAIT then + self.timer:set_for_close(timestamp) + end + + return self.ack_reply(ip, tcp) + end + end + + -- Compute the amount of acknowledged octets, removing the SYN and FIN bits + -- from the sequence space. + local ack_len = 0 + local ack_of_fin = false + if control ~= tcp_control_rst then + if has_ack(tcp) then + local ack_number = get_ack_number(tcp) + ack_len = ack_number - self.local_seq_no + -- There could have been no data sent before the SYN, so we always remove it + -- from the sequence space. + if sent_syn then ack_len = ack_len - 1 end + -- We could've sent data before the FIN, so only remove FIN from the sequence + -- space if all of that data is acknowledged. + if sent_fin and sockent.value.tx_buffer:read_avail() + 1 == ack_len then + ack_len = ack_len - 1 + net_trace("{}:{}:{}: received ACK of FIN", + self.meta.handle, self.local_endpoint, self.remote_endpoint) + ack_of_fin = true + end + end + end + + if control == tcp_control_psh then + -- Disregard control flags we don't care about or shouldn't act on + -- yet. + control = tcp_control_none + elseif control == tcp_control_fin and window_start ~= segment_start then + -- If a FIN is received at the end of the current segment but the + -- start of the segment is not at the start of the receive window, + -- disregard this FIN. + control = tcp_control_none + end + + local update_state=[[ + + -- Validate and update the state. + match (self.state, control) { + -- RSTs are not accepted in the LISTEN state. + (State::Listen, TcpControl::Rst) => + return Err(Error::Dropped), + + -- RSTs in SYN-RECEIVED flip the socket back to the LISTEN state. + (State::SynReceived, TcpControl::Rst) => { + net_trace!("{}:{}:{}: received RST", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.local_endpoint.addr = self.listen_address; + self.remote_endpoint = IpEndpoint::default(); + self.set_state(State::Listen); + return Ok(None) + } + + -- RSTs in any other state close the socket. + (_, TcpControl::Rst) => { + net_trace!("{}:{}:{}: received RST", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.set_state(State::Closed); + self.local_endpoint = IpEndpoint::default(); + self.remote_endpoint = IpEndpoint::default(); + return Ok(None) + } + + -- SYN packets in the LISTEN state change it to SYN-RECEIVED. + (State::Listen, TcpControl::Syn) => { + net_trace!("{}:{}: received SYN", + self.meta.handle, self.local_endpoint); + self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); + self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port); + -- FIXME: use something more secure here + self.local_seq_no = TcpSeqNumber(-repr.seq_number.0); + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no; + if let Some(max_seg_size) = repr.max_seg_size { + self.remote_mss = max_seg_size as usize + } + self.set_state(State::SynReceived); + self.timer:set_for_idle(timestamp, self.keep_alive); + } + + -- ACK packets in the SYN-RECEIVED state change it to ESTABLISHED. + (State::SynReceived, TcpControl::None) => { + self.set_state(State::Established); + self.timer:set_for_idle(timestamp, self.keep_alive); + } + + -- FIN packets in the SYN-RECEIVED state change it to CLOSE-WAIT. + -- It's not obvious from RFC 793 that this is permitted, but + -- 7th and 8th steps in the "SEGMENT ARRIVES" event describe this behavior. + (State::SynReceived, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.set_state(State::CloseWait); + self.timer:set_for_idle(timestamp, self.keep_alive); + } + + -- SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED. + (State::SynSent, TcpControl::Syn) => { + net_trace!("{}:{}:{}: received SYN|ACK", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no + 1; + self.remote_last_ack = Some(repr.seq_number); + if let Some(max_seg_size) = repr.max_seg_size { + self.remote_mss = max_seg_size as usize; + } + self.set_state(State::Established); + self.timer.set_for_idle(timestamp, self.keep_alive); + } + + -- ACK packets in ESTABLISHED state reset the retransmit timer, + -- except for duplicate ACK packets which preserve it. + (State::Established, TcpControl::None) => { + if !self.timer.is_retransmit() || ack_len != 0 { + self.timer.set_for_idle(timestamp, self.keep_alive); + } + }, + + -- FIN packets in ESTABLISHED state indicate the remote side has closed. + (State::Established, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.set_state(State::CloseWait); + self.timer.set_for_idle(timestamp, self.keep_alive); + } + + -- ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already + -- sent everything in the transmit buffer. If not, they reset the retransmit timer. + (State::FinWait1, TcpControl::None) => { + if ack_of_fin { + self.set_state(State::FinWait2); + } + self.timer.set_for_idle(timestamp, self.keep_alive); + } + + -- FIN packets in FIN-WAIT-1 state change it to CLOSING, or to TIME-WAIT + -- if they also acknowledge our FIN. + (State::FinWait1, TcpControl::Fin) => { + self.remote_seq_no += 1; + if ack_of_fin { + self.set_state(State::TimeWait); + self.timer.set_for_close(timestamp); + } else { + self.set_state(State::Closing); + self.timer.set_for_idle(timestamp, self.keep_alive); + } + } + + -- FIN packets in FIN-WAIT-2 state change it to TIME-WAIT. + (State::FinWait2, TcpControl::Fin) => { + self.remote_seq_no += 1; + self.set_state(State::TimeWait); + self.timer.set_for_close(timestamp); + } + + -- ACK packets in CLOSING state change it to TIME-WAIT. + (State::Closing, TcpControl::None) => { + if ack_of_fin { + self.set_state(State::TimeWait); + self.timer.set_for_close(timestamp); + } else { + self.timer.set_for_idle(timestamp, self.keep_alive); + } + } + + -- ACK packets in CLOSE-WAIT state reset the retransmit timer. + (State::CloseWait, TcpControl::None) => { + self.timer.set_for_idle(timestamp, self.keep_alive); + } + + -- ACK packets in LAST-ACK state change it to CLOSED. + (State::LastAck, TcpControl::None) => { + -- Clear the remote endpoint, or we'll send an RST there. + self.set_state(State::Closed); + self.local_endpoint = IpEndpoint::default(); + self.remote_endpoint = IpEndpoint::default(); + } + + _ => { + net_debug!("{}:{}:{}: unexpected packet {}", + self.meta.handle, self.local_endpoint, self.remote_endpoint, repr); + return Err(Error::Dropped) + } + } + ]] + + -- Update remote state. + self.remote_last_ts = timestamp + self.remote_win_len = ntohs(tcp.window_len) + + if ack_len > 0 then + -- Dequeue acknowledged octets. + debug_assert(self.tx_buffer.len() >= ack_len) + net_trace("{}:{}:{}: tx buffer: dequeueing {} octets (now {})", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + ack_len, self.tx_buffer.len() - ack_len) + self.tx_buffer.drop(ack_len) + end + + if has_ack(tcp) then + local ack_number = tcp_ack_number(tcp) + -- We've processed everything in the incoming segment, so advance the local + -- sequence number past it. + self.local_seq_no = ack_number + -- During retransmission, if an earlier segment got lost but later + -- was successfully received, self.local_seq_no can move past + -- self.remote_last_seq. + -- + -- Do not attempt to retransmit the latter segments; not only this + -- is pointless in theory but also impossible in practice, since + -- they have been already deallocated from the buffer. + if self.remote_last_seq < self.local_seq_no then + self.remote_last_seq = self.local_seq_no + end + end + + local payload_len = tcp_payload_length(tcp, l4_length) + if payload_len == 0 then return end + + local reordering_before = self.reorder:has_holes() + self.reorder:write(sockent.value.rx_buffer, window_start, segment_start, + tcp_payload(tcp), payload_length) + local reordering_after = self.reorder:has_holes() + + -- Now there may be some data! + + -- Per RFC 5681, we should send an immediate ACK when either: + -- 1) an out-of-order segment is received, or + -- 2) a segment arrives that fills in all or part of a gap in sequence space. + if reordering_before or reordering_after then + -- Note that we change the transmitter state here. This is fine + -- because smoltcp assumes that it can always transmit zero or one + -- packets for every packet it receives. + net_trace("{}:{}:{}: ACKing incoming segment", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.remote_last_ack = Some(self.remote_seq_no + self.rx_buffer.len()); + Ok(Some(self.ack_reply(ip_repr, repr))) + else + Ok(None) + end +end +str=[[ + + fn timed_out(&self, timestamp: u64) -> bool { + match (self.remote_last_ts, self.timeout) { + (Some(remote_last_ts), Some(timeout)) => + timestamp >= remote_last_ts + timeout, + (_, _) => + false + } + } + + fn seq_to_transmit(&self) -> bool { + let control; + match self.state { + State::SynSent | State::SynReceived => + control = TcpControl::Syn, + State::FinWait1 | State::LastAck => + control = TcpControl::Fin, + _ => control = TcpControl::None + } + + if self.remote_win_len > 0 { + self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len() + } else { + false + } + } + + fn ack_to_transmit(&self) -> bool { + if let Some(remote_last_ack) = self.remote_last_ack { + remote_last_ack < self.remote_seq_no + self.rx_buffer.len() + } else { + false + } + } + + fn window_to_update(&self) -> bool { + self.rx_buffer.window() as u16 > self.remote_last_win + } + + pub(crate) fn dispatch(&mut self, timestamp: u64, caps: &DeviceCapabilities, + emit: F) -> Result<()> + where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> { + if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) } + + if self.remote_last_ts.is_none() { + // We get here in exactly two cases: + // 1) This socket just transitioned into SYN-SENT. + // 2) This socket had an empty transmit buffer and some data was added there. + // Both are similar in that the socket has been quiet for an indefinite + // period of time, it isn't anymore, and the local endpoint is talking. + // So, we start counting the timeout not from the last received packet + // but from the first transmitted one. + self.remote_last_ts = Some(timestamp); + } + + // Check if any state needs to be changed because of a timer. + if self.timed_out(timestamp) { + // If a timeout expires, we should abort the connection. + net_debug!("{}:{}:{}: timeout exceeded", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.set_state(State::Closed); + } else if !self.seq_to_transmit() { + if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) { + // If a retransmit timer expired, we should resend data starting at the last ACK. + net_debug!("{}:{}:{}: retransmitting at t+{}ms", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + retransmit_delta); + self.remote_last_seq = self.local_seq_no; + } + } + + // Decide whether we're sending a packet. + if self.seq_to_transmit() { + // If we have data to transmit and it fits into partner's window, do it. + net_trace!("{}:{}:{}: outgoing segment will send data or flags", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.ack_to_transmit() { + // If we have data to acknowledge, do it. + net_trace!("{}:{}:{}: outgoing segment will acknowledge", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.window_to_update() { + // If we have window length increase to advertise, do it. + net_trace!("{}:{}:{}: outgoing segment will update window", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.state == State::Closed { + // If we need to abort the connection, do it. + net_trace!("{}:{}:{}: outgoing segment will abort connection", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.timer.should_retransmit(timestamp).is_some() { + // If we have packets to retransmit, do it. + net_trace!("{}:{}:{}: retransmit timer expired", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.timer.should_keep_alive(timestamp) { + // If we need to transmit a keep-alive packet, do it. + net_trace!("{}:{}:{}: keep-alive timer expired", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if self.timer.should_close(timestamp) { + // If we have spent enough time in the TIME-WAIT state, close the socket. + net_trace!("{}:{}:{}: TIME-WAIT timer expired", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + self.reset(); + return Err(Error::Exhausted) + } else { + return Err(Error::Exhausted) + } + + // Construct the lowered IP representation. + // We might need this to calculate the MSS, so do it early. + let mut ip_repr = IpRepr::Unspecified { + src_addr: self.local_endpoint.addr, + dst_addr: self.remote_endpoint.addr, + protocol: IpProtocol::Tcp, + hop_limit: self.hop_limit.unwrap_or(64), + payload_len: 0 + }.lower(&[])?; + + // Construct the basic TCP representation, an empty ACK packet. + // We'll adjust this to be more specific as needed. + let mut repr = TcpRepr { + src_port: self.local_endpoint.port, + dst_port: self.remote_endpoint.port, + control: TcpControl::None, + seq_number: self.remote_last_seq, + ack_number: Some(self.remote_seq_no + self.rx_buffer.len()), + window_len: self.rx_buffer.window() as u16, + max_seg_size: None, + payload: &[] + }; + + match self.state { + // We transmit an RST in the CLOSED state. If we ended up in the CLOSED state + // with a specified endpoint, it means that the socket was aborted. + State::Closed => { + repr.control = TcpControl::Rst; + } + + // We never transmit anything in the LISTEN state. + State::Listen => return Err(Error::Exhausted), + + // We transmit a SYN in the SYN-SENT state. + // We transmit a SYN|ACK in the SYN-RECEIVED state. + State::SynSent | State::SynReceived => { + repr.control = TcpControl::Syn; + if self.state == State::SynSent { + repr.ack_number = None; + } + } + + // We transmit data in all states where we may have data in the buffer, + // or the transmit half of the connection is still open: + // the ESTABLISHED, FIN-WAIT-1, CLOSE-WAIT and LAST-ACK states. + State::Established | State::FinWait1 | State::CloseWait | State::LastAck => { + // Extract as much data as the remote side can receive in this packet + // from the transmit buffer. + let offset = self.remote_last_seq - self.local_seq_no; + let size = cmp::min(self.remote_win_len, self.remote_mss); + repr.payload = self.tx_buffer.get_allocated(offset, size); + // If we've sent everything we had in the buffer, follow it with the PSH or FIN + // flags, depending on whether the transmit half of the connection is open. + if offset + repr.payload.len() == self.tx_buffer.len() { + match self.state { + State::FinWait1 | State::LastAck => + repr.control = TcpControl::Fin, + State::Established | State::CloseWait if repr.payload.len() > 0 => + repr.control = TcpControl::Psh, + _ => () + } + } + } + + // We do not transmit anything in the FIN-WAIT-2 state. + State::FinWait2 => return Err(Error::Exhausted), + + // We do not transmit data or control flags in the CLOSING or TIME-WAIT states, + // but we may retransmit an ACK. + State::Closing | State::TimeWait => () + } + + // There might be more than one reason to send a packet. E.g. the keep-alive timer + // has expired, and we also have data in transmit buffer. Since any packet that occupies + // sequence space will elicit an ACK, we only need to send an explicit packet if we + // couldn't fill the sequence space with anything. + let is_keep_alive; + if self.timer.should_keep_alive(timestamp) && repr.is_empty() { + repr.seq_number = repr.seq_number - 1; + repr.payload = b"\x00"; // RFC 1122 says we should do this + is_keep_alive = true; + } else { + is_keep_alive = false; + } + + // Trace a summary of what will be sent. + if is_keep_alive { + net_trace!("{}:{}:{}: sending a keep-alive", + self.meta.handle, self.local_endpoint, self.remote_endpoint); + } else if repr.payload.len() > 0 { + net_trace!("{}:{}:{}: tx buffer: sending {} octets at offset {}", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + repr.payload.len(), self.remote_last_seq - self.local_seq_no); + } + if repr.control != TcpControl::None || repr.payload.len() == 0 { + let flags = + match (repr.control, repr.ack_number) { + (TcpControl::Syn, None) => "SYN", + (TcpControl::Syn, Some(_)) => "SYN|ACK", + (TcpControl::Fin, Some(_)) => "FIN|ACK", + (TcpControl::Rst, Some(_)) => "RST|ACK", + (TcpControl::Psh, Some(_)) => "PSH|ACK", + (TcpControl::None, Some(_)) => "ACK", + _ => "" + }; + net_trace!("{}:{}:{}: sending {}", + self.meta.handle, self.local_endpoint, self.remote_endpoint, + flags); + } + + if repr.control == TcpControl::Syn { + // Fill the MSS option. See RFC 6691 for an explanation of this calculation. + let mut max_segment_size = caps.max_transmission_unit; + max_segment_size -= ip_repr.buffer_len(); + max_segment_size -= repr.header_len(); + repr.max_seg_size = Some(max_segment_size as u16); + } + + // Actually send the packet. If this succeeds, it means the packet is in + // the device buffer, and its transmission is imminent. If not, we might have + // a number of problems, e.g. we need neighbor discovery. + // + // Bailing out if the packet isn't placed in the device buffer allows us + // to not waste time waiting for the retransmit timer on packets that we know + // for sure will not be successfully transmitted. + ip_repr.set_payload_len(repr.buffer_len()); + emit((ip_repr, repr))?; + + // We've sent something, whether useful data or a keep-alive packet, so rewind + // the keep-alive timer. + self.timer.rewind_keep_alive(timestamp, self.keep_alive); + + // Leave the rest of the state intact if sending a keep-alive packet, since those + // carry a fake segment. + if is_keep_alive { return Ok(()) } + + // We've sent a packet successfully, so we can update the internal state now. + self.remote_last_seq = repr.seq_number + repr.segment_len(); + self.remote_last_ack = repr.ack_number; + self.remote_last_win = repr.window_len; + + if !self.seq_to_transmit() && repr.segment_len() > 0 { + // If we've transmitted all data we could (and there was something at all, + // data or flag, to transmit, not just an ACK), wind up the retransmit timer. + self.timer.set_for_retransmit(timestamp); + } + + if self.state == State::Closed { + // When aborting a connection, forget about it after sending a single RST packet. + self.local_endpoint = IpEndpoint::default(); + self.remote_endpoint = IpEndpoint::default(); + } + + Ok(()) + } + + pub(crate) fn poll_at(&self) -> Option { + // The logic here mirrors the beginning of dispatch() closely. + if !self.remote_endpoint.is_specified() { + // No one to talk to, nothing to transmit. + None + } else if self.remote_last_ts.is_none() { + // Socket stopped being quiet recently, we need to acquire a timestamp. + Some(0) + } else if self.state == State::Closed { + // Socket was aborted, we have an RST packet to transmit. + Some(0) + } else if self.seq_to_transmit() || self.ack_to_transmit() || self.window_to_update() { + // We have a data or flag packet to transmit. + Some(0) + } else { + let timeout_poll_at; + match (self.remote_last_ts, self.timeout) { + // If we're transmitting or retransmitting data, we need to poll at the moment + // when the timeout would expire. + (Some(remote_last_ts), Some(timeout)) => + timeout_poll_at = Some(remote_last_ts + timeout), + // Otherwise we have no timeout. + (_, _) => + timeout_poll_at = None + } + + // We wait for the earliest of our timers to fire. + [self.timer.poll_at(), timeout_poll_at] + .iter() + .filter_map(|x| *x) + .min() + } + } +} + +impl<'a> Into> for TcpSocket<'a> { + fn into(self) -> Socket<'a, 'static> { + Socket::Tcp(self) + } +} + +impl<'a> fmt::Write for TcpSocket<'a> { + fn write_str(&mut self, slice: &str) -> fmt::Result { + let slice = slice.as_bytes(); + if self.send_slice(slice) == Ok(slice.len()) { + Ok(()) + } else { + Err(fmt::Error) + } + } +} + +#[cfg(test)] +mod test { + use core::i32; + use wire::{IpAddress, IpRepr, IpCidr}; + use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3, MOCK_UNSPECIFIED}; + use super::*; + + // =========================================================================================// + // Constants + // =========================================================================================// + + const LOCAL_PORT: u16 = 80; + const REMOTE_PORT: u16 = 49500; + const LOCAL_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_1, port: LOCAL_PORT }; + const REMOTE_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_2, port: REMOTE_PORT }; + const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000); + const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10000); + + const SEND_IP_TEMPL: IpRepr = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_1, dst_addr: MOCK_IP_ADDR_2, + protocol: IpProtocol::Tcp, payload_len: 20, + hop_limit: 64 + }; + const SEND_TEMPL: TcpRepr<'static> = TcpRepr { + src_port: REMOTE_PORT, dst_port: LOCAL_PORT, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)), + window_len: 256, max_seg_size: None, + payload: &[] + }; + const _RECV_IP_TEMPL: IpRepr = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_1, dst_addr: MOCK_IP_ADDR_2, + protocol: IpProtocol::Tcp, payload_len: 20, + hop_limit: 64 + }; + const RECV_TEMPL: TcpRepr<'static> = TcpRepr { + src_port: LOCAL_PORT, dst_port: REMOTE_PORT, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)), + window_len: 64, max_seg_size: None, + payload: &[] + }; + + #[cfg(feature = "proto-ipv6")] + const BASE_MSS: u16 = 1460; + #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] + const BASE_MSS: u16 = 1480; + + // =========================================================================================// + // Helper functions + // =========================================================================================// + + fn send(socket: &mut TcpSocket, timestamp: u64, repr: &TcpRepr) -> + Result>> { + let ip_repr = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_2, + dst_addr: MOCK_IP_ADDR_1, + protocol: IpProtocol::Tcp, + payload_len: repr.buffer_len(), + hop_limit: 64 + }; + net_trace!("send: {}", repr); + + assert!(socket.accepts(&ip_repr, repr)); + match socket.process(timestamp, &ip_repr, repr) { + Ok(Some((_ip_repr, repr))) => { + net_trace!("recv: {}", repr); + Ok(Some(repr)) + } + Ok(None) => Ok(None), + Err(err) => Err(err) + } + } + + fn recv(socket: &mut TcpSocket, timestamp: u64, mut f: F) + where F: FnMut(Result) { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1520; + let result = socket.dispatch(timestamp, &caps, |(ip_repr, tcp_repr)| { + let ip_repr = ip_repr.lower(&[IpCidr::new(LOCAL_END.addr, 24)]).unwrap(); + + assert_eq!(ip_repr.protocol(), IpProtocol::Tcp); + assert_eq!(ip_repr.src_addr(), MOCK_IP_ADDR_1); + assert_eq!(ip_repr.dst_addr(), MOCK_IP_ADDR_2); + assert_eq!(ip_repr.payload_len(), tcp_repr.buffer_len()); + + net_trace!("recv: {}", tcp_repr); + Ok(f(Ok(tcp_repr))) + }); + match result { + Ok(()) => (), + Err(e) => f(Err(e)) + } + } + + macro_rules! send { + ($socket:ident, $repr:expr) => + (send!($socket, time 0, $repr)); + ($socket:ident, $repr:expr, $result:expr) => + (send!($socket, time 0, $repr, $result)); + ($socket:ident, time $time:expr, $repr:expr) => + (send!($socket, time $time, $repr, Ok(None))); + ($socket:ident, time $time:expr, $repr:expr, $result:expr) => + (assert_eq!(send(&mut $socket, $time, &$repr), $result)); + } + + macro_rules! recv { + ($socket:ident, [$( $repr:expr ),*]) => ({ + $( recv!($socket, Ok($repr)); )* + recv!($socket, Err(Error::Exhausted)) + }); + ($socket:ident, $result:expr) => + (recv!($socket, time 0, $result)); + ($socket:ident, time $time:expr, $result:expr) => + (recv(&mut $socket, $time, |result| { + // Most of the time we don't care about the PSH flag. + let result = result.map(|mut repr| { + repr.control = repr.control.quash_psh(); + repr + }); + assert_eq!(result, $result) + })); + ($socket:ident, time $time:expr, $result:expr, exact) => + (recv(&mut $socket, $time, |repr| assert_eq!(repr, $result))); + } + + macro_rules! sanity { + ($socket1:expr, $socket2:expr) => ({ + let (s1, s2) = ($socket1, $socket2); + assert_eq!(s1.state, s2.state, "state"); + assert_eq!(s1.listen_address, s2.listen_address, "listen_address"); + assert_eq!(s1.local_endpoint, s2.local_endpoint, "local_endpoint"); + assert_eq!(s1.remote_endpoint, s2.remote_endpoint, "remote_endpoint"); + assert_eq!(s1.local_seq_no, s2.local_seq_no, "local_seq_no"); + assert_eq!(s1.remote_seq_no, s2.remote_seq_no, "remote_seq_no"); + assert_eq!(s1.remote_last_seq, s2.remote_last_seq, "remote_last_seq"); + assert_eq!(s1.remote_last_ack, s2.remote_last_ack, "remote_last_ack"); + assert_eq!(s1.remote_last_win, s2.remote_last_win, "remote_last_win"); + assert_eq!(s1.remote_win_len, s2.remote_win_len, "remote_win_len"); + assert_eq!(s1.timer, s2.timer, "timer"); + }) + } + + #[cfg(feature = "log")] + fn init_logger() { + extern crate log; + use std::boxed::Box; + + struct Logger(()); + + impl log::Log for Logger { + fn enabled(&self, _metadata: &log::LogMetadata) -> bool { + true + } + + fn log(&self, record: &log::LogRecord) { + println!("{}", record.args()); + } + } + + let _ = log::set_logger(|max_level| { + max_level.set(log::LogLevelFilter::Trace); + Box::new(Logger(())) + }); + + println!(""); + } + + fn socket() -> TcpSocket<'static> { + #[cfg(feature = "log")] + init_logger(); + + let rx_buffer = SocketBuffer::new(vec![0; 64]); + let tx_buffer = SocketBuffer::new(vec![0; 64]); + TcpSocket::new(rx_buffer, tx_buffer) + } + + fn socket_syn_received() -> TcpSocket<'static> { + let mut s = socket(); + s.state = State::SynReceived; + s.local_endpoint = LOCAL_END; + s.remote_endpoint = REMOTE_END; + s.local_seq_no = LOCAL_SEQ; + s.remote_seq_no = REMOTE_SEQ + 1; + s.remote_last_seq = LOCAL_SEQ; + s.remote_win_len = 256; + s + } + + fn socket_syn_sent() -> TcpSocket<'static> { + let mut s = socket(); + s.state = State::SynSent; + s.local_endpoint = IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_PORT); + s.remote_endpoint = REMOTE_END; + s.local_seq_no = LOCAL_SEQ; + s.remote_last_seq = LOCAL_SEQ; + s + } + + fn socket_established() -> TcpSocket<'static> { + let mut s = socket_syn_received(); + s.state = State::Established; + s.local_seq_no = LOCAL_SEQ + 1; + s.remote_last_seq = LOCAL_SEQ + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1); + s.remote_last_win = 64; + s + } + + fn socket_fin_wait_1() -> TcpSocket<'static> { + let mut s = socket_established(); + s.state = State::FinWait1; + s + } + + fn socket_fin_wait_2() -> TcpSocket<'static> { + let mut s = socket_fin_wait_1(); + s.state = State::FinWait2; + s.local_seq_no = LOCAL_SEQ + 1 + 1; + s.remote_last_seq = LOCAL_SEQ + 1 + 1; + s + } + + fn socket_closing() -> TcpSocket<'static> { + let mut s = socket_fin_wait_1(); + s.state = State::Closing; + s.remote_last_seq = LOCAL_SEQ + 1 + 1; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s + } + + fn socket_time_wait(from_closing: bool) -> TcpSocket<'static> { + let mut s = socket_fin_wait_2(); + s.state = State::TimeWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + if from_closing { + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); + } + s.timer = Timer::Close { expires_at: 1_000 + CLOSE_DELAY }; + s + } + + fn socket_close_wait() -> TcpSocket<'static> { + let mut s = socket_established(); + s.state = State::CloseWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); + s + } + + fn socket_last_ack() -> TcpSocket<'static> { + let mut s = socket_close_wait(); + s.state = State::LastAck; + s + } + + fn socket_recved() -> TcpSocket<'static> { + let mut s = socket_established(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }]); + s + } + + // =========================================================================================// + // Tests for the CLOSED state. + // =========================================================================================// + #[test] + fn test_closed_reject() { + let s = socket(); + assert_eq!(s.state, State::Closed); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + ..SEND_TEMPL + }; + assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_closed_reject_after_listen() { + let mut s = socket(); + s.listen(LOCAL_END).unwrap(); + s.close(); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + ..SEND_TEMPL + }; + assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_closed_close() { + let mut s = socket(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the LISTEN state. + // =========================================================================================// + fn socket_listen() -> TcpSocket<'static> { + let mut s = socket(); + s.state = State::Listen; + s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT); + s + } + + #[test] + fn test_listen_sanity() { + let mut s = socket(); + s.listen(LOCAL_PORT).unwrap(); + sanity!(s, socket_listen()); + } + + #[test] + fn test_listen_validation() { + let mut s = socket(); + assert_eq!(s.listen(0), Err(Error::Unaddressable)); + } + + #[test] + fn test_listen_twice() { + let mut s = socket(); + assert_eq!(s.listen(80), Ok(())); + assert_eq!(s.listen(80), Err(Error::Illegal)); + } + + #[test] + fn test_listen_syn() { + let mut s = socket_listen(); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + }); + sanity!(s, socket_syn_received()); + } + + #[test] + fn test_listen_syn_reject_ack() { + let s = socket_listen(); + + let tcp_repr = TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), + ..SEND_TEMPL + }; + assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + + assert_eq!(s.state, State::Listen); + } + + #[test] + fn test_listen_rst() { + let mut s = socket_listen(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + }, Err(Error::Dropped)); + } + + #[test] + fn test_listen_close() { + let mut s = socket_listen(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the SYN-RECEIVED state. + // =========================================================================================// + + #[test] + fn test_syn_received_ack() { + let mut s = socket_syn_received(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Established); + sanity!(s, socket_established()); + } + + #[test] + fn test_syn_received_fin() { + let mut s = socket_syn_received(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 1), + window_len: 58, + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::CloseWait); + sanity!(s, TcpSocket { + remote_last_ack: Some(REMOTE_SEQ + 1 + 6 + 1), + remote_last_win: 58, + ..socket_close_wait() + }); + } + + #[test] + fn test_syn_received_rst() { + let mut s = socket_syn_received(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Listen); + assert_eq!(s.local_endpoint, IpEndpoint::new(IpAddress::Unspecified, LOCAL_END.port)); + assert_eq!(s.remote_endpoint, IpEndpoint::default()); + } + + #[test] + fn test_syn_received_close() { + let mut s = socket_syn_received(); + s.close(); + assert_eq!(s.state, State::FinWait1); + } + + // =========================================================================================// + // Tests for the SYN-SENT state. + // =========================================================================================// + + #[test] + fn test_connect_validation() { + let mut s = socket(); + assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), + Err(Error::Unaddressable)); + assert_eq!(s.connect(REMOTE_END, (MOCK_UNSPECIFIED, 0)), + Err(Error::Unaddressable)); + assert_eq!(s.connect((MOCK_UNSPECIFIED, 0), LOCAL_END), + Err(Error::Unaddressable)); + assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), + Err(Error::Unaddressable)); + } + + #[test] + fn test_connect() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.connect(REMOTE_END, LOCAL_END.port).unwrap(); + assert_eq!(s.local_endpoint, IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_END.port)); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + ..SEND_TEMPL + }); + assert_eq!(s.local_endpoint, LOCAL_END); + } + + #[test] + fn test_connect_unspecified_local() { + let mut s = socket(); + assert_eq!(s.connect(REMOTE_END, (MOCK_UNSPECIFIED, 80)), + Ok(())); + s.abort(); + assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), + Ok(())); + s.abort(); + } + + #[test] + fn test_connect_specified_local() { + let mut s = socket(); + assert_eq!(s.connect(REMOTE_END, (MOCK_IP_ADDR_2, 80)), + Ok(())); + } + + #[test] + fn test_connect_twice() { + let mut s = socket(); + assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), + Ok(())); + assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), + Err(Error::Illegal)); + } + + #[test] + fn test_syn_sent_sanity() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.connect(REMOTE_END, LOCAL_END).unwrap(); + sanity!(s, socket_syn_sent()); + } + + #[test] + fn test_syn_sent_syn_ack() { + let mut s = socket_syn_sent(); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + recv!(s, time 1000, Err(Error::Exhausted)); + assert_eq!(s.state, State::Established); + sanity!(s, socket_established()); + } + + #[test] + fn test_syn_sent_rst() { + let mut s = socket_syn_sent(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_syn_sent_rst_no_ack() { + let mut s = socket_syn_sent(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + }, Err(Error::Dropped)); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_rst_bad_ack() { + let mut s = socket_syn_sent(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(TcpSeqNumber(1234)), + ..SEND_TEMPL + }, Err(Error::Dropped)); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_close() { + let mut s = socket(); + s.close(); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the ESTABLISHED state. + // =========================================================================================// + + #[test] + fn test_established_recv() { + let mut s = socket_established(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }]); + assert_eq!(s.rx_buffer.dequeue_many(6), &b"abcdef"[..]); + } + + #[test] + fn test_established_send() { + let mut s = socket_established(); + // First roundtrip after establishing. + s.send_slice(b"abcdef").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + assert_eq!(s.tx_buffer.len(), 6); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + }); + assert_eq!(s.tx_buffer.len(), 0); + // Second roundtrip. + s.send_slice(b"foobar").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + }); + assert_eq!(s.tx_buffer.len(), 0); + } + + #[test] + fn test_established_send_no_ack_send() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + s.send_slice(b"foobar").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }]); + } + + #[test] + fn test_established_send_buf_gt_win() { + let mut data = [0; 32]; + for (i, elem) in data.iter_mut().enumerate() { + *elem = i as u8 + } + + let mut s = socket_established(); + s.remote_win_len = 16; + s.send_slice(&data[..]).unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &data[0..16], + ..RECV_TEMPL + }, TcpRepr { + seq_number: LOCAL_SEQ + 1 + 16, + ack_number: Some(REMOTE_SEQ + 1), + payload: &data[16..32], + ..RECV_TEMPL + }]); + } + + #[test] + fn test_established_send_wrap() { + let mut s = socket_established(); + let local_seq_start = TcpSeqNumber(i32::MAX - 1); + s.local_seq_no = local_seq_start + 1; + s.remote_last_seq = local_seq_start + 1; + s.send_slice(b"abc").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: local_seq_start + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_no_ack() { + let mut s = socket_established(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + }, Err(Error::Dropped)); + } + + #[test] + fn test_established_bad_ack() { + let mut s = socket_established(); + // Already acknowledged data. + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)), + ..SEND_TEMPL + }, Err(Error::Dropped)); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + // Data not yet transmitted. + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 10), + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }))); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + } + + #[test] + fn test_established_bad_seq() { + let mut s = socket_established(); + // Data outside of receive window. + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }))); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + } + + #[test] + fn test_established_fin() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::CloseWait); + sanity!(s, socket_close_wait()); + } + + #[test] + fn test_established_fin_after_missing() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }))); + assert_eq!(s.state, State::Established); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 6), + window_len: 52, + ..RECV_TEMPL + }))); + assert_eq!(s.state, State::Established); + } + + #[test] + fn test_established_send_fin() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::CloseWait); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + } + + #[test] + fn test_established_rst() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_rst_no_ack() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + sanity!(s, socket_fin_wait_1()); + } + + #[test] + fn test_established_abort() { + let mut s = socket_established(); + s.abort(); + assert_eq!(s.state, State::Closed); + recv!(s, [TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + } + + // =========================================================================================// + // Tests for the FIN-WAIT-1 state. + // =========================================================================================// + + #[test] + fn test_fin_wait_1_fin_ack() { + let mut s = socket_fin_wait_1(); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::FinWait2); + sanity!(s, socket_fin_wait_2()); + } + + #[test] + fn test_fin_wait_1_fin_fin() { + let mut s = socket_fin_wait_1(); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closing); + sanity!(s, socket_closing()); + } + + #[test] + fn test_fin_wait_1_fin_with_data_queued() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef123456").unwrap(); + s.close(); + recv!(s, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::FinWait1); + } + + #[test] + fn test_fin_wait_1_close() { + let mut s = socket_fin_wait_1(); + s.close(); + assert_eq!(s.state, State::FinWait1); + } + + // =========================================================================================// + // Tests for the FIN-WAIT-2 state. + // =========================================================================================// + + #[test] + fn test_fin_wait_2_fin() { + let mut s = socket_fin_wait_2(); + send!(s, time 1_000, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + sanity!(s, socket_time_wait(false)); + } + + #[test] + fn test_fin_wait_2_close() { + let mut s = socket_fin_wait_2(); + s.close(); + assert_eq!(s.state, State::FinWait2); + } + + // =========================================================================================// + // Tests for the CLOSING state. + // =========================================================================================// + + #[test] + fn test_closing_ack_fin() { + let mut s = socket_closing(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + send!(s, time 1_000, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + sanity!(s, socket_time_wait(true)); + } + + #[test] + fn test_closing_close() { + let mut s = socket_closing(); + s.close(); + assert_eq!(s.state, State::Closing); + } + + // =========================================================================================// + // Tests for the TIME-WAIT state. + // =========================================================================================// + + #[test] + fn test_time_wait_from_fin_wait_2_ack() { + let mut s = socket_time_wait(false); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + } + + #[test] + fn test_time_wait_from_closing_no_ack() { + let mut s = socket_time_wait(true); + recv!(s, []); + } + + #[test] + fn test_time_wait_close() { + let mut s = socket_time_wait(false); + s.close(); + assert_eq!(s.state, State::TimeWait); + } + + #[test] + fn test_time_wait_retransmit() { + let mut s = socket_time_wait(false); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + send!(s, time 5_000, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }))); + assert_eq!(s.timer, Timer::Close { expires_at: 5_000 + CLOSE_DELAY }); + } + + #[test] + fn test_time_wait_timeout() { + let mut s = socket_time_wait(false); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::TimeWait); + recv!(s, time 60_000, Err(Error::Exhausted)); + assert_eq!(s.state, State::Closed); + } + + // =========================================================================================// + // Tests for the CLOSE-WAIT state. + // =========================================================================================// + + #[test] + fn test_close_wait_ack() { + let mut s = socket_close_wait(); + s.send_slice(b"abcdef").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + }); + } + + #[test] + fn test_close_wait_close() { + let mut s = socket_close_wait(); + s.close(); + assert_eq!(s.state, State::LastAck); + sanity!(s, socket_last_ack()); + } + + // =========================================================================================// + // Tests for the LAST-ACK state. + // =========================================================================================// + #[test] + fn test_last_ack_fin_ack() { + let mut s = socket_last_ack(); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::LastAck); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_close() { + let mut s = socket_last_ack(); + s.close(); + assert_eq!(s.state, State::LastAck); + } + + // =========================================================================================// + // Tests for transitioning through multiple states. + // =========================================================================================// + + #[test] + fn test_listen() { + let mut s = socket(); + s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT)).unwrap(); + assert_eq!(s.state, State::Listen); + } + + #[test] + fn test_three_way_handshake() { + let mut s = socket_listen(); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + }); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.local_endpoint(), LOCAL_END); + assert_eq!(s.remote_endpoint(), REMOTE_END); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state(), State::Established); + assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + } + + #[test] + fn test_remote_close() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::CloseWait); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + s.close(); + assert_eq!(s.state, State::LastAck); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_local_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::FinWait2); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + } + + #[test] + fn test_simultaneous_close() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { // due to reordering, this is logically located... + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::Closing); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + // ... at this point + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); + } + + #[test] + fn test_simultaneous_close_combined_fin_ack() { + let mut s = socket_established(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::TimeWait); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + } + + #[test] + fn test_fin_with_data() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]) + } + + #[test] + fn test_mutual_close_with_data_1() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + } + + #[test] + fn test_mutual_close_with_data_2() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!(s, [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state, State::FinWait2); + send!(s, TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }]); + assert_eq!(s.state, State::TimeWait); + } + + // =========================================================================================// + // Tests for retransmission on packet loss. + // =========================================================================================// + + #[test] + fn test_duplicate_seq_ack() { + let mut s = socket_recved(); + // remote retransmission + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }))); + } + + #[test] + fn test_data_retransmit() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + recv!(s, time 1050, Err(Error::Exhausted)); + recv!(s, time 1100, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_data_retransmit_bursts() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef012345").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + s.remote_win_len = 6; + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + s.remote_win_len = 6; + recv!(s, time 0, Err(Error::Exhausted)); + + recv!(s, time 50, Err(Error::Exhausted)); + + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + s.remote_win_len = 6; + recv!(s, time 150, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + s.remote_win_len = 6; + recv!(s, time 200, Err(Error::Exhausted)); + } + + #[test] + fn test_send_data_after_syn_ack_retransmit() { + let mut s = socket_syn_received(); + recv!(s, time 50, Ok(TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + })); + recv!(s, time 150, Ok(TcpRepr { // retransmit + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + })); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.state(), State::Established); + s.send_slice(b"abcdef").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }]) + } + + #[test] + fn test_established_retransmit_for_dup_ack() { + let mut s = socket_established(); + // Duplicate ACKs do not replace the retransmission timer + s.send_slice(b"abc").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + // Retransmit timer is on because all data was sent + assert_eq!(s.tx_buffer.len(), 3); + // ACK nothing new + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + // Retransmit + recv!(s, time 4000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_retransmit_reset_after_ack() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_established_queue_during_retransmission() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef123456ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); // this one is dropped + recv!(s, time 1005, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); // this one is received + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); // also dropped + recv!(s, time 2000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); // retransmission + send!(s, time 2005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + }); // acknowledgement of both segments + recv!(s, time 2010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); // retransmission of only unacknowledged data + } + + #[test] + fn test_close_wait_retransmit_reset_after_ack() { + let mut s = socket_close_wait(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + #[test] + fn test_fin_wait_1_retransmit_reset_after_ack() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + s.send_slice(b"ABCDEF").unwrap(); + s.close(); + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + send!(s, time 1005, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1010, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + })); + send!(s, time 1015, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + window_len: 6, + ..SEND_TEMPL + }); + recv!(s, time 1020, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ABCDEF"[..], + ..RECV_TEMPL + })); + } + + // =========================================================================================// + // Tests for window management. + // =========================================================================================// + + #[test] + fn test_maximum_segment_size() { + let mut s = socket_listen(); + s.tx_buffer = SocketBuffer::new(vec![0; 32767]); + send!(s, TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + max_seg_size: Some(1000), + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 32767, + ..SEND_TEMPL + }); + s.send_slice(&[0; 1200][..]).unwrap(); + recv!(s, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0; 1000][..], + ..RECV_TEMPL + })); + } + + // =========================================================================================// + // Tests for flow control. + // =========================================================================================// + + #[test] + fn test_psh_transmit() { + let mut s = socket_established(); + s.remote_win_len = 6; + s.send_slice(b"abcdef").unwrap(); + s.send_slice(b"123456").unwrap(); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + }), exact); + } + + #[test] + fn test_psh_receive() { + let mut s = socket_established(); + send!(s, TcpRepr { + control: TcpControl::Psh, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }]); + } + + #[test] + fn test_zero_window_ack() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(s.rx_buffer.capacity()); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }]); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }))); + } + + #[test] + fn test_zero_window_ack_on_window_growth() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(s.rx_buffer.capacity()); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }]); + recv!(s, time 0, Err(Error::Exhausted)); + s.recv(|buffer| { + assert_eq!(&buffer[..3], b"abc"); + (3, ()) + }).unwrap(); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 3, + ..RECV_TEMPL + })); + recv!(s, time 0, Err(Error::Exhausted)); + s.recv(|buffer| { + assert_eq!(buffer, b"def"); + (buffer.len(), ()) + }).unwrap(); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 6, + ..RECV_TEMPL + })); + } + + #[test] + fn test_fill_peer_window() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef123456!@#$%^").unwrap(); + recv!(s, [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }, TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + }, TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"!@#$%^"[..], + ..RECV_TEMPL + }]); + } + + // =========================================================================================// + // Tests for timeouts. + // =========================================================================================// + + #[test] + fn test_listen_timeout() { + let mut s = socket_listen(); + s.set_timeout(Some(100)); + assert_eq!(s.poll_at(), None); + } + + #[test] + fn test_connect_timeout() { + let mut s = socket(); + s.local_seq_no = LOCAL_SEQ; + s.connect(REMOTE_END, LOCAL_END.port).unwrap(); + s.set_timeout(Some(100)); + recv!(s, time 150, Ok(TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::SynSent); + assert_eq!(s.poll_at(), Some(250)); + recv!(s, time 250, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(TcpSeqNumber(0)), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_timeout() { + let mut s = socket_established(); + s.set_timeout(Some(200)); + recv!(s, time 250, Err(Error::Exhausted)); + assert_eq!(s.poll_at(), Some(450)); + s.send_slice(b"abcdef").unwrap(); + assert_eq!(s.poll_at(), Some(0)); + recv!(s, time 255, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(355)); + recv!(s, time 355, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(455)); + recv!(s, time 500, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_established_keep_alive_timeout() { + let mut s = socket_established(); + s.set_keep_alive(Some(50)); + s.set_timeout(Some(100)); + recv!(s, time 100, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + recv!(s, time 100, Err(Error::Exhausted)); + assert_eq!(s.poll_at(), Some(150)); + send!(s, time 105, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.poll_at(), Some(155)); + recv!(s, time 155, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + recv!(s, time 155, Err(Error::Exhausted)); + assert_eq!(s.poll_at(), Some(205)); + recv!(s, time 200, Err(Error::Exhausted)); + recv!(s, time 205, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + recv!(s, time 205, Err(Error::Exhausted)); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_fin_wait_1_timeout() { + let mut s = socket_fin_wait_1(); + s.set_timeout(Some(200)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(200)); + recv!(s, time 400, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_timeout() { + let mut s = socket_last_ack(); + s.set_timeout(Some(200)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), Some(200)); + recv!(s, time 400, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + })); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_closed_timeout() { + let mut s = socket_established(); + s.set_timeout(Some(200)); + s.remote_last_ts = Some(100); + s.abort(); + assert_eq!(s.poll_at(), Some(0)); + recv!(s, time 100, Ok(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + })); + assert_eq!(s.poll_at(), None); + } + + // =========================================================================================// + // Tests for keep-alive. + // =========================================================================================// + + #[test] + fn test_responds_to_keep_alive() { + let mut s = socket_established(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }))); + } + + #[test] + fn test_sends_keep_alive() { + let mut s = socket_established(); + s.set_keep_alive(Some(100)); + + // drain the forced keep-alive packet + assert_eq!(s.poll_at(), Some(0)); + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + assert_eq!(s.poll_at(), Some(100)); + recv!(s, time 95, Err(Error::Exhausted)); + recv!(s, time 100, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + assert_eq!(s.poll_at(), Some(200)); + recv!(s, time 195, Err(Error::Exhausted)); + recv!(s, time 200, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0], + ..RECV_TEMPL + })); + + send!(s, time 250, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + assert_eq!(s.poll_at(), Some(350)); + recv!(s, time 345, Err(Error::Exhausted)); + recv!(s, time 350, Ok(TcpRepr { + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"\x00"[..], + ..RECV_TEMPL + })); + } + + // =========================================================================================// + // Tests for time-to-live configuration. + // =========================================================================================// + + #[test] + fn test_set_hop_limit() { + let mut s = socket_syn_received(); + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1520; + + s.set_hop_limit(Some(0x2a)); + assert_eq!(s.dispatch(0, &caps, |(ip_repr, _)| { + assert_eq!(ip_repr.hop_limit(), 0x2a); + Ok(()) + }), Ok(())); + } + + #[test] + #[should_panic(expected = "the time-to-live value of a packet must not be zero")] + fn test_set_hop_limit_zero() { + let mut s = socket_syn_received(); + s.set_hop_limit(Some(0)); + } + + // =========================================================================================// + // Tests for reassembly. + // =========================================================================================// + + #[test] + fn test_out_of_order() { + let mut s = socket_established(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }))); + s.recv(|buffer| { + assert_eq!(buffer, b""); + (buffer.len(), ()) + }).unwrap(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, Ok(Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }))); + s.recv(|buffer| { + assert_eq!(buffer, b"abcdef"); + (buffer.len(), ()) + }).unwrap(); + } + + #[test] + fn test_buffer_wraparound_rx() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(s.rx_buffer.capacity()); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + }); + s.recv(|buffer| { + assert_eq!(buffer, b"abc"); + (buffer.len(), ()) + }).unwrap(); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"defghi"[..], + ..SEND_TEMPL + }); + let mut data = [0; 6]; + assert_eq!(s.recv_slice(&mut data[..]), Ok(6)); + assert_eq!(data, &b"defghi"[..]); + } + + #[test] + fn test_buffer_wraparound_tx() { + let mut s = socket_established(); + s.tx_buffer = SocketBuffer::new(vec![0; 6]); + assert_eq!(s.send_slice(b"abc"), Ok(3)); + recv!(s, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abc"[..], + ..RECV_TEMPL + })); + send!(s, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 3), + ..SEND_TEMPL + }); + assert_eq!(s.send_slice(b"defghi"), Ok(6)); + recv!(s, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 3, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"def"[..], + ..RECV_TEMPL + })); + // "defghi" not contiguous in tx buffer + recv!(s, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 3 + 3, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ghi"[..], + ..RECV_TEMPL + })); + } + + // =========================================================================================// + // Tests for packet filtering. + // =========================================================================================// + + #[test] + fn test_doesnt_accept_wrong_port() { + let mut s = socket_established(); + s.rx_buffer = SocketBuffer::new(vec![0; 6]); + s.assembler = Assembler::new(s.rx_buffer.capacity()); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + dst_port: LOCAL_PORT + 1, + ..SEND_TEMPL + }; + assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + src_port: REMOTE_PORT + 1, + ..SEND_TEMPL + }; + assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + } + + #[test] + fn test_doesnt_accept_wrong_ip() { + let s = socket_established(); + + let tcp_repr = TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }; + + let ip_repr = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_2, + dst_addr: MOCK_IP_ADDR_1, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64 + }; + assert!(s.accepts(&ip_repr, &tcp_repr)); + + let ip_repr_wrong_src = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_3, + dst_addr: MOCK_IP_ADDR_1, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64 + }; + assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr)); + + let ip_repr_wrong_dst = IpRepr::Unspecified { + src_addr: MOCK_IP_ADDR_2, + dst_addr: MOCK_IP_ADDR_3, + protocol: IpProtocol::Tcp, + payload_len: tcp_repr.buffer_len(), + hop_limit: 64 + }; + assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr)); + } +} +]]