zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit 13068da43e8fbd0ae5d03aff27fb4e8802e1218c (tree)
parent 2ab588049e41a96337a4fa8c2d9507320bc4278b
Author: lithdew <kenta@lithdew.net>
Date:   Fri, 30 Apr 2021 21:08:49 +0900

x/os, x/net: re-approach `Address`, rename namespace `TCP -> tcp`

Address comments from @ifreund and @MasterQ32 to address unsafeness and
ergonomics of the `Address` API.

Rename the `TCP` namespace to `tcp` as it does not contain any
top-level fields.

Fix missing reference to `sockaddr` which was identified by @kprotty in
os/bits/linux/arm64.zig.

Diffstat:
Mlib/std/builtin.zig | 18+-----------------
Mlib/std/compress/deflate.zig | 6++----
Mlib/std/os/bits/linux/arm64.zig | 1+
Mlib/std/x.zig | 16+++++++++++++++-
Dlib/std/x/net/TCP.zig | 399-------------------------------------------------------------------------------
Alib/std/x/net/tcp.zig | 383+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mlib/std/x/os/Socket.zig | 128++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------
Mlib/std/x/os/net.zig | 18++++++++++++++++++
8 files changed, 534 insertions(+), 435 deletions(-)

diff --git a/lib/std/builtin.zig b/lib/std/builtin.zig @@ -150,23 +150,7 @@ pub const Mode = enum { /// This data structure is used by the Zig language code generation and /// therefore must be kept in sync with the compiler implementation. -pub const CallingConvention = enum { - Unspecified, - C, - Naked, - Async, - Inline, - Interrupt, - Signal, - Stdcall, - Fastcall, - Vectorcall, - Thiscall, - APCS, - AAPCS, - AAPCSVFP, - SysV -}; +pub const CallingConvention = enum { Unspecified, C, Naked, Async, Inline, Interrupt, Signal, Stdcall, Fastcall, Vectorcall, Thiscall, APCS, AAPCS, AAPCSVFP, SysV }; /// This data structure is used by the Zig language code generation and /// therefore must be kept in sync with the compiler implementation. diff --git a/lib/std/compress/deflate.zig b/lib/std/compress/deflate.zig @@ -662,14 +662,12 @@ test "lengths overflow" { // malformed final dynamic block, tries to write 321 code lengths (MAXCODES is 316) // f dy hlit hdist hclen 16 17 18 0 (18) x138 (18) x138 (18) x39 (16) x6 // 1 10 11101 11101 0000 010 010 010 010 (11) 1111111 (11) 1111111 (11) 0011100 (01) 11 - const stream = [_]u8{ - 0b11101101, 0b00011101, 0b00100100, 0b11101001, 0b11111111, 0b11111111, 0b00111001, 0b00001110 - }; + const stream = [_]u8{ 0b11101101, 0b00011101, 0b00100100, 0b11101001, 0b11111111, 0b11111111, 0b00111001, 0b00001110 }; const reader = std.io.fixedBufferStream(&stream).reader(); var window: [0x8000]u8 = undefined; var inflate = inflateStream(reader, &window); var buf: [1]u8 = undefined; - std.testing.expectError(error.InvalidLength, inflate.read(&buf)); + std.testing.expectError(error.InvalidLength, inflate.read(&buf)); } diff --git a/lib/std/os/bits/linux/arm64.zig b/lib/std/os/bits/linux/arm64.zig @@ -9,6 +9,7 @@ const std = @import("../../../std.zig"); const linux = std.os.linux; const socklen_t = linux.socklen_t; +const sockaddr = linux.sockaddr; const iovec = linux.iovec; const iovec_const = linux.iovec_const; const uid_t = linux.uid_t; diff --git a/lib/std/x.zig b/lib/std/x.zig @@ -1,8 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + +const std = @import("std.zig"); + pub const os = struct { pub const Socket = @import("x/os/Socket.zig"); pub usingnamespace @import("x/os/net.zig"); }; pub const net = struct { - pub const TCP = @import("x/net/TCP.zig"); + pub const tcp = @import("x/net/tcp.zig"); }; + +test { + inline for (.{ os, net }) |module| { + std.testing.refAllDecls(module); + } +} diff --git a/lib/std/x/net/TCP.zig b/lib/std/x/net/TCP.zig @@ -1,399 +0,0 @@ -const std = @import("../../std.zig"); - -const os = std.os; -const fmt = std.fmt; -const mem = std.mem; -const testing = std.testing; - -const IPv4 = std.x.os.IPv4; -const IPv6 = std.x.os.IPv6; -const Socket = std.x.os.Socket; - -/// A generic TCP socket abstraction. -const TCP = @This(); - -/// A TCP client-address pair. -pub const Connection = struct { - client: TCP.Client, - address: TCP.Address, - - /// Enclose a TCP client and address into a client-address pair. - pub fn from(socket: Socket, address: TCP.Address) Connection { - return .{ .client = TCP.Client.from(socket), .address = address }; - } - - /// Closes the underlying client of the connection. - pub fn deinit(self: TCP.Connection) void { - self.client.deinit(); - } -}; - -/// Possible domains that a TCP client/listener may operate over. -pub const Domain = extern enum(u16) { - ip = os.AF_INET, - ipv6 = os.AF_INET6, -}; - -/// A TCP client. -pub const Client = struct { - socket: Socket, - - /// Opens a new client. - pub fn init(domain: TCP.Domain, flags: u32) !Client { - return Client{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK_STREAM | flags, - os.IPPROTO_TCP, - ), - }; - } - - /// Enclose a TCP client over an existing socket. - pub fn from(socket: Socket) Client { - return Client{ .socket = socket }; - } - - /// Closes the client. - pub fn deinit(self: Client) void { - self.socket.deinit(); - } - - /// Shutdown either the read side, write side, or all sides of the client's underlying socket. - pub fn shutdown(self: Client, how: os.ShutdownHow) !void { - return self.socket.shutdown(how); - } - - /// Have the client attempt to the connect to an address. - pub fn connect(self: Client, address: TCP.Address) !void { - return self.socket.connect(TCP.Address, address); - } - - /// Read data from the socket into the buffer provided. It returns the - /// number of bytes read into the buffer provided. - pub fn read(self: Client, buf: []u8) !usize { - return self.socket.read(buf); - } - - /// Read data from the socket into the buffer provided with a set of flags - /// specified. It returns the number of bytes read into the buffer provided. - pub fn recv(self: Client, buf: []u8, flags: u32) !usize { - return self.socket.recv(buf, flags); - } - - /// Write a buffer of data provided to the socket. It returns the number - /// of bytes that are written to the socket. - pub fn write(self: Client, buf: []const u8) !usize { - return self.socket.write(buf); - } - - /// Writes multiple I/O vectors to the socket. It returns the number - /// of bytes that are written to the socket. - pub fn writev(self: Client, buffers: []const os.iovec_const) !usize { - return self.socket.writev(buffers); - } - - /// Write a buffer of data provided to the socket with a set of flags specified. - /// It returns the number of bytes that are written to the socket. - pub fn send(self: Client, buf: []const u8, flags: u32) !usize { - return self.socket.send(buf, flags); - } - - /// Writes multiple I/O vectors with a prepended message header to the socket - /// with a set of flags specified. It returns the number of bytes that are - /// written to the socket. - pub fn sendmsg(self: Client, msg: os.msghdr_const, flags: u32) !usize { - return self.socket.sendmsg(msg, flags); - } - - /// Query and return the latest cached error on the client's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the read buffer size of the client's underlying socket. - pub fn getReadBufferSize(self: Client) !u32 { - return self.socket.getReadBufferSize(); - } - - /// Query the write buffer size of the client's underlying socket. - pub fn getWriteBufferSize(self: Client) !u32 { - return self.socket.getWriteBufferSize(); - } - - /// Query the address that the client's socket is locally bounded to. - pub fn getLocalAddress(self: Client) !TCP.Address { - return self.socket.getLocalAddress(TCP.Address); - } - - /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets disabling Nagle's algorithm. - pub fn setNoDelay(self: Client, enabled: bool) !void { - if (comptime @hasDecl(os, "TCP_NODELAY")) { - const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled))); - return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_NODELAY, bytes); - } - return error.UnsupportedSocketOption; - } - - /// Set the write buffer size of the socket. - pub fn setWriteBufferSize(self: Client, size: u32) !void { - return self.socket.setWriteBufferSize(size); - } - - /// Set the read buffer size of the socket. - pub fn setReadBufferSize(self: Client, size: u32) !void { - return self.socket.setReadBufferSize(size); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully written - /// to its bound destination after a specified number of milliseconds. A subsequent write - /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setWriteTimeout(self: Client, milliseconds: usize) !void { - return self.socket.setWriteTimeout(milliseconds); - } - - /// Set a timeout on the socket that is to occur if no messages are successfully read - /// from its bound destination after a specified number of milliseconds. A subsequent - /// read from the socket will thereafter return `error.WouldBlock` should the timeout be - /// exceeded. - pub fn setReadTimeout(self: Client, milliseconds: usize) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -/// A TCP listener. -pub const Listener = struct { - socket: Socket, - - /// Opens a new listener. - pub fn init(domain: TCP.Domain, flags: u32) !Listener { - return Listener{ - .socket = try Socket.init( - @enumToInt(domain), - os.SOCK_STREAM | flags, - os.IPPROTO_TCP, - ), - }; - } - - /// Closes the listener. - pub fn deinit(self: Listener) void { - self.socket.deinit(); - } - - /// Shuts down the underlying listener's socket. The next subsequent call, or - /// a current pending call to accept() after shutdown is called will return - /// an error. - pub fn shutdown(self: Listener) !void { - return self.socket.shutdown(.recv); - } - - /// Binds the listener's socket to an address. - pub fn bind(self: Listener, address: TCP.Address) !void { - return self.socket.bind(TCP.Address, address); - } - - /// Start listening for incoming connections. - pub fn listen(self: Listener, max_backlog_size: u31) !void { - return self.socket.listen(max_backlog_size); - } - - /// Accept a pending incoming connection queued to the kernel backlog - /// of the listener's socket. - pub fn accept(self: Listener, flags: u32) !TCP.Connection { - return self.socket.accept(TCP.Connection, TCP.Address, flags); - } - - /// Query and return the latest cached error on the listener's underlying socket. - pub fn getError(self: Client) !void { - return self.socket.getError(); - } - - /// Query the address that the listener's socket is locally bounded to. - pub fn getLocalAddress(self: Listener) !TCP.Address { - return self.socket.getLocalAddress(TCP.Address); - } - - /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if - /// the host does not support sockets listening the same address. - pub fn setReuseAddress(self: Listener, enabled: bool) !void { - return self.socket.setReuseAddress(enabled); - } - - /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if - /// the host does not supports sockets listening on the same port. - pub fn setReusePort(self: Listener, enabled: bool) !void { - return self.socket.setReusePort(enabled); - } - - /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not - /// support TCP Fast Open. - pub fn setFastOpen(self: Listener, enabled: bool) !void { - if (comptime @hasDecl(os, "TCP_FASTOPEN")) { - return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_FASTOPEN, mem.asBytes(&@as(usize, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns - /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK. - pub fn setQuickACK(self: Listener, enabled: bool) !void { - if (comptime @hasDecl(os, "TCP_QUICKACK")) { - return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_QUICKACK, mem.asBytes(&@as(usize, @boolToInt(enabled)))); - } - return error.UnsupportedSocketOption; - } - - /// Set a timeout on the listener that is to occur if no new incoming connections come in - /// after a specified number of milliseconds. A subsequent accept call to the listener - /// will thereafter return `error.WouldBlock` should the timeout be exceeded. - pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void { - return self.socket.setReadTimeout(milliseconds); - } -}; - -/// A TCP socket address designated by a host IP and port. A TCP socket -/// address comprises of 28 bytes. It may freely be used in place of -/// `sockaddr` when working with socket syscalls. -/// -/// It is not recommended to touch the fields of an `Address`, but to -/// instead make use of its available accessor methods. -pub const Address = extern struct { - family: u16, - port: u16, - host: extern union { - ipv4: extern struct { - address: IPv4, - }, - ipv6: extern struct { - flow_info: u32 = 0, - address: IPv6, - }, - }, - - /// Instantiate a new TCP address with a IPv4 host and port. - pub fn initIPv4(host: IPv4, port: u16) Address { - return Address{ - .family = os.AF_INET, - .port = mem.nativeToBig(u16, port), - .host = .{ - .ipv4 = .{ - .address = host, - }, - }, - }; - } - - /// Instantiate a new TCP address with a IPv6 host and port. - pub fn initIPv6(host: IPv6, port: u16) Address { - return Address{ - .family = os.AF_INET6, - .port = mem.nativeToBig(u16, port), - .host = .{ - .ipv6 = .{ - .address = host, - }, - }, - }; - } - - /// Extract the host of the address. - pub fn getHost(self: Address) union(enum) { v4: IPv4, v6: IPv6 } { - return switch (self.family) { - os.AF_INET => .{ .v4 = self.host.ipv4.address }, - os.AF_INET6 => .{ .v6 = self.host.ipv6.address }, - else => unreachable, - }; - } - - /// Extract the port of the address. - pub fn getPort(self: Address) u16 { - return mem.nativeToBig(u16, self.port); - } - - /// Set the port of the address. - pub fn setPort(self: *Address, port: u16) void { - self.port = mem.nativeToBig(u16, port); - } - - /// Implements the `std.fmt.format` API. - pub fn format( - self: Address, - comptime layout: []const u8, - opts: fmt.FormatOptions, - writer: anytype, - ) !void { - switch (self.getHost()) { - .v4 => |host| try fmt.format(writer, "{}:{}", .{ host, self.getPort() }), - .v6 => |host| try fmt.format(writer, "{}:{}", .{ host, self.getPort() }), - } - } -}; - -test { - testing.refAllDecls(@This()); -} - -test "tcp: create non-blocking pair" { - const a = try TCP.Listener.init(.ip, os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); - defer a.deinit(); - - try a.bind(TCP.Address.initIPv4(IPv4.unspecified, 0)); - try a.listen(128); - - const binded_address = try a.getLocalAddress(); - - const b = try TCP.Client.init(.ip, os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); - defer b.deinit(); - - testing.expectError(error.WouldBlock, b.connect(binded_address)); - try b.getError(); - - const ab = try a.accept(os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); - defer ab.deinit(); -} - -test "tcp/client: set read timeout of 1 millisecond on blocking client" { - const a = try TCP.Listener.init(.ip, os.SOCK_CLOEXEC); - defer a.deinit(); - - try a.bind(TCP.Address.initIPv4(IPv4.unspecified, 0)); - try a.listen(128); - - const binded_address = try a.getLocalAddress(); - - const b = try TCP.Client.init(.ip, os.SOCK_CLOEXEC); - defer b.deinit(); - - try b.connect(binded_address); - try b.setReadTimeout(1); - - const ab = try a.accept(os.SOCK_CLOEXEC); - defer ab.deinit(); - - var buf: [1]u8 = undefined; - testing.expectError(error.WouldBlock, b.read(&buf)); -} - -test "tcp/listener: bind to unspecified ipv4 address" { - const socket = try TCP.Listener.init(.ip, os.SOCK_CLOEXEC); - defer socket.deinit(); - - try socket.bind(TCP.Address.initIPv4(IPv4.unspecified, 0)); - try socket.listen(128); - - const address = try socket.getLocalAddress(); - testing.expect(address.getHost() == .v4); -} - -test "tcp/listener: bind to unspecified ipv6 address" { - const socket = try TCP.Listener.init(.ipv6, os.SOCK_CLOEXEC); - defer socket.deinit(); - - try socket.bind(TCP.Address.initIPv6(IPv6.unspecified, 0)); - try socket.listen(128); - - const address = try socket.getLocalAddress(); - testing.expect(address.getHost() == .v6); -} diff --git a/lib/std/x/net/tcp.zig b/lib/std/x/net/tcp.zig @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + +const std = @import("../../std.zig"); + +const os = std.os; +const fmt = std.fmt; +const mem = std.mem; +const testing = std.testing; + +const IPv4 = std.x.os.IPv4; +const IPv6 = std.x.os.IPv6; +const Socket = std.x.os.Socket; + +/// A generic TCP socket abstraction. +const tcp = @This(); + +/// A union of all eligible types of socket addresses over TCP. +pub const Address = union(enum) { + ipv4: IPv4.Address, + ipv6: IPv6.Address, + + /// Instantiate a new address with a IPv4 host and port. + pub fn initIPv4(host: IPv4, port: u16) Address { + return .{ .ipv4 = .{ .host = host, .port = port } }; + } + + /// Instantiate a new address with a IPv6 host and port. + pub fn initIPv6(host: IPv6, port: u16) Address { + return .{ .ipv6 = .{ .host = host, .port = port } }; + } + + /// Re-interpret a generic socket address into a TCP socket address. + pub fn from(address: Socket.Address) tcp.Address { + return switch (address) { + .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, + .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, + }; + } + + /// Re-interpret a TCP socket address into a generic socket address. + pub fn into(self: tcp.Address) Socket.Address { + return switch (self) { + .ipv4 => |ipv4_address| .{ .ipv4 = ipv4_address }, + .ipv6 => |ipv6_address| .{ .ipv6 = ipv6_address }, + }; + } + + /// Implements the `std.fmt.format` API. + pub fn format( + self: tcp.Address, + comptime layout: []const u8, + opts: fmt.FormatOptions, + writer: anytype, + ) !void { + switch (self) { + .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), + .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), + } + } +}; + +/// A TCP client-address pair. +pub const Connection = struct { + client: tcp.Client, + address: tcp.Address, + + /// Enclose a TCP client and address into a client-address pair. + pub fn from(conn: Socket.Connection) tcp.Connection { + return .{ + .client = tcp.Client.from(conn.socket), + .address = tcp.Address.from(conn.address), + }; + } + + /// Unravel a TCP client-address pair into a socket-address pair. + pub fn into(self: tcp.Connection) Socket.Connection { + return .{ + .socket = self.client.socket, + .address = self.address.into(), + }; + } + + /// Closes the underlying client of the connection. + pub fn deinit(self: tcp.Connection) void { + self.client.deinit(); + } +}; + +/// Possible domains that a TCP client/listener may operate over. +pub const Domain = extern enum(u16) { + ip = os.AF_INET, + ipv6 = os.AF_INET6, +}; + +/// A TCP client. +pub const Client = struct { + socket: Socket, + + /// Opens a new client. + pub fn init(domain: tcp.Domain, flags: u32) !Client { + return Client{ + .socket = try Socket.init( + @enumToInt(domain), + os.SOCK_STREAM | flags, + os.IPPROTO_TCP, + ), + }; + } + + /// Enclose a TCP client over an existing socket. + pub fn from(socket: Socket) Client { + return Client{ .socket = socket }; + } + + /// Closes the client. + pub fn deinit(self: Client) void { + self.socket.deinit(); + } + + /// Shutdown either the read side, write side, or all sides of the client's underlying socket. + pub fn shutdown(self: Client, how: os.ShutdownHow) !void { + return self.socket.shutdown(how); + } + + /// Have the client attempt to the connect to an address. + pub fn connect(self: Client, address: tcp.Address) !void { + return self.socket.connect(address.into()); + } + + /// Read data from the socket into the buffer provided. It returns the + /// number of bytes read into the buffer provided. + pub fn read(self: Client, buf: []u8) !usize { + return self.socket.read(buf); + } + + /// Read data from the socket into the buffer provided with a set of flags + /// specified. It returns the number of bytes read into the buffer provided. + pub fn recv(self: Client, buf: []u8, flags: u32) !usize { + return self.socket.recv(buf, flags); + } + + /// Write a buffer of data provided to the socket. It returns the number + /// of bytes that are written to the socket. + pub fn write(self: Client, buf: []const u8) !usize { + return self.socket.write(buf); + } + + /// Writes multiple I/O vectors to the socket. It returns the number + /// of bytes that are written to the socket. + pub fn writev(self: Client, buffers: []const os.iovec_const) !usize { + return self.socket.writev(buffers); + } + + /// Write a buffer of data provided to the socket with a set of flags specified. + /// It returns the number of bytes that are written to the socket. + pub fn send(self: Client, buf: []const u8, flags: u32) !usize { + return self.socket.send(buf, flags); + } + + /// Writes multiple I/O vectors with a prepended message header to the socket + /// with a set of flags specified. It returns the number of bytes that are + /// written to the socket. + pub fn sendmsg(self: Client, msg: os.msghdr_const, flags: u32) !usize { + return self.socket.sendmsg(msg, flags); + } + + /// Query and return the latest cached error on the client's underlying socket. + pub fn getError(self: Client) !void { + return self.socket.getError(); + } + + /// Query the read buffer size of the client's underlying socket. + pub fn getReadBufferSize(self: Client) !u32 { + return self.socket.getReadBufferSize(); + } + + /// Query the write buffer size of the client's underlying socket. + pub fn getWriteBufferSize(self: Client) !u32 { + return self.socket.getWriteBufferSize(); + } + + /// Query the address that the client's socket is locally bounded to. + pub fn getLocalAddress(self: Client) !tcp.Address { + return tcp.Address.from(try self.socket.getLocalAddress()); + } + + /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if + /// the host does not support sockets disabling Nagle's algorithm. + pub fn setNoDelay(self: Client, enabled: bool) !void { + if (comptime @hasDecl(os, "TCP_NODELAY")) { + const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled))); + return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_NODELAY, bytes); + } + return error.UnsupportedSocketOption; + } + + /// Set the write buffer size of the socket. + pub fn setWriteBufferSize(self: Client, size: u32) !void { + return self.socket.setWriteBufferSize(size); + } + + /// Set the read buffer size of the socket. + pub fn setReadBufferSize(self: Client, size: u32) !void { + return self.socket.setReadBufferSize(size); + } + + /// Set a timeout on the socket that is to occur if no messages are successfully written + /// to its bound destination after a specified number of milliseconds. A subsequent write + /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded. + pub fn setWriteTimeout(self: Client, milliseconds: usize) !void { + return self.socket.setWriteTimeout(milliseconds); + } + + /// Set a timeout on the socket that is to occur if no messages are successfully read + /// from its bound destination after a specified number of milliseconds. A subsequent + /// read from the socket will thereafter return `error.WouldBlock` should the timeout be + /// exceeded. + pub fn setReadTimeout(self: Client, milliseconds: usize) !void { + return self.socket.setReadTimeout(milliseconds); + } +}; + +/// A TCP listener. +pub const Listener = struct { + socket: Socket, + + /// Opens a new listener. + pub fn init(domain: tcp.Domain, flags: u32) !Listener { + return Listener{ + .socket = try Socket.init( + @enumToInt(domain), + os.SOCK_STREAM | flags, + os.IPPROTO_TCP, + ), + }; + } + + /// Closes the listener. + pub fn deinit(self: Listener) void { + self.socket.deinit(); + } + + /// Shuts down the underlying listener's socket. The next subsequent call, or + /// a current pending call to accept() after shutdown is called will return + /// an error. + pub fn shutdown(self: Listener) !void { + return self.socket.shutdown(.recv); + } + + /// Binds the listener's socket to an address. + pub fn bind(self: Listener, address: tcp.Address) !void { + return self.socket.bind(address.into()); + } + + /// Start listening for incoming connections. + pub fn listen(self: Listener, max_backlog_size: u31) !void { + return self.socket.listen(max_backlog_size); + } + + /// Accept a pending incoming connection queued to the kernel backlog + /// of the listener's socket. + pub fn accept(self: Listener, flags: u32) !tcp.Connection { + return tcp.Connection.from(try self.socket.accept(flags)); + } + + /// Query and return the latest cached error on the listener's underlying socket. + pub fn getError(self: Client) !void { + return self.socket.getError(); + } + + /// Query the address that the listener's socket is locally bounded to. + pub fn getLocalAddress(self: Listener) !tcp.Address { + return tcp.Address.from(try self.socket.getLocalAddress()); + } + + /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if + /// the host does not support sockets listening the same address. + pub fn setReuseAddress(self: Listener, enabled: bool) !void { + return self.socket.setReuseAddress(enabled); + } + + /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if + /// the host does not supports sockets listening on the same port. + pub fn setReusePort(self: Listener, enabled: bool) !void { + return self.socket.setReusePort(enabled); + } + + /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not + /// support TCP Fast Open. + pub fn setFastOpen(self: Listener, enabled: bool) !void { + if (comptime @hasDecl(os, "TCP_FASTOPEN")) { + return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_FASTOPEN, mem.asBytes(&@as(usize, @boolToInt(enabled)))); + } + return error.UnsupportedSocketOption; + } + + /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns + /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK. + pub fn setQuickACK(self: Listener, enabled: bool) !void { + if (comptime @hasDecl(os, "TCP_QUICKACK")) { + return os.setsockopt(self.socket.fd, os.IPPROTO_TCP, os.TCP_QUICKACK, mem.asBytes(&@as(usize, @boolToInt(enabled)))); + } + return error.UnsupportedSocketOption; + } + + /// Set a timeout on the listener that is to occur if no new incoming connections come in + /// after a specified number of milliseconds. A subsequent accept call to the listener + /// will thereafter return `error.WouldBlock` should the timeout be exceeded. + pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void { + return self.socket.setReadTimeout(milliseconds); + } +}; + +test { + testing.refAllDecls(@This()); +} + +test "tcp: create non-blocking pair" { + const listener = try tcp.Listener.init(.ip, os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); + defer listener.deinit(); + + try listener.bind(tcp.Address.initIPv4(IPv4.unspecified, 0)); + try listener.listen(128); + + const binded_address = try listener.getLocalAddress(); + + const client = try tcp.Client.init(.ip, os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); + defer client.deinit(); + + testing.expectError(error.WouldBlock, client.connect(binded_address)); + try client.getError(); + + const conn = try listener.accept(os.SOCK_NONBLOCK | os.SOCK_CLOEXEC); + defer conn.deinit(); +} + +test "tcp/client: set read timeout of 1 millisecond on blocking client" { + const listener = try tcp.Listener.init(.ip, os.SOCK_CLOEXEC); + defer listener.deinit(); + + try listener.bind(tcp.Address.initIPv4(IPv4.unspecified, 0)); + try listener.listen(128); + + const binded_address = try listener.getLocalAddress(); + + const client = try tcp.Client.init(.ip, os.SOCK_CLOEXEC); + defer client.deinit(); + + try client.connect(binded_address); + try client.setReadTimeout(1); + + const conn = try listener.accept(os.SOCK_CLOEXEC); + defer conn.deinit(); + + var buf: [1]u8 = undefined; + testing.expectError(error.WouldBlock, client.read(&buf)); +} + +test "tcp/listener: bind to unspecified ipv4 address" { + const listener = try tcp.Listener.init(.ip, os.SOCK_CLOEXEC); + defer listener.deinit(); + + try listener.bind(tcp.Address.initIPv4(IPv4.unspecified, 0)); + try listener.listen(128); + + const address = try listener.getLocalAddress(); + testing.expect(address == .ipv4); +} + +test "tcp/listener: bind to unspecified ipv6 address" { + const listener = try tcp.Listener.init(.ipv6, os.SOCK_CLOEXEC); + defer listener.deinit(); + + try listener.bind(tcp.Address.initIPv6(IPv6.unspecified, 0)); + try listener.listen(128); + + const address = try listener.getLocalAddress(); + testing.expect(address == .ipv6); +} diff --git a/lib/std/x/os/Socket.zig b/lib/std/x/os/Socket.zig @@ -1,4 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + const std = @import("../../std.zig"); +const net = @import("net.zig"); const os = std.os; const mem = std.mem; @@ -7,6 +14,98 @@ const time = std.time; /// A generic socket abstraction. const Socket = @This(); +/// A socket-address pair. +pub const Connection = struct { + socket: Socket, + address: Socket.Address, + + /// Enclose a socket and address into a socket-address pair. + pub fn from(socket: Socket, address: Socket.Address) Socket.Connection { + return .{ .socket = socket, .address = address }; + } +}; + +/// A generic socket address abstraction. It is safe to directly access and modify +/// the fields of a `Socket.Address`. +pub const Address = union(enum) { + ipv4: net.IPv4.Address, + ipv6: net.IPv6.Address, + + /// Instantiate a new address with a IPv4 host and port. + pub fn initIPv4(host: net.IPv4, port: u16) Socket.Address { + return .{ .ipv4 = .{ .host = host, .port = port } }; + } + + /// Instantiate a new address with a IPv6 host and port. + pub fn initIPv6(host: net.IPv6, port: u16) Socket.Address { + return .{ .ipv6 = .{ .host = host, .port = port } }; + } + + /// Parses a `sockaddr` into a generic socket address. + pub fn fromNative(address: *align(4) const os.sockaddr) Socket.Address { + switch (address.family) { + os.AF_INET => { + const info = @ptrCast(*const os.sockaddr_in, address); + const host = net.IPv4{ .octets = @bitCast([4]u8, info.addr) }; + const port = mem.bigToNative(u16, info.port); + return Socket.Address.initIPv4(host, port); + }, + os.AF_INET6 => { + const info = @ptrCast(*const os.sockaddr_in6, address); + const host = net.IPv6{ .octets = info.addr, .scope_id = info.scope_id }; + const port = mem.bigToNative(u16, info.port); + return Socket.Address.initIPv6(host, port); + }, + else => unreachable, + } + } + + /// Encodes a generic socket address into an extern union that may be reliably + /// casted into a `sockaddr` which may be passed into socket syscalls. + pub fn toNative(self: Socket.Address) extern union { + ipv4: os.sockaddr_in, + ipv6: os.sockaddr_in6, + } { + return switch (self) { + .ipv4 => |address| .{ + .ipv4 = .{ + .addr = @bitCast(u32, address.host.octets), + .port = mem.nativeToBig(u16, address.port), + }, + }, + .ipv6 => |address| .{ + .ipv6 = .{ + .addr = address.host.octets, + .port = mem.nativeToBig(u16, address.port), + .scope_id = address.host.scope_id, + .flowinfo = 0, + }, + }, + }; + } + + /// Returns the number of bytes that make up the `sockaddr` equivalent to the address. + pub fn getNativeSize(self: Socket.Address) u32 { + return switch (self) { + .ipv4 => @sizeOf(os.sockaddr_in), + .ipv6 => @sizeOf(os.sockaddr_in6), + }; + } + + /// Implements the `std.fmt.format` API. + pub fn format( + self: Socket.Address, + comptime layout: []const u8, + opts: fmt.FormatOptions, + writer: anytype, + ) !void { + switch (self) { + .ipv4 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), + .ipv6 => |address| try fmt.format(writer, "{}:{}", .{ address.host, address.port }), + } + } +}; + /// The underlying handle of a socket. fd: os.socket_t, @@ -31,8 +130,8 @@ pub fn shutdown(self: Socket, how: os.ShutdownHow) !void { } /// Binds the socket to an address. -pub fn bind(self: Socket, comptime Address: type, address: Address) !void { - return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address), @sizeOf(Address)); +pub fn bind(self: Socket, address: Socket.Address) !void { + return os.bind(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); } /// Start listening for incoming connections on the socket. @@ -41,19 +140,20 @@ pub fn listen(self: Socket, max_backlog_size: u31) !void { } /// Have the socket attempt to the connect to an address. -pub fn connect(self: Socket, comptime Address: type, address: Address) !void { - return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address), @sizeOf(Address)); +pub fn connect(self: Socket, address: Socket.Address) !void { + return os.connect(self.fd, @ptrCast(*const os.sockaddr, &address.toNative()), address.getNativeSize()); } /// Accept a pending incoming connection queued to the kernel backlog /// of the socket. -pub fn accept(self: Socket, comptime Connection: type, comptime Address: type, flags: u32) !Connection { - var address: Address = undefined; - var address_len: u32 = @sizeOf(Address); +pub fn accept(self: Socket, flags: u32) !Socket.Connection { + var address: os.sockaddr = undefined; + var address_len: u32 = @sizeOf(os.sockaddr); - const fd = try os.accept(self.fd, @ptrCast(*os.sockaddr, &address), &address_len, flags); + const socket = Socket{ .fd = try os.accept(self.fd, &address, &address_len, flags) }; + const socket_address = Socket.Address.fromNative(@alignCast(4, &address)); - return Connection.from(.{ .fd = fd }, address); + return Socket.Connection.from(socket, socket_address); } /// Read data from the socket into the buffer provided. It returns the @@ -94,11 +194,11 @@ pub fn sendmsg(self: Socket, msg: os.msghdr_const, flags: u32) !usize { } /// Query the address that the socket is locally bounded to. -pub fn getLocalAddress(self: Socket, comptime Address: type) !Address { - var address: Address = undefined; - var address_len: u32 = @sizeOf(Address); - try os.getsockname(self.fd, @ptrCast(*os.sockaddr, &address), &address_len); - return address; +pub fn getLocalAddress(self: Socket) !Socket.Address { + var address: os.sockaddr = undefined; + var address_len: u32 = @sizeOf(os.sockaddr); + try os.getsockname(self.fd, &address, &address_len); + return Socket.Address.fromNative(@alignCast(4, &address)); } /// Query and return the latest cached error on the socket. diff --git a/lib/std/x/os/net.zig b/lib/std/x/os/net.zig @@ -1,3 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2015-2021 Zig Contributors +// This file is part of [zig](https://ziglang.org/), which is MIT licensed. +// The MIT license requires this copyright notice to be included in all copies +// and substantial portions of the software. + const std = @import("../../std.zig"); const os = std.os; @@ -27,6 +33,12 @@ pub fn resolveScopeID(name: []const u8) !u32 { /// An IPv4 address comprised of 4 bytes. pub const IPv4 = extern struct { + /// A IPv4 host-port pair. + pub const Address = extern struct { + host: IPv4, + port: u16, + }; + /// Octets of a IPv4 address designating the local host. pub const localhost_octets = [_]u8{ 127, 0, 0, 1 }; @@ -200,6 +212,12 @@ pub const IPv4 = extern struct { /// An IPv6 address comprised of 16 bytes for an address, and 4 bytes /// for a scope ID; cumulatively summing to 20 bytes in total. pub const IPv6 = extern struct { + /// A IPv6 host-port pair. + pub const Address = extern struct { + host: IPv6, + port: u16, + }; + /// Octets of a IPv6 address designating the local host. pub const localhost_octets = [_]u8{0} ** 15 ++ [_]u8{0x01};