diff --git a/integration_tests/model_visualization_test.py b/integration_tests/model_visualization_test.py index 61a06f1288c..cec3299e67d 100644 --- a/integration_tests/model_visualization_test.py +++ b/integration_tests/model_visualization_test.py @@ -1,6 +1,9 @@ +import re from pathlib import Path import keras +from keras.src import testing +from keras.src.utils import model_to_dot from keras.src.utils import plot_model @@ -8,449 +11,530 @@ def assert_file_exists(path): assert Path(path).is_file(), "File does not exist" -def test_plot_sequential_model(): - model = keras.Sequential( - [ - keras.Input((3,)), - keras.layers.Dense(4, activation="relu"), - keras.layers.Dense(1, activation="sigmoid"), +def parse_text_from_html(html): + pattern = r"]*>(.*?)" + matches = re.findall(pattern, html) + + for match in matches: + clean_text = re.sub(r"<[^>]*>", "", match) + return clean_text + return "" + + +def get_node_text(node): + attributes = node.get_attributes() + + if "label" in attributes: + html = node.get_attributes()["label"] + return parse_text_from_html(html) + else: + return None + + +def get_edge_dict(dot): + node_dict = dict() + for node in dot.get_nodes(): + node_dict[node.get_name()] = get_node_text(node) + + edge_dict = dict() + for edge in dot.get_edges(): + edge_dict[node_dict[edge.get_source()]] = node_dict[ + edge.get_destination() ] - ) - file_name = "sequential.png" - plot_model(model, file_name) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes.png" - plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes-show_dtype.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - ) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes-show_dtype-show_layer_names.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - ) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - ) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - rankdir="LR", - ) - assert_file_exists(file_name) - - file_name = "sequential-show_layer_activations-show_trainable.png" - plot_model( - model, - file_name, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - -def test_plot_functional_model(): - inputs = keras.Input((3,)) - x = keras.layers.Dense(4, activation="relu", trainable=False)(inputs) - residual = x - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(4, activation="relu")(x) - x += residual - residual = x - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(4, activation="relu")(x) - x += residual - x = keras.layers.Dropout(0.5)(x) - outputs = keras.layers.Dense(1, activation="sigmoid")(x) - - model = keras.Model(inputs, outputs) - - file_name = "functional.png" - plot_model(model, file_name) - assert_file_exists(file_name) - - file_name = "functional-show_shapes.png" - plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) - - file_name = "functional-show_shapes-show_dtype.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - ) - assert_file_exists(file_name) - - file_name = "functional-show_shapes-show_dtype-show_layer_names.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - ) - assert_file_exists(file_name) - - file_name = "functional-show_shapes-show_dtype-show_layer_activations.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - ) - assert_file_exists(file_name) - - file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - rankdir="LR", - ) - assert_file_exists(file_name) - - file_name = "functional-show_layer_activations-show_trainable.png" - plot_model( - model, - file_name, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - file_name = ( - "functional-show_shapes-show_layer_activations-show_trainable.png" - ) - plot_model( - model, - file_name, - show_shapes=True, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - -def test_plot_subclassed_model(): - class MyModel(keras.Model): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.dense_1 = keras.layers.Dense(3, activation="relu") - self.dense_2 = keras.layers.Dense(1, activation="sigmoid") - - def call(self, x): - return self.dense_2(self.dense_1(x)) - - model = MyModel() - model.build((None, 3)) - - file_name = "subclassed.png" - plot_model(model, file_name) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes.png" - plot_model(model, file_name, show_shapes=True) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes-show_dtype.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - ) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - ) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes-show_dtype-show_layer_activations.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - ) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - rankdir="LR", - ) - assert_file_exists(file_name) - - file_name = "subclassed-show_layer_activations-show_trainable.png" - plot_model( - model, - file_name, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - file_name = ( - "subclassed-show_shapes-show_layer_activations-show_trainable.png" - ) - plot_model( - model, - file_name, - show_shapes=True, - show_layer_activations=True, - show_trainable=True, - ) - assert_file_exists(file_name) - - -def test_plot_nested_functional_model(): - inputs = keras.Input((3,)) - x = keras.layers.Dense(4, activation="relu")(inputs) - x = keras.layers.Dense(4, activation="relu")(x) - outputs = keras.layers.Dense(3, activation="relu")(x) - inner_model = keras.Model(inputs, outputs) - - inputs = keras.Input((3,)) - x = keras.layers.Dense(3, activation="relu", trainable=False)(inputs) - residual = x - x = inner_model(x) - x += residual - residual = x - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(4, activation="relu")(x) - x = keras.layers.Dense(3, activation="relu")(x) - x += residual - x = keras.layers.Dropout(0.5)(x) - outputs = keras.layers.Dense(1, activation="sigmoid")(x) - model = keras.Model(inputs, outputs) - - file_name = "nested-functional.png" - plot_model(model, file_name, expand_nested=True) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes.png" - plot_model( - model, - file_name, - show_shapes=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_dtype.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_dtype-show_layer_names.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - show_layer_names=True, - show_layer_activations=True, - show_trainable=True, - rankdir="LR", - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_layer_activations-show_trainable.png" - plot_model( - model, - file_name, - show_layer_activations=True, - show_trainable=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501 - plot_model( - model, - file_name, - show_shapes=True, - show_layer_activations=True, - show_trainable=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - -def test_plot_functional_model_with_splits_and_merges(): - class SplitLayer(keras.Layer): - def call(self, x): - return list(keras.ops.split(x, 2, axis=1)) - - class ConcatLayer(keras.Layer): - def call(self, xs): - return keras.ops.concatenate(xs, axis=1) - - inputs = keras.Input((2,)) - a, b = SplitLayer()(inputs) - - a = keras.layers.Dense(2)(a) - b = keras.layers.Dense(2)(b) - - outputs = ConcatLayer()([a, b]) - model = keras.Model(inputs, outputs) - - file_name = "split-functional.png" - plot_model(model, file_name, expand_nested=True) - assert_file_exists(file_name) - - file_name = "split-functional-show_shapes.png" - plot_model( - model, - file_name, - show_shapes=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - file_name = "split-functional-show_shapes-show_dtype.png" - plot_model( - model, - file_name, - show_shapes=True, - show_dtype=True, - expand_nested=True, - ) - assert_file_exists(file_name) - - -if __name__ == "__main__": - test_plot_sequential_model() - test_plot_functional_model() - test_plot_subclassed_model() - test_plot_nested_functional_model() - test_plot_functional_model_with_splits_and_merges() + + return edge_dict + + +class ModelVisualizationTest(testing.TestCase): + + def test_plot_sequential_model(self): + model = keras.Sequential( + [ + keras.Input((3,), name="input"), + keras.layers.Dense(4, activation="relu", name="dense"), + keras.layers.Dense(1, activation="sigmoid", name="dense_1"), + ] + ) + + edge_dict = get_edge_dict(model_to_dot(model)) + self.assertEqual(edge_dict["dense (Dense)"], "dense_1 (Dense)") + + file_name = "sequential.png" + plot_model(model, file_name) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes.png" + plot_model(model, file_name, show_shapes=True) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes-show_dtype.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + ) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes-show_dtype-show_layer_names.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + ) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + ) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + file_name = "sequential-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + rankdir="LR", + ) + assert_file_exists(file_name) + + file_name = "sequential-show_layer_activations-show_trainable.png" + plot_model( + model, + file_name, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + def test_plot_functional_model(self): + inputs = keras.Input((3,), name="input") + x = keras.layers.Dense( + 4, activation="relu", trainable=False, name="dense" + )(inputs) + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_1")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_2")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_3")(x) + x += residual + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_4")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_5")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_6")(x) + x += residual + x = keras.layers.Dropout(0.5, name="dropout")(x) + outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x) + + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + + self.assertEqual(edge_dict["input (InputLayer)"], "dense (Dense)") + self.assertEqual(edge_dict["dense (Dense)"], "add (Add)") + self.assertEqual(edge_dict["dense_1 (Dense)"], "dense_2 (Dense)") + self.assertEqual(edge_dict["dense_2 (Dense)"], "dense_3 (Dense)") + self.assertEqual(edge_dict["dense_3 (Dense)"], "add (Add)") + self.assertEqual(edge_dict["add (Add)"], "add_1 (Add)") + self.assertEqual(edge_dict["dense_4 (Dense)"], "dense_5 (Dense)") + self.assertEqual(edge_dict["dense_5 (Dense)"], "dense_6 (Dense)") + self.assertEqual(edge_dict["dense_6 (Dense)"], "add_1 (Add)") + self.assertEqual(edge_dict["add_1 (Add)"], "dropout (Dropout)") + self.assertEqual(edge_dict["dropout (Dropout)"], "dense_7 (Dense)") + + file_name = "functional.png" + plot_model(model, file_name) + assert_file_exists(file_name) + + file_name = "functional-show_shapes.png" + plot_model(model, file_name, show_shapes=True) + assert_file_exists(file_name) + + file_name = "functional-show_shapes-show_dtype.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + ) + assert_file_exists(file_name) + + file_name = "functional-show_shapes-show_dtype-show_layer_names.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + ) + assert_file_exists(file_name) + + file_name = ( + "functional-show_shapes-show_dtype-show_layer_activations.png" + ) + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + ) + assert_file_exists(file_name) + + file_name = "functional-show_shapes-show_dtype-show_layer_activations-show_trainable.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + file_name = "functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + rankdir="LR", + ) + assert_file_exists(file_name) + + file_name = "functional-show_layer_activations-show_trainable.png" + plot_model( + model, + file_name, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + file_name = ( + "functional-show_shapes-show_layer_activations-show_trainable.png" + ) + plot_model( + model, + file_name, + show_shapes=True, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + def test_plot_subclassed_model(self): + class MyModel(keras.Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dense_1 = keras.layers.Dense(3, activation="relu") + self.dense_2 = keras.layers.Dense(1, activation="sigmoid") + + def call(self, x): + return self.dense_2(self.dense_1(x)) + + model = MyModel() + model.build((None, 3)) + + file_name = "subclassed.png" + plot_model(model, file_name) + assert_file_exists(file_name) + + file_name = "subclassed-show_shapes.png" + plot_model(model, file_name, show_shapes=True) + assert_file_exists(file_name) + + file_name = "subclassed-show_shapes-show_dtype.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + ) + assert_file_exists(file_name) + + file_name = "subclassed-show_shapes-show_dtype-show_layer_names.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + ) + assert_file_exists(file_name) + + file_name = ( + "subclassed-show_shapes-show_dtype-show_layer_activations.png" + ) + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + ) + assert_file_exists(file_name) + + file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + file_name = "subclassed-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + rankdir="LR", + ) + assert_file_exists(file_name) + + file_name = "subclassed-show_layer_activations-show_trainable.png" + plot_model( + model, + file_name, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + file_name = ( + "subclassed-show_shapes-show_layer_activations-show_trainable.png" + ) + plot_model( + model, + file_name, + show_shapes=True, + show_layer_activations=True, + show_trainable=True, + ) + assert_file_exists(file_name) + + def test_plot_nested_functional_model(self): + inputs = keras.Input((3,), name="input") + x = keras.layers.Dense(4, activation="relu", name="dense")(inputs) + x = keras.layers.Dense(4, activation="relu", name="dense_1")(x) + outputs = keras.layers.Dense(3, activation="relu", name="dense_2")(x) + inner_model = keras.Model(inputs, outputs, name="inner_model") + + inputs = keras.Input((3,), name="input_1") + x = keras.layers.Dense( + 3, activation="relu", trainable=False, name="dense_3" + )(inputs) + residual = x + x = inner_model(x) + x = keras.layers.Add(name="add")([x, residual]) + residual = x + x = keras.layers.Dense(4, activation="relu", name="dense_4")(x) + x = keras.layers.Dense(4, activation="relu", name="dense_5")(x) + x = keras.layers.Dense(3, activation="relu", name="dense_6")(x) + x = keras.layers.Add(name="add_1")([x, residual]) + x = keras.layers.Dropout(0.5, name="dropout")(x) + outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x) + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + + self.assertEqual(edge_dict["input_1 (InputLayer)"], "dense_3 (Dense)") + self.assertEqual(edge_dict["dense_3 (Dense)"], "add (Add)") + self.assertEqual(edge_dict["inner_model (Functional)"], "add (Add)") + self.assertEqual(edge_dict["add (Add)"], "add_1 (Add)") + self.assertEqual(edge_dict["dense_4 (Dense)"], "dense_5 (Dense)") + self.assertEqual(edge_dict["dense_5 (Dense)"], "dense_6 (Dense)") + self.assertEqual(edge_dict["dense_6 (Dense)"], "add_1 (Add)") + self.assertEqual(edge_dict["add_1 (Add)"], "dropout (Dropout)") + self.assertEqual(edge_dict["dropout (Dropout)"], "dense_7 (Dense)") + + file_name = "nested-functional.png" + plot_model(model, file_name, expand_nested=True) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes.png" + plot_model( + model, + file_name, + show_shapes=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes-show_dtype.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = ( + "nested-functional-show_shapes-show_dtype-show_layer_names.png" + ) + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes-show_dtype-show_layer_names-show_layer_activations-show_trainable-LR.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + show_layer_names=True, + show_layer_activations=True, + show_trainable=True, + rankdir="LR", + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = ( + "nested-functional-show_layer_activations-show_trainable.png" + ) + plot_model( + model, + file_name, + show_layer_activations=True, + show_trainable=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "nested-functional-show_shapes-show_layer_activations-show_trainable.png" # noqa: E501 + plot_model( + model, + file_name, + show_shapes=True, + show_layer_activations=True, + show_trainable=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + def test_plot_functional_model_with_splits_and_merges(self): + class SplitLayer(keras.Layer): + def call(self, x): + return list(keras.ops.split(x, 2, axis=1)) + + class ConcatLayer(keras.Layer): + def call(self, xs): + return keras.ops.concatenate(xs, axis=1) + + inputs = keras.Input((2,), name="input") + a, b = SplitLayer()(inputs) + + a = keras.layers.Dense(2, name="dense")(a) + b = keras.layers.Dense(2, name="dense_1")(b) + + outputs = ConcatLayer(name="concat_layer")([a, b]) + model = keras.Model(inputs, outputs) + + edge_dict = get_edge_dict(model_to_dot(model)) + + self.assertEqual( + edge_dict["input (InputLayer)"], "split_layer (SplitLayer)" + ) + self.assertEqual( + edge_dict["split_layer (SplitLayer)"], "dense_1 (Dense)" + ) + self.assertEqual( + edge_dict["dense (Dense)"], "concat_layer (ConcatLayer)" + ) + self.assertEqual( + edge_dict["dense_1 (Dense)"], "concat_layer (ConcatLayer)" + ) + + file_name = "split-functional.png" + plot_model(model, file_name, expand_nested=True) + assert_file_exists(file_name) + + file_name = "split-functional-show_shapes.png" + plot_model( + model, + file_name, + show_shapes=True, + expand_nested=True, + ) + assert_file_exists(file_name) + + file_name = "split-functional-show_shapes-show_dtype.png" + plot_model( + model, + file_name, + show_shapes=True, + show_dtype=True, + expand_nested=True, + ) + assert_file_exists(file_name)