diff --git a/src/DB.zig b/src/DB.zig index 796265a..dee5c0e 100644 --- a/src/DB.zig +++ b/src/DB.zig @@ -307,13 +307,13 @@ const GroupMembersIter = struct { total: usize, arr: []const u8, - pub fn nextMust(self: *GroupMembersIter) ?u64 { - return self.it.nextMust(); + pub fn next(self: *GroupMembersIter) error{Overflow}!?u64 { + return self.it.next(); } }; -pub fn groupMembersIter(members_slice: []const u8) GroupMembersIter { - var vit = compress.varintSliceIteratorMust(members_slice); +pub fn groupMembersIter(members_slice: []const u8) error{Overflow}!GroupMembersIter { + var vit = try compress.varintSliceIterator(members_slice); var it = compress.deltaDecompressionIterator(&vit); return GroupMembersIter{ .arr = members_slice, @@ -324,9 +324,12 @@ pub fn groupMembersIter(members_slice: []const u8) GroupMembersIter { } // dumps PackedGroup to []u8 and returns a CGroup. -pub fn packCGroup(self: *const DB, group: *const PackedGroup, buf: []u8) error{BufferTooSmall}!CGroup { +pub fn packCGroup(self: *const DB, group: *const PackedGroup, buf: []u8) error{ + Overflow, + BufferTooSmall, +}!CGroup { const members_slice = self.groupmembers[group.members_offset..]; - var it = groupMembersIter(members_slice); + var it = try groupMembersIter(members_slice); const num_members = it.total; const ptr_end = @sizeOf(?[*:0]const u8) * (num_members + 1); @@ -338,7 +341,7 @@ pub fn packCGroup(self: *const DB, group: *const PackedGroup, buf: []u8) error{B var buf_offset: usize = ptr_end; var i: usize = 0; - while (it.nextMust()) |member_offset| : (i += 1) { + while (try it.next()) |member_offset| : (i += 1) { const entry = PackedUser.fromBytes(@alignCast(8, self.users[member_offset << 3 ..])); const start = buf_offset; const name = entry.user.name(); @@ -389,7 +392,7 @@ pub fn getgrnam( name: []const u8, buf: []u8, omit_members: bool, -) error{BufferTooSmall}!?CGroup { +) error{ Overflow, BufferTooSmall }!?CGroup { const group = self.getGroupByName(name) orelse return null; if (omit_members) return try packCGroupNoMembers(&group, buf) @@ -403,7 +406,7 @@ pub fn getgrgid( gid: u32, buf: []u8, omit_members: bool, -) error{BufferTooSmall}!?CGroup { +) error{ Overflow, BufferTooSmall }!?CGroup { const group = self.getGroupByGid(gid) orelse return null; if (omit_members) return try packCGroupNoMembers(&group, buf) @@ -687,8 +690,8 @@ fn groupsSection( }; } -pub fn userGids(self: *const DB, offset: u64) compress.DeltaDecompressionIterator { - var vit = compress.varintSliceIteratorMust(self.additional_gids[offset..]); +pub fn userGids(self: *const DB, offset: u64) error{Overflow}!compress.DeltaDecompressionIterator { + var vit = try compress.varintSliceIterator(self.additional_gids[offset..]); return compress.deltaDecompressionIterator(&vit); } diff --git a/src/compress.zig b/src/compress.zig index 364c274..46392c8 100644 --- a/src/compress.zig +++ b/src/compress.zig @@ -80,12 +80,6 @@ pub fn uvarint(buf: []const u8) error{Overflow}!Varint { }; } -pub fn uvarintMust(buf: []const u8) Varint { - return uvarint(buf) catch |err| switch (err) { - error.Overflow => unreachable, - }; -} - // https://golang.org/pkg/encoding/binary/#PutUvarint pub fn putUvarint(buf: []u8, x: u64) usize { var i: usize = 0; @@ -118,12 +112,6 @@ pub const VarintSliceIterator = struct { return value.value; } - pub fn nextMust(self: *VarintSliceIterator) ?u64 { - return self.next() catch |err| switch (err) { - error.Overflow => unreachable, - }; - } - // returns the number of remaining items. If called before the first // next(), returns the length of the slice. pub fn remaining(self: *const VarintSliceIterator) usize { @@ -140,12 +128,6 @@ pub fn varintSliceIterator(arr: []const u8) error{Overflow}!VarintSliceIterator }; } -pub fn varintSliceIteratorMust(arr: []const u8) VarintSliceIterator { - return varintSliceIterator(arr) catch |err| switch (err) { - error.Overflow => unreachable, - }; -} - pub const DeltaDecompressionIterator = struct { vit: *VarintSliceIterator, prev: u64, @@ -167,12 +149,6 @@ pub const DeltaDecompressionIterator = struct { pub fn remaining(self: *const DeltaDecompressionIterator) usize { return self.vit.remaining; } - - pub fn nextMust(self: *DeltaDecompressionIterator) ?u64 { - return self.next() catch |err| switch (err) { - error.Overflow => unreachable, - }; - } }; pub fn deltaDecompressionIterator(vit: *VarintSliceIterator) DeltaDecompressionIterator { @@ -344,8 +320,8 @@ const GroupMembersIter = struct { total: usize, }; -pub fn groupMembersIter(members_slice: []const u8) GroupMembersIter { - var vit = compress.varintSliceIteratorMust(members_slice); +pub fn groupMembersIter(members_slice: []const u8) error{Overflow}!GroupMembersIter { + var vit = try compress.varintSliceIterator(members_slice); var it = compress.deltaDecompressionIterator(&vit); return GroupMembersIter{ .vit = vit, @@ -357,11 +333,10 @@ pub fn groupMembersIter(members_slice: []const u8) GroupMembersIter { test "compress: trying to repro pointer change of DB.groupMembersIter" { const members_slice = &[_]u8{ 4, 0, 60, 2, 2, 2, 64, 2 }; - var members = groupMembersIter(members_slice); + var members = try groupMembersIter(members_slice); var i: usize = 0; - while (members.it.nextMust()) |member_offset| : (i += 1) { + while (try members.it.next()) |member_offset| : (i += 1) { _ = member_offset; - //std.debug.print("member_offset: {d}\n", .{member_offset}); } } diff --git a/src/libnss.zig b/src/libnss.zig index af7d048..8b1cba3 100644 --- a/src/libnss.zig +++ b/src/libnss.zig @@ -221,6 +221,7 @@ fn getgrgid_r( var buf = buffer[0..buflen]; const cgroup = db.getgrgid(gid, buf, omit_members) catch |err| switch (err) { + error.Overflow => return badFile(errnop), error.BufferTooSmall => { errnop.* = @enumToInt(os.E.RANGE); return c.NSS_STATUS_TRYAGAIN; @@ -263,6 +264,7 @@ fn getgrnam_r( const nameSlice = mem.sliceTo(name, 0); var buf = buffer[0..buflen]; const cgroup = db.getgrnam(nameSlice, buf, omit_members) catch |err| switch (err) { + error.Overflow => return badFile(errnop), error.BufferTooSmall => { errnop.* = @enumToInt(os.E.RANGE); return c.NSS_STATUS_TRYAGAIN; @@ -379,6 +381,10 @@ fn getgrent_r( result.* = cgroup; return c.NSS_STATUS_SUCCESS; } else |err| switch (err) { + error.Overflow => { + it.rollback(); + return badFile(errnop); + }, error.BufferTooSmall => { it.rollback(); errnop.* = @enumToInt(os.E.RANGE); @@ -469,14 +475,18 @@ fn initgroups_dyn( return c.NSS_STATUS_NOTFOUND; }; - var gids = db.userGids(user.additional_gids_offset); + var gids = db.userGids(user.additional_gids_offset) catch |err| switch (err) { + error.Overflow => return badFile(errnop), + }; const remaining = gids.vit.remaining; // the implementation below is ported from glibc's db-initgroups.c // even though we know the size of the groups upfront, I found it too difficult // to preallocate and juggle size, start and limit while keeping glibc happy. var any: bool = false; - while (gids.nextMust()) |gid| { + while (gids.next() catch |err| switch (err) { + error.Overflow => return badFile(errnop), + }) |gid| { if (start.* == size.*) { if (limit > 0 and size.* == limit) return c.NSS_STATUS_SUCCESS; @@ -528,6 +538,11 @@ fn getDBErrno(errnop: *c_int) ?*const DB { return &state.file.db; } +fn badFile(errnop: *c_int) c.enum_nss_status { + errnop.* = @enumToInt(os.E.NOENT); + return c.NSS_STATUS_NOTFOUND; +} + // isId tells if this command is "id". Reads the cmdline // from the given fd. Returns false on any error. fn isId(fd: os.fd_t) bool { diff --git a/src/turbonss-analyze.zig b/src/turbonss-analyze.zig index 2d4031b..79edf07 100644 --- a/src/turbonss-analyze.zig +++ b/src/turbonss-analyze.zig @@ -162,7 +162,13 @@ fn execute( while (it.next()) |packed_user| { const offset = packed_user.additional_gids_offset; const additional_gids = db.additional_gids[offset..]; - const vit = compress.varintSliceIteratorMust(additional_gids); + const vit = compress.varintSliceIterator(additional_gids) catch |err| { + stderr.print( + "ERROR {s}: file '{s}' is corrupted or cannot be read\n", + .{ @errorName(err), db_file }, + ) catch {}; + return 1; + }; // the primary gid of the user is never in "additional gids" const ngroups = vit.remaining + 1; if (ngroups > popUser.score) { diff --git a/src/turbonss-getent.zig b/src/turbonss-getent.zig index d21c584..b2a39d3 100644 --- a/src/turbonss-getent.zig +++ b/src/turbonss-getent.zig @@ -79,7 +79,7 @@ fn execute( } else { stderr.print("bad argument {s}: expected passwd or group\n", .{ myflags.args[0], - }) catch return 3; + }) catch {}; return 1; } }; @@ -89,18 +89,36 @@ fn execute( stderr.print( "ERROR {s}: file '{s}' is corrupted or cannot be read\n", .{ @errorName(err), db_file }, - ) catch return 3; + ) catch {}; return 1; }; defer file.close(); + // TODO: how can I handle error.Overflow in a single place? Can I catch on + // a block? return switch (mode) { - .passwd => passwd(stdout, &file.db, myflags.args[1..]), - .group => group(stdout, &file.db, myflags.args[1..]), + .passwd => passwd(stdout, &file.db, myflags.args[1..]) catch |err| switch (err) { + error.Overflow => { + stderr.print( + "ERROR {s}: file '{s}' is corrupted or cannot be read\n", + .{ @errorName(err), db_file }, + ) catch {}; + return 1; + }, + }, + .group => group(stdout, &file.db, myflags.args[1..]) catch |err| switch (err) { + error.Overflow => { + stderr.print( + "ERROR {s}: file '{s}' is corrupted or cannot be read\n", + .{ @errorName(err), db_file }, + ) catch {}; + return 1; + }, + }, }; } -fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) u8 { +fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Overflow}!u8 { if (keys.len == 0) return passwdAll(stdout, db); var some_notfound = false; @@ -134,7 +152,7 @@ fn passwdAll(stdout: anytype, db: *const DB) u8 { return 0; } -fn group(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) u8 { +fn group(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Overflow}!u8 { if (keys.len == 0) return groupAll(stdout, db); var some_notfound = false; @@ -151,31 +169,31 @@ fn group(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) u8 { continue; }; - if (printGroup(stdout, db, &g)) |exit_code| + if (try printGroup(stdout, db, &g)) |exit_code| return exit_code; } return if (some_notfound) 2 else 0; } -fn groupAll(stdout: anytype, db: *const DB) u8 { +fn groupAll(stdout: anytype, db: *const DB) error{Overflow}!u8 { var it = PackedGroup.iterator(db.groups, db.header.num_groups); while (it.next()) |g| - if (printGroup(stdout, db, &g)) |exit_code| + if (try printGroup(stdout, db, &g)) |exit_code| return exit_code; return 0; } -fn printGroup(stdout: anytype, db: *const DB, g: *const PackedGroup) ?u8 { +fn printGroup(stdout: anytype, db: *const DB, g: *const PackedGroup) error{Overflow}!?u8 { // not converting to Group to save a few memory allocations. stdout.print("{s}:x:{d}:", .{ g.name(), g.gid() }) catch return 3; - var it = DB.groupMembersIter(db.groupmembers[g.members_offset..]); + var it = try DB.groupMembersIter(db.groupmembers[g.members_offset..]); // lines will be buffered, flushed on every EOL. var line_writer = io.bufferedWriter(stdout); var i: usize = 0; - while (it.nextMust()) |member_offset| : (i += 1) { + while (try it.next()) |member_offset| : (i += 1) { const puser = PackedUser.fromBytes(@alignCast(8, db.users[member_offset << 3 ..])); const name = puser.user.name(); if (i != 0)