Skip to content

Commit

Permalink
node: Introduce new module for working with postgres Node types (#68)
Browse files Browse the repository at this point in the history
Requires #67 

Closes #66 

Postgres sources use tagging and the `Node` (and `NodeTag`) types for
many internal types. Unfortunately the macros did not translate well to
Zig, which makes it annoying to work with the low level types.

We introduce the `pgzx.node` module that provides a list of helper
functions like `pgzx.node.make`, `pgzx.node.create`, `pgzx.node.tag`, or
`pgzx.node.isA`.

We also wrap the constants into our own `pgzx.node.Tag` enum, so you can
use `pgzx.node.isA(node, .Query)` in code.

Postgres uses code generation to collect the node tags and supported
types into `nodetags.h`. We also use code generation to produce the list
of tags and type mappings. The generated module is then imported by
`pgzx.node`.

The `tools/gennodetags` tool imports the `nodes/nodes.h` header file at
compile time and uses comptime to collect all known tags into an
`ArrayList` which is then used to produce the sources for the anonymous
module imported into `pgzx`.
  • Loading branch information
urso authored Jun 23, 2024
1 parent 466ac4c commit a1f6f4c
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 3 deletions.
28 changes: 28 additions & 0 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ pub fn build(b: *std.Build) void {
break :blk module;
};

// codegen
// The codegen produces Zig files that are imported as modules by pgzx.
const node_tags_src = blk: {
const tool = b.addExecutable(.{
.name = "gennodetags",
.root_source_file = b.path("./tools/gennodetags/main.zig"),
.target = b.host,
.link_libc = true,
});
tool.root_module.addIncludePath(.{ .cwd_relative = pgbuild.getIncludeServerDir() });
tool.root_module.addIncludePath(.{ .cwd_relative = pgbuild.getIncludeDir() });

const tool_step = b.addRunArtifact(tool);
break :blk tool_step.addOutputFileArg("nodetags.zig");
};

// pgzx: main project module.
// This module re-exports pgzx_pgsys, other generated modules, and utility functions.
const pgzx = blk: {
Expand All @@ -64,6 +80,12 @@ pub fn build(b: *std.Build) void {
.optimize = optimize,
});
module.addImport("pgzx_pgsys", pgzx_pgsys);
module.addAnonymousImport("gen_node_tags", .{
.root_source_file = node_tags_src,
.imports = &.{
.{ .name = "pgzx_pgsys", .module = pgzx_pgsys },
},
});

break :blk module;
};
Expand All @@ -87,6 +109,12 @@ pub fn build(b: *std.Build) void {

tests.lib.root_module.addImport("pgzx_pgsys", pgzx_pgsys);
tests.lib.root_module.addImport("pgzx", pgzx);
tests.lib.root_module.addAnonymousImport("gen_node_tags", .{
.root_source_file = node_tags_src,
.imports = &.{
.{ .name = "pgzx_pgsys", .module = pgzx_pgsys },
},
});

break :blk tests;
};
Expand Down
2 changes: 2 additions & 0 deletions src/pgzx.zig
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ pub const testing = @import("pgzx/testing.zig");
// helpers around at times.
pub const meta = @import("pgzx/meta.zig");

pub const node = @import("pgzx/node.zig");

pub const guc = utils.guc;
33 changes: 31 additions & 2 deletions src/pgzx/c.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ const includes = @cImport({
@cInclude("varatt.h");

@cInclude("access/reloptions.h");
@cInclude("access/tsmapi.h");

@cInclude("commands/event_trigger.h");

@cInclude("catalog/binary_upgrade.h");
@cInclude("catalog/catalog.h");
Expand Down Expand Up @@ -98,6 +101,34 @@ const includes = @cImport({

@cInclude("executor/spi.h");
@cInclude("executor/executor.h");
@cInclude("windowapi.h");

@cInclude("lib/ilist.h");

@cInclude("nodes/bitmapset.h");
@cInclude("nodes/execnodes.h");
@cInclude("nodes/extensible.h");
@cInclude("nodes/lockoptions.h");
@cInclude("nodes/makefuncs.h");
@cInclude("nodes/memnodes.h");
@cInclude("nodes/miscnodes.h");
@cInclude("nodes/multibitmapset.h");
@cInclude("nodes/nodeFuncs.h");
@cInclude("nodes/nodes.h");
@cInclude("nodes/params.h");
@cInclude("nodes/parsenodes.h");
@cInclude("nodes/pathnodes.h");
@cInclude("nodes/pg_list.h");
@cInclude("nodes/plannodes.h");
@cInclude("nodes/primnodes.h");
@cInclude("nodes/print.h");
@cInclude("nodes/queryjumble.h");
@cInclude("nodes/readfuncs.h");
@cInclude("nodes/replnodes.h");
@cInclude("nodes/subscripting.h");
@cInclude("nodes/supportnodes.h");
@cInclude("nodes/tidbitmap.h");
@cInclude("nodes/value.h");

@cInclude("optimizer/appendinfo.h");
@cInclude("optimizer/clauses.h");
Expand Down Expand Up @@ -135,8 +166,6 @@ const includes = @cImport({

@cInclude("commands/extension.h");

@cInclude("lib/ilist.h");

@cInclude("storage/ipc.h");
@cInclude("storage/proc.h");
@cInclude("storage/latch.h");
Expand Down
110 changes: 110 additions & 0 deletions src/pgzx/node.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
const pg = @import("pgzx_pgsys");

const gen = @import("gen_node_tags");

pub const Tag = gen.Tag;

pub inline fn make(comptime T: type) *T {
const node: *pg.Node = @ptrCast(@alignCast(pg.palloc0fast(@sizeOf(T))));
node.*.type = @intFromEnum(mustFindTag(T));
return @ptrCast(@alignCast(node));
}

pub inline fn create(initFrom: anytype) *@TypeOf(initFrom) {
const node = make(@TypeOf(initFrom));
node.* = initFrom;
setTag(node, mustFindTag(@TypeOf(initFrom)));
return node;
}

fn mustFindTag(comptime T: type) Tag {
return gen.findTag(T) orelse @compileError("No tag found for type");
}

pub inline fn tag(node: anytype) Tag {
return @enumFromInt(asNodePtr(node).*.type);
}

pub inline fn setTag(node: anytype, t: Tag) void {
asNodePtr(node).*.type = @intFromEnum(t);
}

pub inline fn isA(node: anytype, t: Tag) bool {
return tag(node) == t;
}

pub inline fn castNode(comptime T: type, node: anytype) *T {
return @ptrCast(@alignCast(asNodePtr(node)));
}

pub inline fn safeCastNode(comptime T: type, node: anytype) ?*T {
if (tag(node) != gen.findTag(T)) {
return null;
}
return castNode(T, node);
}

inline fn asNodePtr(node: anytype) *pg.Node {
checkIsPotentialNodePtr(node);
return @ptrCast(@alignCast(node));
}

inline fn checkIsPotentialNodePtr(node: anytype) void {
const nodeType = @typeInfo(@TypeOf(node));
if (nodeType != .Pointer or nodeType.Pointer.size != .One) {
@compileError("Expected single node pointer");
}
}

pub const TestSuite_Node = struct {
const std = @import("std");

pub fn testMakeAndTag() !void {
const node = make(pg.FdwRoutine);
try std.testing.expectEqual(tag(node), .FdwRoutine);
}

pub fn testCreate() !void {
const node = create(pg.Query{
.commandType = pg.CMD_SELECT,
});
try std.testing.expectEqual(tag(node), .Query);
try std.testing.expectEqual(node.*.commandType, pg.CMD_SELECT);
}

pub fn testSetTag() !void {
const node = make(pg.Query);
setTag(node, .FdwRoutine);
try std.testing.expectEqual(tag(node), .FdwRoutine);
}

pub fn testIsA_Ok() !void {
const node = make(pg.Query);
try std.testing.expect(isA(node, .Query));
}

pub fn testIsA_Fail() !void {
const node = make(pg.Query);
try std.testing.expect(!isA(node, .FdwRoutine));
}

pub fn testCastNode() !void {
const node: *pg.Node = @ptrCast(@alignCast(make(pg.Query)));
const query: *pg.Query = castNode(pg.Query, node);
try std.testing.expect(isA(query, .Query));
}

pub fn testSafeCast_Ok() !void {
const node = make(pg.Query);
const query = safeCastNode(pg.Query, node) orelse {
return error.UnexpectedCastFailure;
};
try std.testing.expect(isA(query, .Query));
}

pub fn testSafeCast_Fail() !void {
const node = make(pg.Query);
const fdw = safeCastNode(pg.FdwRoutine, node);
try std.testing.expect(fdw == null);
}
};
2 changes: 1 addition & 1 deletion src/testing.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ comptime {
pgzx.collections.htab.TestSuite_HTab,

pgzx.meta.TestSuite_Meta,

pgzx.mem.TestSuite_Mem,
pgzx.node.TestSuite_Node,
},
);
}
108 changes: 108 additions & 0 deletions tools/gennodetags/main.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
const std = @import("std");

const pg = @cImport({
@cInclude("c.h");
@cInclude("nodes/nodes.h");
});

const tagsOnly = std.StaticStringMap(void).initComptime([_]struct { []const u8 }{
// Internal markers
.{"T_Invalid"},

// Tags for internal types
.{"T_AllocSetContext"},
.{"T_GenerationContext"},
.{"T_SlabContext"},
.{"T_WindowObjectData"},

// List types (only tags, all use the `List` type)
.{"T_IntList"},
.{"T_OidList"},
.{"T_XidList"},
});

pub fn main() !void {
var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
defer arena_state.deinit();
const arena = arena_state.allocator();

const args = try std.process.argsAlloc(arena);
if (args.len != 2)
fatal("wrong number of arguments", .{});

var out = std.fs.cwd().createFile(args[1], .{}) catch |err| {
fatal("create file {s}: {}\n", .{ args[1], err });
};
defer out.close();

try out.writeAll(
\\pub const std = @import("std");
\\
\\pub const pg = @import("pgzx_pgsys");
\\
\\
);

// 1. collect all node tags into `node_tags` list using comptime reflection.
@setEvalBranchQuota(50000);
var node_tags = std.ArrayList([]const u8).init(arena);
defer node_tags.deinit();
const pg_mod = @typeInfo(pg).Struct;
inline for (pg_mod.decls) |decl| {
const name = decl.name;
if (std.mem.startsWith(u8, name, "T_")) {
node_tags.append(decl.name) catch |err| {
fatal("build node tags list: {}\n", .{err});
};
}
}

// 2. Create `Tag enum` with all known node tags.
try out.writeAll("pub const Tag = enum (pg.NodeTag) {\n");
for (node_tags.items) |tag| {
const name = tag[2..];
try out.writer().print("{s} = pg.{s},\n", .{ name, tag });
}
try out.writeAll("};\n\n");

// 3. Create types -> tags mappings. Only add tags for valid types.
try out.writeAll("pub const TypeTagTable = .{\n");
for (node_tags.items) |tag| {
if (tagsOnly.has(tag))
continue;

const typeName = tag[2..];
try out.writeAll(".{");
try out.writer().print("pg.{s}, pg.{s}", .{ tag, typeName });
try out.writeAll("},\n");
}
try out.writeAll("};\n");

try out.writeAll(
\\pub inline fn findTag(comptime T: type) ?Tag {
\\ inline for (TypeTagTable) |entry| {
\\ if (entry[1] == T) {
\\ return @enumFromInt(entry[0]);
\\ }
\\ }
\\ return null;
\\}
\\
\\pub inline fn findType(comptime tag: Tag) ?type {
\\ const tag_int: c_int = @intCast(@intFromEnum(tag));
\\ inline for (TypeTagTable) |entry| {
\\ if (entry[0] == tag_int) {
\\ return entry[1];
\\ }
\\ }
\\ return null;
\\}
);

return std.process.cleanExit();
}

fn fatal(comptime format: []const u8, args: anytype) noreturn {
std.debug.print(format, args);
std.process.exit(1);
}

0 comments on commit a1f6f4c

Please sign in to comment.