Skip to content

Commit

Permalink
keras v3 converter clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
calad0i committed Dec 17, 2024
1 parent 8e61d91 commit 22667a3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
56 changes: 47 additions & 9 deletions hls4ml/converters/keras_v3/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
from types import FunctionType
from typing import Any, Callable, Sequence, TypedDict
from typing import Any, Callable, Sequence, TypedDict, overload


class DefaultConfig(TypedDict, total=False):
Expand All @@ -26,6 +26,14 @@ class DefaultConfig(TypedDict, total=False):
registry: dict[str, T_kv3_handler] = {}


@overload
def register(cls: type) -> type: ...


@overload
def register(cls: str) -> Callable[[T_kv3_handler], T_kv3_handler]: ...


def register(cls: str | type):
"""Decorator to register a handler for a specific layer class. Suggested to decorate the `KerasV3LayerHandler` class.
Expand All @@ -51,11 +59,13 @@ def my_layer_handler(layer, inp_tensors, out_tensors):
```
"""

def deco(func: T_kv3_handler):
def deco(func):
if isinstance(cls, str):
registry[cls] = func
for k in getattr(func, 'handles', ()):
registry[k] = func
if isinstance(cls, type):
return cls
return func

if isinstance(cls, type):
Expand All @@ -79,7 +89,7 @@ def __call__(
layer: 'keras.Layer',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
) -> tuple[dict[str, Any], ...]:
"""Handle a keras layer. Return a tuple of dictionaries, each
dictionary representing a layer (module) in the HLS model. One
layer may correspond one or more dictionaries (e.g., layers with
Expand Down Expand Up @@ -114,8 +124,7 @@ def __call__(
dict[str, Any] | tuple[dict[str, Any], ...]
layer configuration(s) for the HLS model to be consumed by
the ModelGraph constructor
""" # noqa: E501
import keras
"""

name = layer.name
class_name = layer.__class__.__name__
Expand Down Expand Up @@ -150,12 +159,23 @@ def __call__(
ret = (config,)

# If activation exists, append it

act_config, intermediate_tensor_name = self.maybe_get_activation_config(layer, out_tensors)
if act_config is not None:
ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name]
ret = *ret, act_config

return ret

def maybe_get_activation_config(self, layer, out_tensors):
import keras

activation = getattr(layer, 'activation', None)
name = layer.name
if activation not in (keras.activations.linear, None):
assert len(out_tensors) == 1, f"Layer {name} has more than one output, but has an activation function"
assert isinstance(activation, FunctionType), f"Activation function for layer {name} is not a function"
intermediate_tensor_name = f'{out_tensors[0].name}_activation'
ret[0]['output_keras_tensor_names'] = [intermediate_tensor_name]
act_cls_name = activation.__name__
act_config = {
'class_name': 'Activation',
Expand All @@ -164,9 +184,8 @@ def __call__(
'input_keras_tensor_names': [intermediate_tensor_name],
'output_keras_tensor_names': [out_tensors[0].name],
}
ret = *ret, act_config

return ret
return act_config, intermediate_tensor_name
return None, None

def handle(
self,
Expand All @@ -175,3 +194,22 @@ def handle(
out_tensors: Sequence['KerasTensor'],
) -> dict[str, Any] | tuple[dict[str, Any], ...]:
return {}

def load_weight(self, layer: 'keras.Layer', key: str):
"""Load a weight from a layer.
Parameters
----------
layer : keras.Layer
The layer to load the weight from.
key : str
The key of the weight to load.
Returns
-------
np.ndarray
The weight.
"""
import keras

return keras.ops.convert_to_numpy(getattr(layer, key))
8 changes: 3 additions & 5 deletions hls4ml/converters/keras_v3/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from math import ceil
from typing import Sequence

import numpy as np

from ._base import KerasV3LayerHandler, register

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -40,9 +38,9 @@ def handle(
assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}"
assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}"

kernel = np.array(layer.kernel)
kernel = self.load_weight(layer, 'kernel')
if layer.use_bias:
bias = np.array(layer.bias)
bias = self.load_weight(layer, 'bias')
else:
bias = None

Expand Down Expand Up @@ -113,7 +111,7 @@ def handle(
config['depth_multiplier'] = layer.depth_multiplier
elif isinstance(layer, BaseSeparableConv):
config['depthwise_data'] = kernel
config['pointwise_data'] = np.array(layer.pointwise_kernel)
config['pointwise_data'] = self.load_weight(layer, 'pointwise_kernel')
config['depth_multiplier'] = layer.depth_multiplier
elif isinstance(layer, BaseConv):
config['weight_data'] = kernel
Expand Down
2 changes: 1 addition & 1 deletion hls4ml/converters/keras_v3/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def handle(
config = {
'data_format': 'channels_last',
'weight_data': kernel,
'bias_data': np.array(layer.bias) if layer.use_bias else None,
'bias_data': self.load_weight(layer, 'bias') if layer.use_bias else None,
'n_out': kernel.shape[1],
'n_in': kernel.shape[0],
}
Expand Down
6 changes: 2 additions & 4 deletions hls4ml/converters/keras_v3/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def handle(
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
import keras

assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor'
assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor'

Expand All @@ -56,11 +54,11 @@ def handle(

equation = strip_batch_dim(layer.equation)

kernel = keras.ops.convert_to_numpy(layer.kernel)
kernel = self.load_weight(layer, 'kernel')

bias = None
if layer.bias_axes:
bias = keras.ops.convert_to_numpy(layer.bias)
bias = self.load_weight(layer, 'bias')

return {
'class_name': 'EinsumDense',
Expand Down

0 comments on commit 22667a3

Please sign in to comment.