You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Because tf.keras.utils.plot_model does not work with the agent network so I made a custom function to do this
Example plot:
Here is the code:
importosimportsystry:
# pydot-ng is a fork of pydot that is better maintained.importpydot_ngaspydotexceptImportError:
# pydotplus is an improved version of pydottry:
importpydotplusaspydotexceptImportError:
# Fall back on pydot if necessary.try:
importpydotexceptImportError:
pydot=Nonedefcheck_pydot():
"""Returns True if PyDot is available."""returnpydotisnotNonedefcheck_graphviz():
"""Returns True if both PyDot and Graphviz are available."""ifnotcheck_pydot():
returnFalsetry:
# Attempt to create an image of a blank graph# to check the pydot/graphviz installation.pydot.Dot.create(pydot.Dot())
returnTrueexcept (OSError, pydot.InvocationException):
returnFalsedefadd_edge(dot, src, dst):
ifnotdot.get_edge(src, dst):
dot.add_edge(pydot.Edge(src, dst))
defadd_edge_node(dot, node, next_node):
ifnode['type'] =="sequential":
foriinrange(len(node['nodes']) -1):
add_edge_node(dot, node['nodes'][i], node['nodes'][i+1])
add_edge_node(dot, node['nodes'][-1], next_node)
elifnode['type'] =="nest":
foriinrange(len(node['nodes'])):
add_edge_node(dot, node['nodes'][i], next_node)
elifnext_node['type'] =="sequential":
foriinrange(len(next_node['nodes']) -1):
add_edge_node(dot, next_node['nodes'][i], next_node['nodes'][i+1])
add_edge_node(dot, node, next_node['nodes'][0])
elifnext_node['type'] =="nest":
foriinrange(len(next_node['nodes'])):
add_edge_node(dot, node, next_node['nodes'][i])
else:
add_edge(dot, node['id'], next_node['id'])
defmake_node(id):
return {"id": id, "type": "node"}
defagent_model_to_dot(
model,
subgraph=False,
dpi=96,
depth=4,
):
ifnotmodel.built:
raiseValueError(
"This model has not yet been built. ""Build the model first by calling `build()` or by calling ""the model on a batch of data."
)
fromtf_agents.networksimportNestMapfromtf_agents.networksimportNestFlattenfromtf_agents.networksimportsequentialifnotcheck_pydot():
raiseImportError(
"You must install pydot (`pip install pydot`) for ""model_to_dot to work."
)
ifsubgraph:
dot=pydot.Cluster(style="dashed", graph_name=model.name)
dot.set("label", model.name)
dot.set("labeljust", "l")
else:
dot=pydot.Dot()
dot.set("rankdir", "TB")
dot.set("dpi", dpi)
dot.set_node_defaults(shape="record")
layers=model.layerslistIdNode= {"nodes": [], "type": (
"nest"ifisinstance(model, NestMap) else"sequential")}
# Create graph nodes.forlayerinlayers:
layer_id=str(id(layer))
# Append a wrapped layer's label to node's label, if it exists.layer_name=layer.nameclass_name=layer.__class__.__name__# Create node's label.label="{0}|{1}".format(class_name, layer_name)
defformat_shape(shape):
return (
str(shape)
.replace(str(None), "None")
.replace("{", r"\{")
.replace("}", r"\}")
)
try:
outputlabels=format_shape(layer.output_shape)
exceptAttributeError:
outputlabels="?"ifhasattr(layer, "input_shape"):
inputlabels=format_shape(layer.input_shape)
elifhasattr(layer, "input_shapes"):
inputlabels=", ".join(
[format_shape(ishape) forishapeinlayer.input_shapes]
)
else:
inputlabels="?"label="{%s}|{input:|output:}|{{%s}|{%s}}"% (
label,
inputlabels,
outputlabels,
)
ifdepth==0:
listIdNode['nodes'].append(make_node(layer_id))
node=pydot.Node(layer_id, label=label)
dot.add_node(node)
continueifisinstance(layer, sequential.Sequential) orisinstance(layer, NestMap):
submodel_wrapper, sub_listIdNode=agent_model_to_dot(
layer, subgraph=True, dpi=dpi, depth=depth-1)
listIdNode['nodes'].append(sub_listIdNode)
dot.add_subgraph(submodel_wrapper)
else:
listIdNode['nodes'].append(make_node(layer_id))
node=pydot.Node(layer_id, label=label)
dot.add_node(node)
# Add edges between nodes.ifnotsubgraphandisinstance(model, sequential.Sequential):
foriinrange(len(listIdNode['nodes']) -1):
node=listIdNode['nodes'][i]
next_node=listIdNode['nodes'][i+1]
add_edge_node(dot, node, next_node)
returndot, listIdNodedefprint_msg(message, line_break=True):
ifline_break:
sys.stdout.write(message+"\n")
else:
sys.stdout.write(message)
sys.stdout.flush()
defpath_to_string(path):
ifisinstance(path, os.PathLike):
returnos.fspath(path)
returnpathdefplot_agent_model(
model,
to_file="model.png",
subgraph=False,
dpi=96,
depth=4,
):
ifnotmodel.built:
raiseValueError(
"This model has not yet been built. ""Build the model first by calling `build()` or by calling ""the model on a batch of data."
)
ifnotcheck_graphviz():
message= (
"You must install pydot (`pip install pydot`) ""and install graphviz ""(see instructions at https://graphviz.gitlab.io/download/) ""for plot_model to work."
)
if"IPython.core.magics.namespace"insys.modules:
# We don't raise an exception here in order to avoid crashing# notebook tests where graphviz is not available.print_msg(message)
returnelse:
raiseImportError(message)
dot, _=agent_model_to_dot(
model,
subgraph=subgraph,
dpi=dpi,
depth=depth,
)
to_file=path_to_string(to_file)
ifdotisNone:
return_, extension=os.path.splitext(to_file)
ifnotextension:
extension="png"else:
extension=extension[1:]
# Save image to disk.dot.write(to_file, format=extension)
# Return the image as a Jupyter Image object, to be displayed in-line.# Note that we cannot easily detect whether the code is running in a# notebook, and thus we always return the Image if Jupyter is available.ifextension!="pdf":
try:
fromIPythonimportdisplayreturndisplay.Image(filename=to_file)
exceptImportError:
pass
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Because tf.keras.utils.plot_model does not work with the agent network so I made a custom function to do this
Example plot:
Here is the code:
Beta Was this translation helpful? Give feedback.
All reactions