From e1cae43d08d1573873aeaf9e4c1da8f10cbe27a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Motiejus=20Jak=C5=A1tys?= Date: Tue, 6 Jun 2023 20:33:07 +0300 Subject: [PATCH] more safety checks in user parsing --- src/DB.zig | 24 +++++++++++++----------- src/PackedUser.zig | 12 +++++------- src/libnss.zig | 10 ++++++++-- src/turbonss-analyze.zig | 7 ++++++- src/turbonss-getent.zig | 10 +++++----- 5 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/DB.zig b/src/DB.zig index dee5c0e..3f65d64 100644 --- a/src/DB.zig +++ b/src/DB.zig @@ -342,7 +342,7 @@ pub fn packCGroup(self: *const DB, group: *const PackedGroup, buf: []u8) error{ var i: usize = 0; while (try it.next()) |member_offset| : (i += 1) { - const entry = PackedUser.fromBytes(@alignCast(8, self.users[member_offset << 3 ..])); + const entry = try PackedUser.fromBytes(@alignCast(8, self.users[member_offset << 3 ..])); const start = buf_offset; const name = entry.user.name(); if (buf_offset + name.len + 1 > buf.len) @@ -458,34 +458,36 @@ pub fn writeUser(self: *const DB, user: PackedUser, buf: []u8) error{BufferTooSm }; } -pub fn getUserByName(self: *const DB, name: []const u8) ?PackedUser { +pub fn getUserByName(self: *const DB, name: []const u8) error{Overflow}!?PackedUser { const idx = bdz.search(self.bdz_username, name); // bdz may return a hash that's bigger than the number of users - if (idx >= self.header.num_users) return null; + if (idx >= self.header.num_users) + return null; const offset = self.idx_name2user[idx]; - const user = PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..])).user; + const user = (try PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..]))).user; if (!mem.eql(u8, name, user.name())) return null; return user; } // get a CUser entry by name. -pub fn getpwnam(self: *const DB, name: []const u8, buf: []u8) error{BufferTooSmall}!?CUser { - const user = self.getUserByName(name) orelse return null; +pub fn getpwnam(self: *const DB, name: []const u8, buf: []u8) error{ Overflow, BufferTooSmall }!?CUser { + const user = try self.getUserByName(name) orelse return null; return try self.writeUser(user, buf); } -pub fn getUserByUid(self: *const DB, uid: u32) ?PackedUser { +pub fn getUserByUid(self: *const DB, uid: u32) error{Overflow}!?PackedUser { const idx = bdz.search_u32(self.bdz_uid, uid); - if (idx >= self.header.num_users) return null; + if (idx >= self.header.num_users) + return null; const offset = self.idx_uid2user[idx]; - const user = PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..])).user; + const user = (try PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..]))).user; if (uid != user.uid()) return null; return user; } // get a CUser entry by uid. -pub fn getpwuid(self: *const DB, uid: u32, buf: []u8) error{BufferTooSmall}!?CUser { - const user = self.getUserByUid(uid) orelse return null; +pub fn getpwuid(self: *const DB, uid: u32, buf: []u8) error{ Overflow, BufferTooSmall }!?CUser { + const user = try self.getUserByUid(uid) orelse return null; return try self.writeUser(user, buf); } diff --git a/src/PackedUser.zig b/src/PackedUser.zig index d9a7645..aee8a53 100644 --- a/src/PackedUser.zig +++ b/src/PackedUser.zig @@ -83,13 +83,11 @@ pub const Entry = struct { end: usize, }; -pub fn fromBytes(blob: []align(8) const u8) Entry { +pub fn fromBytes(blob: []align(8) const u8) error{Overflow}!Entry { const start_var_payload = @bitSizeOf(Inner) / 8; const inner = @ptrCast(*align(8) const Inner, blob[0..start_var_payload]); const end_strings = start_var_payload + inner.stringLength(); - const gids_offset = compress.uvarint(blob[end_strings..]) catch |err| switch (err) { - error.Overflow => unreachable, - }; + const gids_offset = try compress.uvarint(blob[end_strings..]); const end_payload = end_strings + gids_offset.bytes_read; return Entry{ @@ -110,9 +108,9 @@ pub const Iterator = struct { total: u32, advanced_by: usize = 0, - pub fn next(it: *Iterator) ?PackedUser { + pub fn next(it: *Iterator) error{Overflow}!?PackedUser { if (it.idx == it.total) return null; - const entry = fromBytes(@alignCast(8, it.section[it.next_start..])); + const entry = try fromBytes(@alignCast(8, it.section[it.next_start..])); it.idx += 1; it.next_start += entry.end; it.advanced_by = entry.end; @@ -296,7 +294,7 @@ test "PackedUser construct section" { var i: u29 = 0; var it1 = PackedUser.iterator(buf.items, users.len, test_shell_reader); - while (it1.next()) |user| : (i += 1) { + while (try it1.next()) |user| : (i += 1) { try testing.expectEqual(users[i].uid, user.uid()); try testing.expectEqual(users[i].gid, user.gid()); try testing.expectEqual(user.additionalGidsOffset(), additional_gids); diff --git a/src/libnss.zig b/src/libnss.zig index 8b1cba3..813875f 100644 --- a/src/libnss.zig +++ b/src/libnss.zig @@ -140,6 +140,7 @@ fn getpwuid_r( errnop: *c_int, ) c.enum_nss_status { var cuser = db.getpwuid(uid, buffer[0..buflen]) catch |err| switch (err) { + error.Overflow => return badFile(errnop), error.BufferTooSmall => { errnop.* = @enumToInt(os.E.RANGE); return c.NSS_STATUS_TRYAGAIN; @@ -178,6 +179,7 @@ fn getpwnam_r( var buf = buffer[0..buflen]; const cuser = db.getpwnam(nameSlice, buf) catch |err| switch (err) { + error.Overflow => return badFile(errnop), error.BufferTooSmall => { errnop.* = @enumToInt(os.E.RANGE); return c.NSS_STATUS_TRYAGAIN; @@ -424,7 +426,9 @@ fn getpwent_r( return c.NSS_STATUS_UNAVAIL; }); - const user = it.next() orelse { + const user = it.next() catch |err| switch (err) { + error.Overflow => return badFile(errnop), + } orelse { errnop.* = 0; return c.NSS_STATUS_NOTFOUND; }; @@ -470,7 +474,9 @@ fn initgroups_dyn( errnop: *c_int, ) c.enum_nss_status { const db = state.file.db; - const user = db.getUserByName(mem.sliceTo(user_name, 0)) orelse { + const user = db.getUserByName(mem.sliceTo(user_name, 0)) catch |err| switch (err) { + error.Overflow => return badFile(errnop), + } orelse { errnop.* = @enumToInt(os.E.NOENT); return c.NSS_STATUS_NOTFOUND; }; diff --git a/src/turbonss-analyze.zig b/src/turbonss-analyze.zig index 79edf07..dfaf232 100644 --- a/src/turbonss-analyze.zig +++ b/src/turbonss-analyze.zig @@ -159,7 +159,12 @@ fn execute( db.header.num_users, db.shellReader(), ); - while (it.next()) |packed_user| { + while (it.next() catch |err| switch (err) { + error.Overflow => { + stderr.print("ERROR: file '{s}' is corrupted\n", .{db_file}) catch {}; + return 1; + }, + }) |packed_user| { const offset = packed_user.additional_gids_offset; const additional_gids = db.additional_gids[offset..]; const vit = compress.varintSliceIterator(additional_gids) catch |err| { diff --git a/src/turbonss-getent.zig b/src/turbonss-getent.zig index b2a39d3..463ec13 100644 --- a/src/turbonss-getent.zig +++ b/src/turbonss-getent.zig @@ -120,12 +120,12 @@ fn execute( fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Overflow}!u8 { if (keys.len == 0) - return passwdAll(stdout, db); + return try passwdAll(stdout, db); var some_notfound = false; const shell_reader = db.shellReader(); for (keys) |key| { const keyZ = mem.span(key); - const maybe_packed_user = if (fmt.parseUnsigned(u32, keyZ, 10)) |uid| + const maybe_packed_user = try if (fmt.parseUnsigned(u32, keyZ, 10)) |uid| db.getUserByUid(uid) else |_| db.getUserByName(keyZ); @@ -142,10 +142,10 @@ fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Ove return if (some_notfound) 2 else 0; } -fn passwdAll(stdout: anytype, db: *const DB) u8 { +fn passwdAll(stdout: anytype, db: *const DB) error{Overflow}!u8 { const shell_reader = db.shellReader(); var it = PackedUser.iterator(db.users, db.header.num_users, shell_reader); - while (it.next()) |packed_user| { + while (try it.next()) |packed_user| { const line = packed_user.toUser(db.shellReader()).toLine(); stdout.writeAll(line.constSlice()) catch return 3; } @@ -194,7 +194,7 @@ fn printGroup(stdout: anytype, db: *const DB, g: *const PackedGroup) error{Overf var line_writer = io.bufferedWriter(stdout); var i: usize = 0; while (try it.next()) |member_offset| : (i += 1) { - const puser = PackedUser.fromBytes(@alignCast(8, db.users[member_offset << 3 ..])); + const puser = try PackedUser.fromBytes(@alignCast(8, db.users[member_offset << 3 ..])); const name = puser.user.name(); if (i != 0) _ = line_writer.write(",") catch return 3;