diff --git a/lib/std/math/big/int.zig b/lib/std/math/big/int.zig index ce73c0c648..b182993885 100644 --- a/lib/std/math/big/int.zig +++ b/lib/std/math/big/int.zig @@ -3228,8 +3228,19 @@ pub const Managed = struct { /// r = ⌊√a⌋ pub fn sqrt(rma: *Managed, a: *const Managed) !void { - const needed_limbs = calcSqrtLimbsBufferLen(a.bitCountAbs()); + const bit_count = a.bitCountAbs(); + if (bit_count == 0) { + try rma.set(0); + rma.setMetadata(a.isPositive(), rma.len()); + return; + } + + if (!a.isPositive()) { + return error.SqrtOfNegativeNumber; + } + + const needed_limbs = calcSqrtLimbsBufferLen(bit_count); const limbs_buffer = try rma.allocator.alloc(Limb, needed_limbs); defer rma.allocator.free(limbs_buffer); diff --git a/lib/std/math/big/int_test.zig b/lib/std/math/big/int_test.zig index da8fb98c5c..ac4326ec4e 100644 --- a/lib/std/math/big/int_test.zig +++ b/lib/std/math/big/int_test.zig @@ -3149,3 +3149,29 @@ test "big.int.Const.order 0 == -0" { }; try std.testing.expectEqual(std.math.Order.eq, a.order(b)); } + +test "big.int.Managed sqrt(0) = 0" { + const allocator = testing.allocator; + var a = try Managed.initSet(allocator, 1); + defer a.deinit(); + + var res = try Managed.initSet(allocator, 1); + defer res.deinit(); + + try a.setString(10, "0"); + + try res.sqrt(&a); +} + +test "big.int.Managed sqrt(-1) = error" { + const allocator = testing.allocator; + var a = try Managed.initSet(allocator, 1); + defer a.deinit(); + + var res = try Managed.initSet(allocator, 1); + defer res.deinit(); + + try a.setString(10, "-1"); + + try testing.expectError(error.SqrtOfNegativeNumber, res.sqrt(&a)); +}