From 38ef5f26244ab8042ad62ee2b52c5303f2539590 Mon Sep 17 00:00:00 2001 From: Grvzard Date: Tue, 2 Jul 2024 01:56:27 +0800 Subject: [PATCH] Fix `Trainer._symbolic_build` (#19945) --- keras/src/trainers/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 0c7880db7a5..f47b116a36b 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1001,12 +1001,6 @@ def _symbolic_build(self, iterator=None, data_batch=None): optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - if model_unbuilt or compile_metrics_unbuilt or optimizer_unbuilt: - if data_batch is None: - for _, data in iterator.enumerate_epoch(): - data_batch = data[0] - break - if model_unbuilt or compile_metrics_unbuilt: # Create symbolic tensors matching an input batch. @@ -1017,6 +1011,10 @@ def to_symbolic_input(v): v.shape, backend.standardize_dtype(v.dtype) ) + if data_batch is None: + for _, data in iterator.enumerate_epoch(): + data_batch = data[0] + break data_batch = tree.map_structure(to_symbolic_input, data_batch) ( x,