diff --git a/test/NnxTestClasses.py b/test/NnxTestClasses.py index 41d5131..8d4eed1 100644 --- a/test/NnxTestClasses.py +++ b/test/NnxTestClasses.py @@ -48,6 +48,8 @@ class NnxTestConf(BaseModel): has_norm_quant: bool has_bias: bool has_relu: bool + synthetic_weights: bool + synthetic_inputs: bool @model_validator(mode="after") # type: ignore def check_valid_depthwise_channels(self) -> NnxTestConf: @@ -116,6 +118,8 @@ def __init__( scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, global_shift: Optional[torch.Tensor] = torch.Tensor([0]), + synthetic_weights: Optional[bool] = False, + synthetic_inputs: Optional[bool] = False, ) -> None: self.conf = conf self.input = input @@ -124,6 +128,8 @@ def __init__( self.scale = scale self.bias = bias self.global_shift = global_shift + self.synthetic_weights = synthetic_weights + self.synthetic_inputs = synthetic_inputs def is_valid(self) -> bool: return all( @@ -243,20 +249,30 @@ def from_conf( bias_shape = (1, conf.out_channel, 1, 1) if input is None: - input = NnxTestGenerator._random_data( - _type=conf.in_type, - shape=input_shape, - ) + if conf.synthetic_inputs: + inputs = torch.zeros((1, conf.in_channel, conf.in_height, conf.in_width), dtype=torch.int64) + for i in range(conf.in_channel): + inputs[:, i,0,0] = i + else: + input = NnxTestGenerator._random_data( + _type=conf.in_type, + shape=input_shape, + ) if weight is None: - weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN - weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1) - weight = NnxTestGenerator._random_data_normal( - mean = weight_mean, - std = weight_std, - _type=conf.weight_type, - shape=weight_shape, - ) + if conf.synthetic_weights: + weight = torch.zeros((conf.out_channel, 1 if conf.depthwise else conf.in_channel, conf.kernel_shape.height, conf.kernel_shape.width), dtype=torch.int64) + for i in range(0, min(weight.shape[0], weight.shape[1])): + weight[i,i,0,0] = 1 + else: + weight_mean = NnxTestGenerator._DEFAULT_WEIGHT_MEAN + weight_std = NnxTestGenerator._DEFAULT_WEIGHT_STDEV * (1<<(conf.weight_type._bits-1)-1) + weight = NnxTestGenerator._random_data_normal( + mean = weight_mean, + std = weight_std, + _type=conf.weight_type, + shape=weight_shape, + ) if conf.has_norm_quant: if scale is None: @@ -306,6 +322,8 @@ def from_conf( scale=scale, bias=bias, global_shift=global_shift, + synthetic_inputs=conf.synthetic_inputs, + synthetic_weights=conf.synthetic_weights, ) @staticmethod @@ -361,7 +379,10 @@ def generate(self, test_name: str, test: NnxTest): weight_type = test.conf.weight_type weight_bits = weight_type._bits assert weight_bits > 1 and weight_bits <= 8 - weight_offset = -(2 ** (weight_bits - 1)) + if test.synthetic_weights: + weight_offset = 0 + else: + weight_offset = -(2 ** (weight_bits - 1)) weight_out_ch, weight_in_ch, weight_ks_h, weight_ks_w = test.weight.shape weight_data: np.ndarray = test.weight.numpy() - weight_offset weight_init = self.weightEncode( diff --git a/test/testgen.py b/test/testgen.py index c128ff4..e5378e4 100644 --- a/test/testgen.py +++ b/test/testgen.py @@ -86,21 +86,7 @@ def test_gen( exit(-1) test_conf = nnxTestConfCls.model_validate(test_conf_dict) - if test_conf_dict['synthetic_weights']: - import torch - weight = torch.zeros((test_conf.out_channel, 1 if test_conf.depthwise else test_conf.in_channel, test_conf.kernel_shape.height, test_conf.kernel_shape.width), dtype=torch.int64) - for i in range(0, min(weight.shape[0], weight.shape[1])): - weight[i,i,0,0] = 1 - else: - weight = None - if test_conf_dict['synthetic_inputs']: - import torch - inputs = torch.zeros((1, test_conf.in_channel, test_conf.in_height, test_conf.in_width), dtype=torch.int64) - for i in range(test_conf.in_channel): - inputs[:, i,0,0] = i - else: - inputs = None - test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors, weight=weight, input=inputs) + test = NnxTestGenerator.from_conf(test_conf, verbose=args.print_tensors) if not args.skip_save: test.save(args.test_dir) if args.headers: