diff --git a/build.zig b/build.zig index 31ad47d..f335ba0 100644 --- a/build.zig +++ b/build.zig @@ -55,6 +55,22 @@ pub fn build(b: *std.Build) void { break :blk module; }; + // codegen + // The codegen produces Zig files that are imported as modules by pgzx. + const node_tags_src = blk: { + const tool = b.addExecutable(.{ + .name = "gennodetags", + .root_source_file = b.path("./tools/gennodetags/main.zig"), + .target = b.host, + .link_libc = true, + }); + tool.root_module.addIncludePath(.{ .cwd_relative = pgbuild.getIncludeServerDir() }); + tool.root_module.addIncludePath(.{ .cwd_relative = pgbuild.getIncludeDir() }); + + const tool_step = b.addRunArtifact(tool); + break :blk tool_step.addOutputFileArg("nodetags.zig"); + }; + // pgzx: main project module. // This module re-exports pgzx_pgsys, other generated modules, and utility functions. const pgzx = blk: { @@ -64,6 +80,12 @@ pub fn build(b: *std.Build) void { .optimize = optimize, }); module.addImport("pgzx_pgsys", pgzx_pgsys); + module.addAnonymousImport("gen_node_tags", .{ + .root_source_file = node_tags_src, + .imports = &.{ + .{ .name = "pgzx_pgsys", .module = pgzx_pgsys }, + }, + }); break :blk module; }; @@ -87,6 +109,12 @@ pub fn build(b: *std.Build) void { tests.lib.root_module.addImport("pgzx_pgsys", pgzx_pgsys); tests.lib.root_module.addImport("pgzx", pgzx); + tests.lib.root_module.addAnonymousImport("gen_node_tags", .{ + .root_source_file = node_tags_src, + .imports = &.{ + .{ .name = "pgzx_pgsys", .module = pgzx_pgsys }, + }, + }); break :blk tests; }; diff --git a/src/pgzx.zig b/src/pgzx.zig index 19507b1..2ebf417 100644 --- a/src/pgzx.zig +++ b/src/pgzx.zig @@ -49,4 +49,6 @@ pub const testing = @import("pgzx/testing.zig"); // helpers around at times. pub const meta = @import("pgzx/meta.zig"); +pub const node = @import("pgzx/node.zig"); + pub const guc = utils.guc; diff --git a/src/pgzx/c.zig b/src/pgzx/c.zig index c7147c2..f5dced7 100644 --- a/src/pgzx/c.zig +++ b/src/pgzx/c.zig @@ -9,6 +9,9 @@ const includes = @cImport({ @cInclude("varatt.h"); @cInclude("access/reloptions.h"); + @cInclude("access/tsmapi.h"); + + @cInclude("commands/event_trigger.h"); @cInclude("catalog/binary_upgrade.h"); @cInclude("catalog/catalog.h"); @@ -98,6 +101,34 @@ const includes = @cImport({ @cInclude("executor/spi.h"); @cInclude("executor/executor.h"); + @cInclude("windowapi.h"); + + @cInclude("lib/ilist.h"); + + @cInclude("nodes/bitmapset.h"); + @cInclude("nodes/execnodes.h"); + @cInclude("nodes/extensible.h"); + @cInclude("nodes/lockoptions.h"); + @cInclude("nodes/makefuncs.h"); + @cInclude("nodes/memnodes.h"); + @cInclude("nodes/miscnodes.h"); + @cInclude("nodes/multibitmapset.h"); + @cInclude("nodes/nodeFuncs.h"); + @cInclude("nodes/nodes.h"); + @cInclude("nodes/params.h"); + @cInclude("nodes/parsenodes.h"); + @cInclude("nodes/pathnodes.h"); + @cInclude("nodes/pg_list.h"); + @cInclude("nodes/plannodes.h"); + @cInclude("nodes/primnodes.h"); + @cInclude("nodes/print.h"); + @cInclude("nodes/queryjumble.h"); + @cInclude("nodes/readfuncs.h"); + @cInclude("nodes/replnodes.h"); + @cInclude("nodes/subscripting.h"); + @cInclude("nodes/supportnodes.h"); + @cInclude("nodes/tidbitmap.h"); + @cInclude("nodes/value.h"); @cInclude("optimizer/appendinfo.h"); @cInclude("optimizer/clauses.h"); @@ -135,8 +166,6 @@ const includes = @cImport({ @cInclude("commands/extension.h"); - @cInclude("lib/ilist.h"); - @cInclude("storage/ipc.h"); @cInclude("storage/proc.h"); @cInclude("storage/latch.h"); diff --git a/src/pgzx/node.zig b/src/pgzx/node.zig new file mode 100644 index 0000000..27b3c91 --- /dev/null +++ b/src/pgzx/node.zig @@ -0,0 +1,110 @@ +const pg = @import("pgzx_pgsys"); + +const gen = @import("gen_node_tags"); + +pub const Tag = gen.Tag; + +pub inline fn make(comptime T: type) *T { + const node: *pg.Node = @ptrCast(@alignCast(pg.palloc0fast(@sizeOf(T)))); + node.*.type = @intFromEnum(mustFindTag(T)); + return @ptrCast(@alignCast(node)); +} + +pub inline fn create(initFrom: anytype) *@TypeOf(initFrom) { + const node = make(@TypeOf(initFrom)); + node.* = initFrom; + setTag(node, mustFindTag(@TypeOf(initFrom))); + return node; +} + +fn mustFindTag(comptime T: type) Tag { + return gen.findTag(T) orelse @compileError("No tag found for type"); +} + +pub inline fn tag(node: anytype) Tag { + return @enumFromInt(asNodePtr(node).*.type); +} + +pub inline fn setTag(node: anytype, t: Tag) void { + asNodePtr(node).*.type = @intFromEnum(t); +} + +pub inline fn isA(node: anytype, t: Tag) bool { + return tag(node) == t; +} + +pub inline fn castNode(comptime T: type, node: anytype) *T { + return @ptrCast(@alignCast(asNodePtr(node))); +} + +pub inline fn safeCastNode(comptime T: type, node: anytype) ?*T { + if (tag(node) != gen.findTag(T)) { + return null; + } + return castNode(T, node); +} + +inline fn asNodePtr(node: anytype) *pg.Node { + checkIsPotentialNodePtr(node); + return @ptrCast(@alignCast(node)); +} + +inline fn checkIsPotentialNodePtr(node: anytype) void { + const nodeType = @typeInfo(@TypeOf(node)); + if (nodeType != .Pointer or nodeType.Pointer.size != .One) { + @compileError("Expected single node pointer"); + } +} + +pub const TestSuite_Node = struct { + const std = @import("std"); + + pub fn testMakeAndTag() !void { + const node = make(pg.FdwRoutine); + try std.testing.expectEqual(tag(node), .FdwRoutine); + } + + pub fn testCreate() !void { + const node = create(pg.Query{ + .commandType = pg.CMD_SELECT, + }); + try std.testing.expectEqual(tag(node), .Query); + try std.testing.expectEqual(node.*.commandType, pg.CMD_SELECT); + } + + pub fn testSetTag() !void { + const node = make(pg.Query); + setTag(node, .FdwRoutine); + try std.testing.expectEqual(tag(node), .FdwRoutine); + } + + pub fn testIsA_Ok() !void { + const node = make(pg.Query); + try std.testing.expect(isA(node, .Query)); + } + + pub fn testIsA_Fail() !void { + const node = make(pg.Query); + try std.testing.expect(!isA(node, .FdwRoutine)); + } + + pub fn testCastNode() !void { + const node: *pg.Node = @ptrCast(@alignCast(make(pg.Query))); + const query: *pg.Query = castNode(pg.Query, node); + try std.testing.expect(isA(query, .Query)); + } + + pub fn testSafeCast_Ok() !void { + const node = make(pg.Query); + const query = safeCastNode(pg.Query, node) orelse { + return error.UnexpectedCastFailure; + }; + try std.testing.expect(isA(query, .Query)); + } + + pub fn testSafeCast_Fail() !void { + const node = make(pg.Query); + const fdw = safeCastNode(pg.FdwRoutine, node); + try std.testing.expect(fdw == null); + } +}; diff --git a/src/testing.zig b/src/testing.zig index 287218d..9c88e08 100644 --- a/src/testing.zig +++ b/src/testing.zig @@ -12,8 +12,8 @@ comptime { pgzx.collections.htab.TestSuite_HTab, pgzx.meta.TestSuite_Meta, - pgzx.mem.TestSuite_Mem, + pgzx.node.TestSuite_Node, }, ); } diff --git a/tools/gennodetags/main.zig b/tools/gennodetags/main.zig new file mode 100644 index 0000000..e285a93 --- /dev/null +++ b/tools/gennodetags/main.zig @@ -0,0 +1,108 @@ +const std = @import("std"); + +const pg = @cImport({ + @cInclude("c.h"); + @cInclude("nodes/nodes.h"); +}); + +const tagsOnly = std.StaticStringMap(void).initComptime([_]struct { []const u8 }{ + // Internal markers + .{"T_Invalid"}, + + // Tags for internal types + .{"T_AllocSetContext"}, + .{"T_GenerationContext"}, + .{"T_SlabContext"}, + .{"T_WindowObjectData"}, + + // List types (only tags, all use the `List` type) + .{"T_IntList"}, + .{"T_OidList"}, + .{"T_XidList"}, +}); + +pub fn main() !void { + var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator); + defer arena_state.deinit(); + const arena = arena_state.allocator(); + + const args = try std.process.argsAlloc(arena); + if (args.len != 2) + fatal("wrong number of arguments", .{}); + + var out = std.fs.cwd().createFile(args[1], .{}) catch |err| { + fatal("create file {s}: {}\n", .{ args[1], err }); + }; + defer out.close(); + + try out.writeAll( + \\pub const std = @import("std"); + \\ + \\pub const pg = @import("pgzx_pgsys"); + \\ + \\ + ); + + // 1. collect all node tags into `node_tags` list using comptime reflection. + @setEvalBranchQuota(50000); + var node_tags = std.ArrayList([]const u8).init(arena); + defer node_tags.deinit(); + const pg_mod = @typeInfo(pg).Struct; + inline for (pg_mod.decls) |decl| { + const name = decl.name; + if (std.mem.startsWith(u8, name, "T_")) { + node_tags.append(decl.name) catch |err| { + fatal("build node tags list: {}\n", .{err}); + }; + } + } + + // 2. Create `Tag enum` with all known node tags. + try out.writeAll("pub const Tag = enum (pg.NodeTag) {\n"); + for (node_tags.items) |tag| { + const name = tag[2..]; + try out.writer().print("{s} = pg.{s},\n", .{ name, tag }); + } + try out.writeAll("};\n\n"); + + // 3. Create types -> tags mappings. Only add tags for valid types. + try out.writeAll("pub const TypeTagTable = .{\n"); + for (node_tags.items) |tag| { + if (tagsOnly.has(tag)) + continue; + + const typeName = tag[2..]; + try out.writeAll(".{"); + try out.writer().print("pg.{s}, pg.{s}", .{ tag, typeName }); + try out.writeAll("},\n"); + } + try out.writeAll("};\n"); + + try out.writeAll( + \\pub inline fn findTag(comptime T: type) ?Tag { + \\ inline for (TypeTagTable) |entry| { + \\ if (entry[1] == T) { + \\ return @enumFromInt(entry[0]); + \\ } + \\ } + \\ return null; + \\} + \\ + \\pub inline fn findType(comptime tag: Tag) ?type { + \\ const tag_int: c_int = @intCast(@intFromEnum(tag)); + \\ inline for (TypeTagTable) |entry| { + \\ if (entry[0] == tag_int) { + \\ return entry[1]; + \\ } + \\ } + \\ return null; + \\} + ); + + return std.process.cleanExit(); +} + +fn fatal(comptime format: []const u8, args: anytype) noreturn { + std.debug.print(format, args); + std.process.exit(1); +}