Skip to content

Commit

Permalink
Fix allowed types for neureka_v2
Browse files Browse the repository at this point in the history
Neureka_v2 uses signed 8-bit scales. It should also support signed
32-bit scales but it is not tested.
Also for the output type, I have removed the 32-bit unquantized output
for now because there were some issues with the spatial strides (height
and width). There was the issue that output stride in the previous
neureka was multiplied with the number of output bytes, but in
neureka_v2 it seems that gets done in the accelerator, at least by the
look of the model. But when I tried with that change, the whole thing
broke. The output offsets were correct but still got all the wrong
results so I guess the output strides affect some other parts of the
model which I am not aware of.
  • Loading branch information
lukamac committed Dec 6, 2024
1 parent a57ff03 commit f3466a2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/NeurekaV2TestConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def check_valid_in_type(cls, v: IntegerType) -> IntegerType:
@field_validator("out_type")
@classmethod
def check_valid_out_type(cls, v: IntegerType) -> IntegerType:
NeurekaV2TestConf._check_type("out_type", v, ["uint8", "int8", "int32"])
NeurekaV2TestConf._check_type("out_type", v, ["uint8", "int8"])
return v

@field_validator("weight_type")
Expand All @@ -72,7 +72,7 @@ def check_valid_weight_type(cls, v: IntegerType) -> IntegerType:
@classmethod
def check_valid_scale_type(cls, v: Optional[IntegerType]) -> Optional[IntegerType]:
if v is not None:
NeurekaV2TestConf._check_type("scale_type", v, ["uint8", "uint32"])
NeurekaV2TestConf._check_type("scale_type", v, ["int8"])
return v

@field_validator("bias_type")
Expand Down

0 comments on commit f3466a2

Please sign in to comment.