diff --git a/test/testgen.py b/test/testgen.py index 84691e4..5c0d1ba 100644 --- a/test/testgen.py +++ b/test/testgen.py @@ -22,8 +22,6 @@ import typing from typing import Optional, Set, Type, Union -import numpy as np -import numpy.typing as npt import toml from HeaderWriter import HeaderWriter @@ -34,7 +32,6 @@ NnxTestGenerator, NnxTestHeaderGenerator, NnxWeight, - WmemLiteral, ) @@ -72,6 +69,10 @@ def test_gen( nnxTestConfCls: Type[NnxTestConf], nnxWeightCls: Type[NnxWeight], ): + assert not ( + args.gen_ones and args.gen_incremented + ), "You can choose only one method for input generation." + if args.conf.endswith(".toml"): test_conf_dict = toml.load(args.conf) elif args.conf.endswith(".json"): @@ -238,10 +239,6 @@ def add_common_arguments(parser: argparse.ArgumentParser): args = parser.parse_args() -assert not ( - args.gen_ones and args.gen_incremented -), "You can choose only one method for input generation." - testConfCls, weightCls = NnxMapping[args.accelerator] args.func(args, testConfCls, weightCls)