说明
pytorch的底层实现是用的c++,导致检查type极其麻烦,需要jit、template之类的技术,但是zig是支持type类型的,comptime如虎添翼。
简单的函数名和参数检查
注:以下代码是deepseek r1 生成的,然后改了下bug,在 0.14-dev下能跑起来。
const std = @import("std");
const Tensor = struct {};
// 定义 Tensor 结构体示例
const Dpu = struct {
// 示例方法,符合要求的签名
pub fn add(self: *Tensor, dim: i32, index: *Tensor, src: *Tensor) Tensor {
std.debug.print("add {}, {}, {}, {}\n", .{ self, dim, index, src });
return Tensor{};
}
};
// 编译时检查方法签名的函数
fn checkMethodSignature(
comptime Struct: type,
comptime methodName: []const u8,
comptime expectedParamTypes: []const type,
comptime expectedReturnType: type,
) void {
// 检查方法是否存在于结构体中
if (!@hasDecl(Struct, methodName)) {
@compileError("方法 '" ++ methodName ++ "' 不存在于 " ++ @typeName(Struct));
}
// 获取方法实例及其类型信息
const method = @field(Struct, methodName);
const FuncType = @TypeOf(method);
const funcInfo = @typeInfo(FuncType).@"fn";
// 检查参数数量(包含 self)
const expectedParamCount = expectedParamTypes.len;
if (funcInfo.params.len != expectedParamCount) {
@compileError("参数数量错误,期望 " ++ std.fmt.comptimePrint("{}", .{expectedParamCount}) ++ " 个,实际 " ++ std.fmt.comptimePrint("{}", .{funcInfo.params.len}));
}
// 检查参数类型
for (funcInfo.params[0..], 0..) |param, i| {
const expectedType = expectedParamTypes[i];
if (param.type != expectedType) {
@compileError("参数 " ++ std.fmt.comptimePrint("{}", .{i}) ++ " 类型应为 " ++ @typeName(expectedType) ++ ",实际为 " ++ @typeName(param.type.?));
}
}
// 检查返回类型
if (funcInfo.return_type != expectedReturnType) {
@compileError("返回类型应为 " ++ @typeName(expectedReturnType) ++ ",实际为 " ++ @typeName(funcInfo.return_type.?));
}
}
// 编译时执行检查
comptime {
checkMethodSignature(Dpu, "add", &[_]type{ *Tensor, i32, *Tensor, *Tensor }, Tensor);
}
pub fn main() void {
std.debug.print("函数签名检查通过!\n", .{});
}