Skip to content

Commit

Permalink
Fix Trainer._symbolic_build (keras-team#19945)
Browse files Browse the repository at this point in the history
  • Loading branch information
Grvzard committed Jul 1, 2024
1 parent fbf9d17 commit 38ef5f2
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,6 @@ def _symbolic_build(self, iterator=None, data_batch=None):
optimizer_unbuilt = (
self.optimizer is not None and not self.optimizer.built
)
if model_unbuilt or compile_metrics_unbuilt or optimizer_unbuilt:
if data_batch is None:
for _, data in iterator.enumerate_epoch():
data_batch = data[0]
break

if model_unbuilt or compile_metrics_unbuilt:
# Create symbolic tensors matching an input batch.

Expand All @@ -1017,6 +1011,10 @@ def to_symbolic_input(v):
v.shape, backend.standardize_dtype(v.dtype)
)

if data_batch is None:
for _, data in iterator.enumerate_epoch():
data_batch = data[0]
break
data_batch = tree.map_structure(to_symbolic_input, data_batch)
(
x,
Expand Down

0 comments on commit 38ef5f2

Please sign in to comment.