Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assert statement to check model structure on model_visualization_test #20208

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 109 additions & 32 deletions integration_tests/model_visualization_test.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,61 @@
import re
from pathlib import Path

import keras
from keras.src.utils import model_to_dot
from keras.src.utils import plot_model


def assert_file_exists(path):
assert Path(path).is_file(), "File does not exist"


def parse_text_from_html(html):
pattern = r"<font[^>]*>(.*?)</font>"
matches = re.findall(pattern, html)

for match in matches:
clean_text = re.sub(r"<[^>]*>", "", match)
return clean_text
return ""


def get_node_text(node):
attributes = node.get_attributes()

if "label" in attributes:
html = node.get_attributes()["label"]
return parse_text_from_html(html)
else:
return None


def get_edge_dict(dot):
node_dict = dict()
for node in dot.get_nodes():
node_dict[node.get_name()] = get_node_text(node)

edge_dict = dict()
for edge in dot.get_edges():
edge_dict[node_dict[edge.get_source()]] = node_dict[
edge.get_destination()
]

return edge_dict


def test_plot_sequential_model():
model = keras.Sequential(
[
keras.Input((3,)),
keras.layers.Dense(4, activation="relu"),
keras.layers.Dense(1, activation="sigmoid"),
keras.Input((3,), name="input"),
keras.layers.Dense(4, activation="relu", name="dense"),
keras.layers.Dense(1, activation="sigmoid", name="dense_1"),
]
)

edge_dict = get_edge_dict(model_to_dot(model))
assert edge_dict["dense (Dense)"] == "dense_1 (Dense)"

file_name = "sequential.png"
plot_model(model, file_name)
assert_file_exists(file_name)
Expand Down Expand Up @@ -90,23 +130,39 @@ def test_plot_sequential_model():


def test_plot_functional_model():
inputs = keras.Input((3,))
x = keras.layers.Dense(4, activation="relu", trainable=False)(inputs)
inputs = keras.Input((3,), name="input")
x = keras.layers.Dense(4, activation="relu", trainable=False, name="dense")(
inputs
)
residual = x
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_1")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_2")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_3")(x)
x += residual
residual = x
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_4")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_5")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_6")(x)
x += residual
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
x = keras.layers.Dropout(0.5, name="dropout")(x)
outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x)

model = keras.Model(inputs, outputs)

edge_dict = get_edge_dict(model_to_dot(model))

assert edge_dict["input (InputLayer)"] == "dense (Dense)"
assert edge_dict["dense (Dense)"] == "add (Add)"
assert edge_dict["dense_1 (Dense)"] == "dense_2 (Dense)"
assert edge_dict["dense_2 (Dense)"] == "dense_3 (Dense)"
assert edge_dict["dense_3 (Dense)"] == "add (Add)"
assert edge_dict["add (Add)"] == "add_1 (Add)"
assert edge_dict["dense_4 (Dense)"] == "dense_5 (Dense)"
assert edge_dict["dense_5 (Dense)"] == "dense_6 (Dense)"
assert edge_dict["dense_6 (Dense)"] == "add_1 (Add)"
assert edge_dict["add_1 (Add)"] == "dropout (Dropout)"
assert edge_dict["dropout (Dropout)"] == "dense_7 (Dense)"

file_name = "functional.png"
plot_model(model, file_name)
assert_file_exists(file_name)
Expand Down Expand Up @@ -291,26 +347,40 @@ def call(self, x):


def test_plot_nested_functional_model():
inputs = keras.Input((3,))
x = keras.layers.Dense(4, activation="relu")(inputs)
x = keras.layers.Dense(4, activation="relu")(x)
outputs = keras.layers.Dense(3, activation="relu")(x)
inner_model = keras.Model(inputs, outputs)

inputs = keras.Input((3,))
x = keras.layers.Dense(3, activation="relu", trainable=False)(inputs)
inputs = keras.Input((3,), name="input")
x = keras.layers.Dense(4, activation="relu", name="dense")(inputs)
x = keras.layers.Dense(4, activation="relu", name="dense_1")(x)
outputs = keras.layers.Dense(3, activation="relu", name="dense_2")(x)
inner_model = keras.Model(inputs, outputs, name="inner_model")

inputs = keras.Input((3,), name="input_1")
x = keras.layers.Dense(
3, activation="relu", trainable=False, name="dense_3"
)(inputs)
residual = x
x = inner_model(x)
x += residual
x = keras.layers.Add(name="add")([x, residual])
residual = x
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(4, activation="relu")(x)
x = keras.layers.Dense(3, activation="relu")(x)
x += residual
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_4")(x)
x = keras.layers.Dense(4, activation="relu", name="dense_5")(x)
x = keras.layers.Dense(3, activation="relu", name="dense_6")(x)
x = keras.layers.Add(name="add_1")([x, residual])
x = keras.layers.Dropout(0.5, name="dropout")(x)
outputs = keras.layers.Dense(1, activation="sigmoid", name="dense_7")(x)
model = keras.Model(inputs, outputs)

edge_dict = get_edge_dict(model_to_dot(model))

assert edge_dict["input_1 (InputLayer)"] == "dense_3 (Dense)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such statements should be made via self.assertEqual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the source code as your advice.
Additionally, I removed the main statement also, since, I think it can be replaced by just running ModelVisualizationTest class by pytest.

assert edge_dict["dense_3 (Dense)"] == "add (Add)"
assert edge_dict["inner_model (Functional)"] == "add (Add)"
assert edge_dict["add (Add)"] == "add_1 (Add)"
assert edge_dict["dense_4 (Dense)"] == "dense_5 (Dense)"
assert edge_dict["dense_5 (Dense)"] == "dense_6 (Dense)"
assert edge_dict["dense_6 (Dense)"] == "add_1 (Add)"
assert edge_dict["add_1 (Add)"] == "dropout (Dropout)"
assert edge_dict["dropout (Dropout)"] == "dense_7 (Dense)"

file_name = "nested-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)
Expand Down Expand Up @@ -415,15 +485,22 @@ class ConcatLayer(keras.Layer):
def call(self, xs):
return keras.ops.concatenate(xs, axis=1)

inputs = keras.Input((2,))
inputs = keras.Input((2,), name="input")
a, b = SplitLayer()(inputs)

a = keras.layers.Dense(2)(a)
b = keras.layers.Dense(2)(b)
a = keras.layers.Dense(2, name="dense")(a)
b = keras.layers.Dense(2, name="dense_1")(b)

outputs = ConcatLayer()([a, b])
outputs = ConcatLayer(name="concat_layer")([a, b])
model = keras.Model(inputs, outputs)

edge_dict = get_edge_dict(model_to_dot(model))

assert edge_dict["input (InputLayer)"] == "split_layer (SplitLayer)"
assert edge_dict["split_layer (SplitLayer)"] == "dense_1 (Dense)"
assert edge_dict["dense (Dense)"] == "concat_layer (ConcatLayer)"
assert edge_dict["dense_1 (Dense)"] == "concat_layer (ConcatLayer)"

file_name = "split-functional.png"
plot_model(model, file_name, expand_nested=True)
assert_file_exists(file_name)
Expand Down