Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lukamac committed Feb 28, 2024
1 parent c283557 commit d746346
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
8 changes: 6 additions & 2 deletions test/HeaderWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def generate_source(self, name, body):
with open(filepath, "w") as file:
file.write(body)

def generate_vector_source(self, name, size, _type, init=None, golden=None, section="PI_L1"):
def generate_vector_source(
self, name, size, _type, init=None, golden=None, section="PI_L1"
):
render = ""
render += f'#include "{name}.h"\n\n'
render += self.render_vector(name, f"{section} {_type}", size, init=init)
Expand All @@ -168,7 +170,9 @@ def generate_vector_source(self, name, size, _type, init=None, golden=None, sect

self.generate_source(name, render)

def generate_vector_files(self, name, size, _type, init=None, golden=None, section="PI_L1"):
def generate_vector_files(
self, name, size, _type, init=None, golden=None, section="PI_L1"
):
self.generate_vector_source(name, size, _type, init, golden, section)
self.generate_vector_header(name, size, _type, init, golden)

Expand Down
4 changes: 3 additions & 1 deletion test/NeurekaTestConf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,7 @@ def check_valid_out_type_with_norm_quant(self) -> NeurekaTestConf:
@classmethod
def check_valid_wmem(cls, v: WmemLiteral) -> WmemLiteral:
_supported_wmem = ["tcdm", "sram"]
assert v in _supported_wmem, f"Unsupported wmem {v}. Supported {_supported_wmem}."
assert (
v in _supported_wmem
), f"Unsupported wmem {v}. Supported {_supported_wmem}."
return v
25 changes: 16 additions & 9 deletions test/NnxTestClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from NeuralEngineFunctionalModel import NeuralEngineFunctionalModel
from TestClasses import IntegerType, KernelShape, Padding, Stride, implies


WmemLiteral = Literal["tcdm", "sram"]


Expand Down Expand Up @@ -351,11 +350,15 @@ def generate(self, test_name: str, test: NnxTest):
test.conf.depthwise,
)
if test.conf.wmem == "sram":
section = "__attribute__((section(\".weightmem_sram\")))"
section = '__attribute__((section(".weightmem_sram")))'
else:
section = "PI_L1"
self.header_writer.generate_vector_files(
"weight", _type="uint8_t", size=weight_init.size, init=weight_init, section=section
"weight",
_type="uint8_t",
size=weight_init.size,
init=weight_init,
section=section,
)

# Render scale
Expand Down Expand Up @@ -406,14 +409,18 @@ def generate(self, test_name: str, test: NnxTest):
"offset": weight_offset,
},
"scale": {
"bits": test.conf.scale_type._bits
if test.conf.scale_type is not None
else 0
"bits": (
test.conf.scale_type._bits
if test.conf.scale_type is not None
else 0
)
},
"bias": {
"bits": test.conf.bias_type._bits
if test.conf.bias_type is not None
else 0
"bits": (
test.conf.bias_type._bits
if test.conf.bias_type is not None
else 0
)
},
"padding": {
"top": test.conf.padding.top,
Expand Down
3 changes: 1 addition & 2 deletions test/TestClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,4 @@ def model_dump(
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> dict[str, Any]:
...
) -> dict[str, Any]: ...

0 comments on commit d746346

Please sign in to comment.