1
Fork 0

more safety checks in user parsing

This commit is contained in:
Motiejus Jakštys 2023-06-06 20:33:07 +03:00
parent 277a48296a
commit e1cae43d08
5 changed files with 37 additions and 26 deletions

View File

@ -342,7 +342,7 @@ pub fn packCGroup(self: *const DB, group: *const PackedGroup, buf: []u8) error{
var i: usize = 0; var i: usize = 0;
while (try it.next()) |member_offset| : (i += 1) { while (try it.next()) |member_offset| : (i += 1) {
const entry = PackedUser.fromBytes(@alignCast(8, self.users[member_offset << 3 ..])); const entry = try PackedUser.fromBytes(@alignCast(8, self.users[member_offset << 3 ..]));
const start = buf_offset; const start = buf_offset;
const name = entry.user.name(); const name = entry.user.name();
if (buf_offset + name.len + 1 > buf.len) if (buf_offset + name.len + 1 > buf.len)
@ -458,34 +458,36 @@ pub fn writeUser(self: *const DB, user: PackedUser, buf: []u8) error{BufferTooSm
}; };
} }
pub fn getUserByName(self: *const DB, name: []const u8) ?PackedUser { pub fn getUserByName(self: *const DB, name: []const u8) error{Overflow}!?PackedUser {
const idx = bdz.search(self.bdz_username, name); const idx = bdz.search(self.bdz_username, name);
// bdz may return a hash that's bigger than the number of users // bdz may return a hash that's bigger than the number of users
if (idx >= self.header.num_users) return null; if (idx >= self.header.num_users)
return null;
const offset = self.idx_name2user[idx]; const offset = self.idx_name2user[idx];
const user = PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..])).user; const user = (try PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..]))).user;
if (!mem.eql(u8, name, user.name())) return null; if (!mem.eql(u8, name, user.name())) return null;
return user; return user;
} }
// get a CUser entry by name. // get a CUser entry by name.
pub fn getpwnam(self: *const DB, name: []const u8, buf: []u8) error{BufferTooSmall}!?CUser { pub fn getpwnam(self: *const DB, name: []const u8, buf: []u8) error{ Overflow, BufferTooSmall }!?CUser {
const user = self.getUserByName(name) orelse return null; const user = try self.getUserByName(name) orelse return null;
return try self.writeUser(user, buf); return try self.writeUser(user, buf);
} }
pub fn getUserByUid(self: *const DB, uid: u32) ?PackedUser { pub fn getUserByUid(self: *const DB, uid: u32) error{Overflow}!?PackedUser {
const idx = bdz.search_u32(self.bdz_uid, uid); const idx = bdz.search_u32(self.bdz_uid, uid);
if (idx >= self.header.num_users) return null; if (idx >= self.header.num_users)
return null;
const offset = self.idx_uid2user[idx]; const offset = self.idx_uid2user[idx];
const user = PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..])).user; const user = (try PackedUser.fromBytes(@alignCast(8, self.users[offset << 3 ..]))).user;
if (uid != user.uid()) return null; if (uid != user.uid()) return null;
return user; return user;
} }
// get a CUser entry by uid. // get a CUser entry by uid.
pub fn getpwuid(self: *const DB, uid: u32, buf: []u8) error{BufferTooSmall}!?CUser { pub fn getpwuid(self: *const DB, uid: u32, buf: []u8) error{ Overflow, BufferTooSmall }!?CUser {
const user = self.getUserByUid(uid) orelse return null; const user = try self.getUserByUid(uid) orelse return null;
return try self.writeUser(user, buf); return try self.writeUser(user, buf);
} }

View File

@ -83,13 +83,11 @@ pub const Entry = struct {
end: usize, end: usize,
}; };
pub fn fromBytes(blob: []align(8) const u8) Entry { pub fn fromBytes(blob: []align(8) const u8) error{Overflow}!Entry {
const start_var_payload = @bitSizeOf(Inner) / 8; const start_var_payload = @bitSizeOf(Inner) / 8;
const inner = @ptrCast(*align(8) const Inner, blob[0..start_var_payload]); const inner = @ptrCast(*align(8) const Inner, blob[0..start_var_payload]);
const end_strings = start_var_payload + inner.stringLength(); const end_strings = start_var_payload + inner.stringLength();
const gids_offset = compress.uvarint(blob[end_strings..]) catch |err| switch (err) { const gids_offset = try compress.uvarint(blob[end_strings..]);
error.Overflow => unreachable,
};
const end_payload = end_strings + gids_offset.bytes_read; const end_payload = end_strings + gids_offset.bytes_read;
return Entry{ return Entry{
@ -110,9 +108,9 @@ pub const Iterator = struct {
total: u32, total: u32,
advanced_by: usize = 0, advanced_by: usize = 0,
pub fn next(it: *Iterator) ?PackedUser { pub fn next(it: *Iterator) error{Overflow}!?PackedUser {
if (it.idx == it.total) return null; if (it.idx == it.total) return null;
const entry = fromBytes(@alignCast(8, it.section[it.next_start..])); const entry = try fromBytes(@alignCast(8, it.section[it.next_start..]));
it.idx += 1; it.idx += 1;
it.next_start += entry.end; it.next_start += entry.end;
it.advanced_by = entry.end; it.advanced_by = entry.end;
@ -296,7 +294,7 @@ test "PackedUser construct section" {
var i: u29 = 0; var i: u29 = 0;
var it1 = PackedUser.iterator(buf.items, users.len, test_shell_reader); var it1 = PackedUser.iterator(buf.items, users.len, test_shell_reader);
while (it1.next()) |user| : (i += 1) { while (try it1.next()) |user| : (i += 1) {
try testing.expectEqual(users[i].uid, user.uid()); try testing.expectEqual(users[i].uid, user.uid());
try testing.expectEqual(users[i].gid, user.gid()); try testing.expectEqual(users[i].gid, user.gid());
try testing.expectEqual(user.additionalGidsOffset(), additional_gids); try testing.expectEqual(user.additionalGidsOffset(), additional_gids);

View File

@ -140,6 +140,7 @@ fn getpwuid_r(
errnop: *c_int, errnop: *c_int,
) c.enum_nss_status { ) c.enum_nss_status {
var cuser = db.getpwuid(uid, buffer[0..buflen]) catch |err| switch (err) { var cuser = db.getpwuid(uid, buffer[0..buflen]) catch |err| switch (err) {
error.Overflow => return badFile(errnop),
error.BufferTooSmall => { error.BufferTooSmall => {
errnop.* = @enumToInt(os.E.RANGE); errnop.* = @enumToInt(os.E.RANGE);
return c.NSS_STATUS_TRYAGAIN; return c.NSS_STATUS_TRYAGAIN;
@ -178,6 +179,7 @@ fn getpwnam_r(
var buf = buffer[0..buflen]; var buf = buffer[0..buflen];
const cuser = db.getpwnam(nameSlice, buf) catch |err| switch (err) { const cuser = db.getpwnam(nameSlice, buf) catch |err| switch (err) {
error.Overflow => return badFile(errnop),
error.BufferTooSmall => { error.BufferTooSmall => {
errnop.* = @enumToInt(os.E.RANGE); errnop.* = @enumToInt(os.E.RANGE);
return c.NSS_STATUS_TRYAGAIN; return c.NSS_STATUS_TRYAGAIN;
@ -424,7 +426,9 @@ fn getpwent_r(
return c.NSS_STATUS_UNAVAIL; return c.NSS_STATUS_UNAVAIL;
}); });
const user = it.next() orelse { const user = it.next() catch |err| switch (err) {
error.Overflow => return badFile(errnop),
} orelse {
errnop.* = 0; errnop.* = 0;
return c.NSS_STATUS_NOTFOUND; return c.NSS_STATUS_NOTFOUND;
}; };
@ -470,7 +474,9 @@ fn initgroups_dyn(
errnop: *c_int, errnop: *c_int,
) c.enum_nss_status { ) c.enum_nss_status {
const db = state.file.db; const db = state.file.db;
const user = db.getUserByName(mem.sliceTo(user_name, 0)) orelse { const user = db.getUserByName(mem.sliceTo(user_name, 0)) catch |err| switch (err) {
error.Overflow => return badFile(errnop),
} orelse {
errnop.* = @enumToInt(os.E.NOENT); errnop.* = @enumToInt(os.E.NOENT);
return c.NSS_STATUS_NOTFOUND; return c.NSS_STATUS_NOTFOUND;
}; };

View File

@ -159,7 +159,12 @@ fn execute(
db.header.num_users, db.header.num_users,
db.shellReader(), db.shellReader(),
); );
while (it.next()) |packed_user| { while (it.next() catch |err| switch (err) {
error.Overflow => {
stderr.print("ERROR: file '{s}' is corrupted\n", .{db_file}) catch {};
return 1;
},
}) |packed_user| {
const offset = packed_user.additional_gids_offset; const offset = packed_user.additional_gids_offset;
const additional_gids = db.additional_gids[offset..]; const additional_gids = db.additional_gids[offset..];
const vit = compress.varintSliceIterator(additional_gids) catch |err| { const vit = compress.varintSliceIterator(additional_gids) catch |err| {

View File

@ -120,12 +120,12 @@ fn execute(
fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Overflow}!u8 { fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Overflow}!u8 {
if (keys.len == 0) if (keys.len == 0)
return passwdAll(stdout, db); return try passwdAll(stdout, db);
var some_notfound = false; var some_notfound = false;
const shell_reader = db.shellReader(); const shell_reader = db.shellReader();
for (keys) |key| { for (keys) |key| {
const keyZ = mem.span(key); const keyZ = mem.span(key);
const maybe_packed_user = if (fmt.parseUnsigned(u32, keyZ, 10)) |uid| const maybe_packed_user = try if (fmt.parseUnsigned(u32, keyZ, 10)) |uid|
db.getUserByUid(uid) db.getUserByUid(uid)
else |_| else |_|
db.getUserByName(keyZ); db.getUserByName(keyZ);
@ -142,10 +142,10 @@ fn passwd(stdout: anytype, db: *const DB, keys: []const [*:0]const u8) error{Ove
return if (some_notfound) 2 else 0; return if (some_notfound) 2 else 0;
} }
fn passwdAll(stdout: anytype, db: *const DB) u8 { fn passwdAll(stdout: anytype, db: *const DB) error{Overflow}!u8 {
const shell_reader = db.shellReader(); const shell_reader = db.shellReader();
var it = PackedUser.iterator(db.users, db.header.num_users, shell_reader); var it = PackedUser.iterator(db.users, db.header.num_users, shell_reader);
while (it.next()) |packed_user| { while (try it.next()) |packed_user| {
const line = packed_user.toUser(db.shellReader()).toLine(); const line = packed_user.toUser(db.shellReader()).toLine();
stdout.writeAll(line.constSlice()) catch return 3; stdout.writeAll(line.constSlice()) catch return 3;
} }
@ -194,7 +194,7 @@ fn printGroup(stdout: anytype, db: *const DB, g: *const PackedGroup) error{Overf
var line_writer = io.bufferedWriter(stdout); var line_writer = io.bufferedWriter(stdout);
var i: usize = 0; var i: usize = 0;
while (try it.next()) |member_offset| : (i += 1) { while (try it.next()) |member_offset| : (i += 1) {
const puser = PackedUser.fromBytes(@alignCast(8, db.users[member_offset << 3 ..])); const puser = try PackedUser.fromBytes(@alignCast(8, db.users[member_offset << 3 ..]));
const name = puser.user.name(); const name = puser.user.name();
if (i != 0) if (i != 0)
_ = line_writer.write(",") catch return 3; _ = line_writer.write(",") catch return 3;