diff --git a/src/AstGen.zig b/src/AstGen.zig index 245ec45ea0..5b957e48c5 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -758,7 +758,11 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE .array_cat => return simpleBinOp(gz, scope, ri, node, .array_cat), .array_mult => { - const result = try gz.addPlNode(.array_mul, node, Zir.Inst.Bin{ + // This syntax form does not currently use the result type in the language specification. + // However, the result type can be used to emit more optimal code for large multiplications by + // having Sema perform a coercion before the multiplication operation. + const result = try gz.addPlNode(.array_mul, node, Zir.Inst.ArrayMul{ + .res_ty = if (try ri.rl.resultType(gz, node)) |t| t else .none, .lhs = try expr(gz, scope, .{ .rl = .none }, node_datas[node].lhs), .rhs = try comptimeExpr(gz, scope, .{ .rl = .{ .coerced_ty = .usize_type } }, node_datas[node].rhs), }); diff --git a/src/Autodoc.zig b/src/Autodoc.zig index 500f42dfd3..cd64b5e2cf 100644 --- a/src/Autodoc.zig +++ b/src/Autodoc.zig @@ -1567,7 +1567,6 @@ fn walkInstruction( .bit_and, .xor, .array_cat, - .array_mul, => { const pl_node = data[@intFromEnum(inst)].pl_node; const extra = file.zir.extraData(Zir.Inst.Bin, pl_node.payload_index); diff --git a/src/Sema.zig b/src/Sema.zig index 9801ce0040..f79f29dc0c 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -13998,14 +13998,49 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai const mod = sema.mod; const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node; - const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data; - const lhs = try sema.resolveInst(extra.lhs); - const lhs_ty = sema.typeOf(lhs); + const extra = sema.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data; + const uncoerced_lhs = try sema.resolveInst(extra.lhs); + const uncoerced_lhs_ty = sema.typeOf(uncoerced_lhs); const src: LazySrcLoc = inst_data.src(); const lhs_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node }; const operator_src: LazySrcLoc = .{ .node_offset_main_token = inst_data.src_node }; const rhs_src: LazySrcLoc = .{ .node_offset_bin_rhs = inst_data.src_node }; + const lhs, const lhs_ty = coerced_lhs: { + // If we have a result type, we might be able to do this more efficiently + // by coercing the LHS first. Specifically, if we want an array or vector + // and have a tuple, coerce the tuple immediately. + no_coerce: { + if (extra.res_ty == .none) break :no_coerce; + const res_ty_inst = try sema.resolveInst(extra.res_ty); + const res_ty = try sema.analyzeAsType(block, src, res_ty_inst); + if (res_ty.isGenericPoison()) break :no_coerce; + if (!uncoerced_lhs_ty.isTuple(mod)) break :no_coerce; + const lhs_len = uncoerced_lhs_ty.structFieldCount(mod); + const lhs_dest_ty = switch (res_ty.zigTypeTag(mod)) { + else => break :no_coerce, + .Array => try mod.arrayType(.{ + .child = res_ty.childType(mod).toIntern(), + .len = lhs_len, + .sentinel = if (res_ty.sentinel(mod)) |s| s.toIntern() else .none, + }), + .Vector => try mod.vectorType(.{ + .child = res_ty.childType(mod).toIntern(), + .len = lhs_len, + }), + }; + // Attempt to coerce to this type, but don't emit an error if it fails. Instead, + // just exit out of this path and let the usual error happen later, so that error + // messages are consistent. + const coerced = sema.coerceExtra(block, lhs_dest_ty, uncoerced_lhs, lhs_src, .{ .report_err = false }) catch |err| switch (err) { + error.NotCoercible => break :no_coerce, + else => |e| return e, + }; + break :coerced_lhs .{ coerced, lhs_dest_ty }; + } + break :coerced_lhs .{ uncoerced_lhs, uncoerced_lhs_ty }; + }; + if (lhs_ty.isTuple(mod)) { // In `**` rhs must be comptime-known, but lhs can be runtime-known const factor = try sema.resolveInt(block, rhs_src, extra.rhs, Type.usize, .{ @@ -14086,6 +14121,14 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai try sema.requireRuntimeBlock(block, src, lhs_src); + // Grab all the LHS values ahead of time, rather than repeatedly emitting instructions + // to get the same elem values. + const lhs_vals = try sema.arena.alloc(Air.Inst.Ref, lhs_len); + for (lhs_vals, 0..) |*lhs_val, idx| { + const idx_ref = try mod.intRef(Type.usize, idx); + lhs_val.* = try sema.elemVal(block, lhs_src, lhs, idx_ref, src, false); + } + if (ptr_addrspace) |ptr_as| { const alloc_ty = try sema.ptrType(.{ .child = result_ty.toIntern(), @@ -14099,14 +14142,11 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai var elem_i: usize = 0; while (elem_i < result_len) { - var lhs_i: usize = 0; - while (lhs_i < lhs_len) : (lhs_i += 1) { + for (lhs_vals) |lhs_val| { const elem_index = try mod.intRef(Type.usize, elem_i); - elem_i += 1; - const lhs_index = try mod.intRef(Type.usize, lhs_i); const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty); - const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true); - try sema.storePtr2(block, src, elem_ptr, src, init, lhs_src, .store); + try sema.storePtr2(block, src, elem_ptr, src, lhs_val, lhs_src, .store); + elem_i += 1; } } if (lhs_info.sentinel) |sent_val| { @@ -14120,17 +14160,9 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai } const element_refs = try sema.arena.alloc(Air.Inst.Ref, result_len); - var elem_i: usize = 0; - while (elem_i < result_len) { - var lhs_i: usize = 0; - while (lhs_i < lhs_len) : (lhs_i += 1) { - const lhs_index = try mod.intRef(Type.usize, lhs_i); - const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true); - element_refs[elem_i] = init; - elem_i += 1; - } + for (0..try sema.usizeCast(block, rhs_src, factor)) |i| { + @memcpy(element_refs[i * lhs_len ..][0..lhs_len], lhs_vals); } - return block.addAggregateInit(result_ty, element_refs); } diff --git a/src/Zir.zig b/src/Zir.zig index 2aa6c4514c..4fecfd3c50 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -250,7 +250,7 @@ pub const Inst = struct { /// Uses the `pl_node` union field. Payload is `Bin`. array_cat, /// Array multiplication `a ** b` - /// Uses the `pl_node` union field. Payload is `Bin`. + /// Uses the `pl_node` union field. Payload is `ArrayMul`. array_mul, /// `[N]T` syntax. No source location provided. /// Uses the `pl_node` union field. Payload is `Bin`. lhs is length, rhs is element type. @@ -3373,6 +3373,15 @@ pub const Inst = struct { /// The expected field count. expect_len: u32, }; + + pub const ArrayMul = struct { + /// The result type of the array multiplication operation, or `.none` if none was available. + res_ty: Ref, + /// The LHS of the array multiplication. + lhs: Ref, + /// The RHS of the array multiplication. + rhs: Ref, + }; }; pub const SpecialProng = enum { none, @"else", under }; diff --git a/src/print_zir.zig b/src/print_zir.zig index 82eca87e15..3f2334e18d 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -370,7 +370,6 @@ const Writer = struct { .add_sat, .add_unsafe, .array_cat, - .array_mul, .mul, .mulwrap, .mul_sat, @@ -431,6 +430,8 @@ const Writer = struct { .for_len => try self.writePlNodeMultiOp(stream, inst), + .array_mul => try self.writeArrayMul(stream, inst), + .elem_val_imm => try self.writeElemValImm(stream, inst), .@"export" => try self.writePlNodeExport(stream, inst), @@ -977,6 +978,18 @@ const Writer = struct { try self.writeSrc(stream, inst_data.src()); } + fn writeArrayMul(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void { + const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].pl_node; + const extra = self.code.extraData(Zir.Inst.ArrayMul, inst_data.payload_index).data; + try self.writeInstRef(stream, extra.res_ty); + try stream.writeAll(", "); + try self.writeInstRef(stream, extra.lhs); + try stream.writeAll(", "); + try self.writeInstRef(stream, extra.rhs); + try stream.writeAll(") "); + try self.writeSrc(stream, inst_data.src()); + } + fn writeElemValImm(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void { const inst_data = self.code.instructions.items(.data)[@intFromEnum(inst)].elem_val_imm; try self.writeInstRef(stream, inst_data.operand);