diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 0c7880db7a5..f47b116a36b 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1001,12 +1001,6 @@ def _symbolic_build(self, iterator=None, data_batch=None): optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - if model_unbuilt or compile_metrics_unbuilt or optimizer_unbuilt: - if data_batch is None: - for _, data in iterator.enumerate_epoch(): - data_batch = data[0] - break - if model_unbuilt or compile_metrics_unbuilt: # Create symbolic tensors matching an input batch. @@ -1017,6 +1011,10 @@ def to_symbolic_input(v): v.shape, backend.standardize_dtype(v.dtype) ) + if data_batch is None: + for _, data in iterator.enumerate_epoch(): + data_batch = data[0] + break data_batch = tree.map_structure(to_symbolic_input, data_batch) ( x,