Skip to content

Commit

Permalink
Add PyMemoryView wrapper (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Oct 13, 2023
1 parent cd88a94 commit f939cb9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
1 change: 1 addition & 0 deletions pydust/src/types.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub usingnamespace @import("types/gil.zig");
pub usingnamespace @import("types/iter.zig");
pub usingnamespace @import("types/list.zig");
pub usingnamespace @import("types/long.zig");
pub usingnamespace @import("types/memoryview.zig");
pub usingnamespace @import("types/module.zig");
pub usingnamespace @import("types/obj.zig");
pub usingnamespace @import("types/slice.zig");
Expand Down
105 changes: 105 additions & 0 deletions pydust/src/types/memoryview.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

const std = @import("std");
const py = @import("../pydust.zig");
const ffi = py.ffi;
const PyError = @import("../errors.zig").PyError;
const PyObjectMixin = @import("./obj.zig").PyObjectMixin;

pub const PyMemoryView = extern struct {
obj: py.PyObject,

pub const Flags = struct {
const PyBUF_READ: c_int = 0x100;
const PyBUF_WRITE: c_int = 0x200;
};

pub usingnamespace PyObjectMixin("memoryview", "PyMemoryView", @This());

pub fn fromSlice(slice: anytype) !PyMemoryView {
const sliceType = Slice(@TypeOf(slice));
const flag = if (std.meta.trait.isConstPtr(sliceType)) PyMemoryView.Flags.PyBUF_READ else PyMemoryView.Flags.PyBUF_WRITE;
return .{ .obj = .{
.py = py.ffi.PyMemoryView_FromMemory(@constCast(slice.ptr), @intCast(slice.len), flag) orelse return py.PyError.PyRaised,
} };
}

pub fn fromObject(obj: py.PyObject) !PyMemoryView {
return .{ .obj = .{
.py = py.ffi.PyMemoryView_FromObject(obj.py) orelse return py.PyError.PyRaised,
} };
}

fn Slice(comptime T: type) type {
switch (@typeInfo(T)) {
.Pointer => |ptr_info| {
var new_ptr_info = ptr_info;
switch (ptr_info.size) {
.Slice => {},
.One => switch (@typeInfo(ptr_info.child)) {
.Array => |info| new_ptr_info.child = info.child,
else => @compileError("invalid type given to PyMemoryview"),
},
else => @compileError("invalid type given to PyMemoryview"),
}
new_ptr_info.size = .Slice;
return @Type(.{ .Pointer = new_ptr_info });
},
else => @compileError("invalid type given to PyMemoryview"),
}
}
};

test "from array" {
py.initialize();
defer py.finalize();

const array = "static string";
const mv = try PyMemoryView.fromSlice(array);
defer mv.decref();

var buf = try mv.obj.getBuffer(py.PyBuffer.Flags.ANY_CONTIGUOUS);
try std.testing.expectEqualSlices(u8, array, buf.asSlice(u8));
try std.testing.expect(buf.readonly);
}

test "from slice" {
py.initialize();
defer py.finalize();

const array = "This is a static string";
const slice: []const u8 = try std.testing.allocator.dupe(u8, array);
defer std.testing.allocator.free(slice);
const mv = try PyMemoryView.fromSlice(slice);
defer mv.decref();

var buf = try mv.obj.getBuffer(py.PyBuffer.Flags.ANY_CONTIGUOUS);
try std.testing.expectEqualSlices(u8, array, buf.asSlice(u8));
try std.testing.expect(buf.readonly);
}

test "from mutable slice" {
py.initialize();
defer py.finalize();

const array = "This is a static string";
const slice = try std.testing.allocator.alloc(u8, array.len);
defer std.testing.allocator.free(slice);
const mv = try PyMemoryView.fromSlice(slice);
defer mv.decref();
@memcpy(slice, array);

var buf = try mv.obj.getBuffer(py.PyBuffer.Flags.ANY_CONTIGUOUS);
try std.testing.expectEqualSlices(u8, array, buf.asSlice(u8));
try std.testing.expect(!buf.readonly);
}

0 comments on commit f939cb9

Please sign in to comment.