spirv: define and use extended instruction set opcodes

This commit is contained in:
Ali Cheraghi
2025-08-03 17:02:51 +03:30
parent 246e1de554
commit cd4b03c5ed
7 changed files with 977 additions and 3679 deletions

View File

@@ -12,6 +12,7 @@ const ExtendedStructSet = std.StringHashMap(void);
const Extension = struct {
name: []const u8,
opcode_name: []const u8,
spec: ExtensionRegistry,
};
@@ -44,23 +45,11 @@ const OperandKindMap = std.ArrayHashMap(StringPair, OperandKind, StringPairConte
/// Khronos made it so that these names are not defined explicitly, so
/// we need to hardcode it (like they did).
/// See https://github.com/KhronosGroup/SPIRV-Registry/
const set_names = std.StaticStringMap([]const u8).initComptime(.{
.{ "opencl.std.100", "OpenCL.std" },
.{ "glsl.std.450", "GLSL.std.450" },
.{ "opencl.debuginfo.100", "OpenCL.DebugInfo.100" },
.{ "spv-amd-shader-ballot", "SPV_AMD_shader_ballot" },
.{ "nonsemantic.shader.debuginfo.100", "NonSemantic.Shader.DebugInfo.100" },
.{ "nonsemantic.vkspreflection", "NonSemantic.VkspReflection" },
.{ "nonsemantic.clspvreflection", "NonSemantic.ClspvReflection.6" }, // This version needs to be handled manually
.{ "spv-amd-gcn-shader", "SPV_AMD_gcn_shader" },
.{ "spv-amd-shader-trinary-minmax", "SPV_AMD_shader_trinary_minmax" },
.{ "debuginfo", "DebugInfo" },
.{ "nonsemantic.debugprintf", "NonSemantic.DebugPrintf" },
.{ "spv-amd-shader-explicit-vertex-parameter", "SPV_AMD_shader_explicit_vertex_parameter" },
.{ "nonsemantic.debugbreak", "NonSemantic.DebugBreak" },
.{ "tosa.001000.1", "SPV_EXT_INST_TYPE_TOSA_001000_1" },
.{ "zig", "zig" },
/// See https://github.com/KhronosGroup/SPIRV-Registry
const set_names = std.StaticStringMap(struct { []const u8, []const u8 }).initComptime(.{
.{ "opencl.std.100", .{ "OpenCL.std", "OpenClOpcode" } },
.{ "glsl.std.450", .{ "GLSL.std.450", "GlslOpcode" } },
.{ "zig", .{ "zig", "Zig" } },
});
var arena = std.heap.ArenaAllocator.init(std.heap.smp_allocator);
@@ -78,7 +67,7 @@ pub fn main() !void {
const dir = try std.fs.cwd().openDir(json_path, .{ .iterate = true });
const core_spec = try readRegistry(CoreRegistry, dir, "spirv.core.grammar.json");
std.sort.block(Instruction, core_spec.instructions, CmpInst{}, CmpInst.lt);
std.mem.sortUnstable(Instruction, core_spec.instructions, CmpInst{}, CmpInst.lt);
var exts = std.ArrayList(Extension).init(allocator);
@@ -134,14 +123,24 @@ fn readExtRegistry(exts: *std.ArrayList(Extension), dir: std.fs.Dir, sub_path: [
const name = filename["extinst.".len .. filename.len - ".grammar.json".len];
const spec = try readRegistry(ExtensionRegistry, dir, sub_path);
const set_name = set_names.get(name) orelse {
std.log.info("ignored instruction set '{s}'", .{name});
return;
};
std.sort.block(Instruction, spec.instructions, CmpInst{}, CmpInst.lt);
try exts.append(.{ .name = set_names.get(name).?, .spec = spec });
try exts.append(.{
.name = set_name.@"0",
.opcode_name = set_name.@"1",
.spec = spec,
});
}
fn readRegistry(comptime RegistryType: type, dir: std.fs.Dir, path: []const u8) !RegistryType {
const spec = try dir.readFileAlloc(allocator, path, std.math.maxInt(usize));
// Required for json parsing.
// TODO: ALI
@setEvalBranchQuota(10000);
var scanner = std.json.Scanner.initCompleteInput(allocator, spec);
@@ -191,7 +190,11 @@ fn tagPriorityScore(tag: []const u8) usize {
}
}
fn render(writer: *std.io.Writer, registry: CoreRegistry, extensions: []const Extension) !void {
fn render(
writer: *std.io.Writer,
registry: CoreRegistry,
extensions: []const Extension,
) !void {
try writer.writeAll(
\\//! This file is auto-generated by tools/gen_spirv_spec.zig.
\\
@@ -317,13 +320,18 @@ fn render(writer: *std.io.Writer, registry: CoreRegistry, extensions: []const Ex
// Note: extensions don't seem to have class.
try renderClass(writer, registry.instructions);
try renderOperandKind(writer, all_operand_kinds.values());
try renderOpcodes(writer, registry.instructions, extended_structs);
try renderOpcodes(writer, "Opcode", true, registry.instructions, extended_structs);
for (extensions) |ext| {
try renderOpcodes(writer, ext.opcode_name, false, ext.spec.instructions, extended_structs);
}
try renderOperandKinds(writer, all_operand_kinds.values(), extended_structs);
try renderInstructionSet(writer, registry, extensions, all_operand_kinds);
}
fn renderInstructionSet(
writer: anytype,
writer: *std.io.Writer,
core: CoreRegistry,
extensions: []const Extension,
all_operand_kinds: OperandKindMap,
@@ -358,7 +366,7 @@ fn renderInstructionSet(
}
fn renderInstructionsCase(
writer: anytype,
writer: *std.io.Writer,
set_name: []const u8,
instructions: []const Instruction,
all_operand_kinds: OperandKindMap,
@@ -405,7 +413,7 @@ fn renderInstructionsCase(
);
}
fn renderClass(writer: anytype, instructions: []const Instruction) !void {
fn renderClass(writer: *std.io.Writer, instructions: []const Instruction) !void {
var class_map = std.StringArrayHashMap(void).init(allocator);
for (instructions) |inst| {
@@ -454,7 +462,7 @@ fn formatId(identifier: []const u8) std.fmt.Alt(Formatter, Formatter.format) {
return .{ .data = .{ .data = identifier } };
}
fn renderOperandKind(writer: anytype, operands: []const OperandKind) !void {
fn renderOperandKind(writer: *std.io.Writer, operands: []const OperandKind) !void {
try writer.writeAll(
\\pub const OperandKind = enum {
\\ opcode,
@@ -510,7 +518,7 @@ fn renderOperandKind(writer: anytype, operands: []const OperandKind) !void {
try writer.writeAll("};\n}\n};\n");
}
fn renderEnumerant(writer: anytype, enumerant: Enumerant) !void {
fn renderEnumerant(writer: *std.io.Writer, enumerant: Enumerant) !void {
try writer.print(".{{.name = \"{s}\", .value = ", .{enumerant.enumerant});
switch (enumerant.value) {
.bitflag => |flag| try writer.writeAll(flag),
@@ -527,7 +535,9 @@ fn renderEnumerant(writer: anytype, enumerant: Enumerant) !void {
}
fn renderOpcodes(
writer: anytype,
writer: *std.io.Writer,
opcode_type_name: []const u8,
want_operands: bool,
instructions: []const Instruction,
extended_structs: ExtendedStructSet,
) !void {
@@ -538,7 +548,9 @@ fn renderOpcodes(
try aliases.ensureTotalCapacity(instructions.len);
for (instructions, 0..) |inst, i| {
if (std.mem.eql(u8, inst.class.?, "@exclude")) continue;
if (inst.class) |class| {
if (std.mem.eql(u8, class, "@exclude")) continue;
}
const result = inst_map.getOrPutAssumeCapacity(inst.opcode);
if (!result.found_existing) {
@@ -562,58 +574,67 @@ fn renderOpcodes(
const instructions_indices = inst_map.values();
try writer.writeAll("pub const Opcode = enum(u16) {\n");
try writer.print("\npub const {f} = enum(u16) {{\n", .{std.zig.fmtId(opcode_type_name)});
for (instructions_indices) |i| {
const inst = instructions[i];
try writer.print("{f} = {},\n", .{ std.zig.fmtId(inst.opname), inst.opcode });
}
try writer.writeAll(
\\
);
try writer.writeAll("\n");
for (aliases.items) |alias| {
try writer.print("pub const {f} = Opcode.{f};\n", .{
try writer.print("pub const {f} = {f}.{f};\n", .{
formatId(instructions[alias.inst].opname),
std.zig.fmtId(opcode_type_name),
formatId(instructions[alias.alias].opname),
});
}
try writer.writeAll(
\\
\\pub fn Operands(comptime self: Opcode) type {
\\ return switch (self) {
\\
);
if (want_operands) {
try writer.print(
\\
\\pub fn Operands(comptime self: {f}) type {{
\\ return switch (self) {{
\\
, .{std.zig.fmtId(opcode_type_name)});
for (instructions_indices) |i| {
const inst = instructions[i];
try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs, false);
for (instructions_indices) |i| {
const inst = instructions[i];
try renderOperand(writer, .instruction, inst.opname, inst.operands, extended_structs, false);
}
try writer.writeAll(
\\ };
\\}
\\
);
try writer.print(
\\pub fn class(self: {f}) Class {{
\\ return switch (self) {{
\\
, .{std.zig.fmtId(opcode_type_name)});
for (instructions_indices) |i| {
const inst = instructions[i];
try writer.print(".{f} => .{f},\n", .{ std.zig.fmtId(inst.opname), formatId(inst.class.?) });
}
try writer.writeAll(
\\ };
\\}
\\
);
}
try writer.writeAll(
\\ };
\\}
\\pub fn class(self: Opcode) Class {
\\ return switch (self) {
\\
);
for (instructions_indices) |i| {
const inst = instructions[i];
try writer.print(".{f} => .{f},\n", .{ std.zig.fmtId(inst.opname), formatId(inst.class.?) });
}
try writer.writeAll(
\\ };
\\}
\\};
\\
);
}
fn renderOperandKinds(
writer: anytype,
writer: *std.io.Writer,
kinds: []const OperandKind,
extended_structs: ExtendedStructSet,
) !void {
@@ -627,7 +648,7 @@ fn renderOperandKinds(
}
fn renderValueEnum(
writer: anytype,
writer: *std.io.Writer,
enumeration: OperandKind,
extended_structs: ExtendedStructSet,
) !void {
@@ -705,7 +726,7 @@ fn renderValueEnum(
}
fn renderBitEnum(
writer: anytype,
writer: *std.io.Writer,
enumeration: OperandKind,
extended_structs: ExtendedStructSet,
) !void {
@@ -788,7 +809,7 @@ fn renderBitEnum(
}
fn renderOperand(
writer: anytype,
writer: *std.io.Writer,
kind: enum {
@"union",
instruction,
@@ -872,7 +893,7 @@ fn renderOperand(
try writer.writeAll(",\n");
}
fn renderFieldName(writer: anytype, operands: []const Operand, field_index: usize) !void {
fn renderFieldName(writer: *std.io.Writer, operands: []const Operand, field_index: usize) !void {
const operand = operands[field_index];
derive_from_kind: {