diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 8e679f33211..3ea434017f9 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -287,6 +287,7 @@ def train_step(train_state, x, y): print("\nTraining:") data_iter = iter(train_data) for epoch in range(EPOCHS): + loss_value = None # default for i in tqdm(range(STEPS_PER_EPOCH)): x, y = next(data_iter) sharded_x = jax.device_put(x.numpy(), data_sharding) diff --git a/guides/custom_train_step_in_jax.py b/guides/custom_train_step_in_jax.py index 46dd85e1495..2085b202868 100644 --- a/guides/custom_train_step_in_jax.py +++ b/guides/custom_train_step_in_jax.py @@ -124,7 +124,7 @@ def train_step(self, state, data): ) # Update metrics. - new_metrics_vars = [] + new_metrics_vars, logs = [], [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) @@ -314,7 +314,7 @@ def test_step(self, state, data): loss = self.compute_loss(x, y, y_pred) # Update metrics. - new_metrics_vars = [] + new_metrics_vars, logs = [], [] for metric in self.metrics: this_metric_vars = metrics_variables[ len(new_metrics_vars) : len(new_metrics_vars) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 3babe17b8d7..6f6dbbf25d7 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -252,6 +252,7 @@ def get_replicated_train_state(devices): # Custom training loop for epoch in range(num_epochs): data_iter = iter(train_data) + loss_value = None # default for data in data_iter: x, y = data sharded_x = jax.device_put(x.numpy(), data_sharding) diff --git a/keras/src/activations/activations.py b/keras/src/activations/activations.py index c21d2d279bc..3f875b64b15 100644 --- a/keras/src/activations/activations.py +++ b/keras/src/activations/activations.py @@ -83,6 +83,8 @@ def static_call(x, negative_slope=0.0, max_value=None, threshold=0.0): negative_part = backend.nn.relu(-x + threshold) else: negative_part = backend.nn.relu(-x) + else: + negative_part = 1 clip_max = max_value is not None if threshold != 0: diff --git a/keras/src/applications/densenet.py b/keras/src/applications/densenet.py index 886b6bc16bd..436401d258b 100644 --- a/keras/src/applications/densenet.py +++ b/keras/src/applications/densenet.py @@ -289,6 +289,8 @@ def DenseNet( cache_subdir="models", file_hash="1ceb130c1ea1b78c3bf6114dbdfd8807", ) + else: + raise ValueError("weights_path undefined") else: if blocks == [6, 12, 24, 16]: weights_path = file_utils.get_file( @@ -311,6 +313,8 @@ def DenseNet( cache_subdir="models", file_hash="c13680b51ded0fb44dff2d8f86ac8bb1", ) + else: + raise ValueError("weights_path undefined") model.load_weights(weights_path) elif weights is not None: model.load_weights(weights) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 41f7674b1b9..e925f1f078c 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -397,6 +397,7 @@ def fit( self.make_train_function() self.stop_training = False + training_logs = {} callbacks.on_train_begin() initial_epoch = self._initial_epoch or initial_epoch for epoch in range(initial_epoch, epochs): diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index ce228090544..9ba0162d459 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -236,6 +236,7 @@ def fit( ) self.stop_training = False + training_logs = {} self.make_train_function() callbacks.on_train_begin() initial_epoch = self._initial_epoch or initial_epoch diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index d863f5639f8..d8de918c2b9 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -134,6 +134,8 @@ def _calculate_scores(self, query, key): scores = self.concat_score_weight * ops.sum( ops.tanh(q_reshaped + k_reshaped), axis=-1 ) + else: + raise ValueError("scores not computed") return scores diff --git a/keras/src/layers/preprocessing/feature_space.py b/keras/src/layers/preprocessing/feature_space.py index 357d64cf8a3..5fc5e34afaf 100644 --- a/keras/src/layers/preprocessing/feature_space.py +++ b/keras/src/layers/preprocessing/feature_space.py @@ -517,8 +517,7 @@ def adapt(self, dataset): preprocessor = self.preprocessors[name] # TODO: consider adding an adapt progress bar. # Sample 1 element to check the rank - for x in feature_dataset.take(1): - pass + x = next(iter(feature_dataset)) if len(x.shape) == 0: # The dataset yields unbatched scalars; batch it. feature_dataset = feature_dataset.batch(32) diff --git a/keras/src/layers/preprocessing/hashed_crossing_test.py b/keras/src/layers/preprocessing/hashed_crossing_test.py index 9e74b876362..6e9d3648e89 100644 --- a/keras/src/layers/preprocessing/hashed_crossing_test.py +++ b/keras/src/layers/preprocessing/hashed_crossing_test.py @@ -86,8 +86,7 @@ def test_tf_data_compatibility(self): .batch(5) .map(lambda x1, x2: layer((x1, x2))) ) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(np.array([1, 4, 1, 1, 3]), output) def test_unsupported_shape_input_fails(self): diff --git a/keras/src/layers/preprocessing/hashing_test.py b/keras/src/layers/preprocessing/hashing_test.py index 614d575633f..3a7966f8161 100644 --- a/keras/src/layers/preprocessing/hashing_test.py +++ b/keras/src/layers/preprocessing/hashing_test.py @@ -60,8 +60,7 @@ def test_tf_data_compatibility(self): layer = layers.Hashing(num_bins=3) inp = [["A"], ["B"], ["C"], ["D"], ["E"]] ds = tf.data.Dataset.from_tensor_slices(inp).batch(5).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, np.array([[1], [0], [1], [1], [2]])) @parameterized.named_parameters( @@ -306,6 +305,8 @@ def test_count_output(self, input_value, expected_output, output_shape): symbolic_sample_shape = () elif input_array.ndim == 2: symbolic_sample_shape = (None,) + else: + raise TypeError("Unknown `symbolic_sample_shape`") inputs = layers.Input(shape=symbolic_sample_shape, dtype="int32") layer = layers.Hashing(num_bins=3, output_mode="count") outputs = layer(inputs) diff --git a/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py index 78233d0a1c5..7eaa32e08bd 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py @@ -171,8 +171,7 @@ def test_tf_data_compatibility(self): layer = layers.CenterCrop(8, 9) input_data = np.random.random(input_shape) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertEqual(tuple(output.shape), output_shape) # TODO diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py index d3db4366f13..8972d88f33e 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py @@ -54,8 +54,7 @@ def test_tf_data_compatibility(self): layer = layers.RandomContrast(factor=0.5, seed=1337) input_data = np.random.random((2, 8, 8, 3)) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output.numpy() + next(iter(ds)).numpy() def test_dict_input(self): layer = layers.RandomContrast(factor=0.1, bounding_box_format="xyxy") diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py index 77c2b0a3c9e..c4796a2b224 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py @@ -136,8 +136,7 @@ def test_tf_data_compatibility(self): output_shape = (2, 3, 8, 9) input_data = np.random.random(input_shape) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertEqual(tuple(output.shape), output_shape) def test_dict_input(self): diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py index aba8d30c9b9..4d7b25a9e88 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py @@ -141,8 +141,7 @@ def test_tf_data_compatibility(self): input_data = np.array([[[2, 3, 4]], [[5, 6, 7]]]) expected_output = np.array([[[5, 6, 7]], [[2, 3, 4]]]) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, expected_output) # Test 4D input: shape (2, 2, 1, 3) layer = layers.RandomFlip( @@ -167,6 +166,5 @@ def test_tf_data_compatibility(self): ] ) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, expected_output) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py index 005110ef2c5..7350c550ede 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py @@ -73,6 +73,5 @@ def test_tf_data_compatibility(self): [4, 3, 2, 1, 0], ] ).reshape(input_shape[1:]) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(expected_output, output) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py index ff6b97e7ffc..7b9fcbf029f 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py @@ -327,5 +327,4 @@ def test_tf_data_compatibility(self): layer = layers.RandomTranslation(0.2, 0.1) input_data = np.random.random((1, 4, 4, 3)) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(1).map(layer) - for output in ds.take(1): - output.numpy() + next(iter(ds)).numpy() diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py index f4ce59c77f5..52e08492457 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py @@ -119,8 +119,7 @@ def test_tf_data_compatibility(self): [0, 0, 0, 0, 0], ] ).reshape(input_shape) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(expected_output, output) def test_dynamic_shape(self): diff --git a/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py index b0c138550ff..4d0374238dd 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/resizing_test.py @@ -186,8 +186,7 @@ def test_tf_data_compatibility(self): layer = layers.Resizing(8, 9) input_data = np.random.random(input_shape) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertEqual(tuple(output.shape), output_shape) @pytest.mark.skipif( @@ -210,8 +209,7 @@ def test_tf_data_compatibility_sequential(self): .batch(2) .map(Sequential([layer])) ) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertEqual(tuple(output.shape), output_shape) @parameterized.parameters( diff --git a/keras/src/layers/preprocessing/integer_lookup_test.py b/keras/src/layers/preprocessing/integer_lookup_test.py index d1c6a732cbe..9247a98bad8 100644 --- a/keras/src/layers/preprocessing/integer_lookup_test.py +++ b/keras/src/layers/preprocessing/integer_lookup_test.py @@ -102,6 +102,5 @@ def test_tf_data_compatibility(self): ) input_data = [2, 3, 4, 5] ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, np.array([2, 3, 4, 0])) diff --git a/keras/src/layers/preprocessing/normalization.py b/keras/src/layers/preprocessing/normalization.py index cfaa649a0e1..2f6000165e3 100644 --- a/keras/src/layers/preprocessing/normalization.py +++ b/keras/src/layers/preprocessing/normalization.py @@ -277,6 +277,8 @@ def adapt(self, data): batch_var + (batch_mean - new_total_mean) ** 2 ) * batch_weight total_mean = new_total_mean + else: + raise NotImplementedError(type(data)) self.adapt_mean.assign(total_mean) self.adapt_variance.assign(total_var) diff --git a/keras/src/layers/preprocessing/normalization_test.py b/keras/src/layers/preprocessing/normalization_test.py index b76ba5e4fa8..d524e6cd866 100644 --- a/keras/src/layers/preprocessing/normalization_test.py +++ b/keras/src/layers/preprocessing/normalization_test.py @@ -65,6 +65,8 @@ def test_normalization_adapt(self, input_type): data = backend.convert_to_tensor(x) elif input_type == "tf.data": data = tf_data.Dataset.from_tensor_slices(x).batch(8) + else: + raise NotImplementedError(input_type) layer = layers.Normalization() layer.adapt(data) diff --git a/keras/src/layers/preprocessing/rescaling_test.py b/keras/src/layers/preprocessing/rescaling_test.py index bd0a7742328..81634d3801b 100644 --- a/keras/src/layers/preprocessing/rescaling_test.py +++ b/keras/src/layers/preprocessing/rescaling_test.py @@ -72,8 +72,7 @@ def test_tf_data_compatibility(self): layer = layers.Rescaling(scale=1.0 / 255, offset=0.5) x = np.random.random((3, 10, 10, 3)) * 255 ds = tf_data.Dataset.from_tensor_slices(x).batch(3).map(layer) - for output in ds.take(1): - output.numpy() + next(iter(ds)).numpy() def test_rescaling_with_channels_first_and_vector_scale(self): config = backend.image_data_format() diff --git a/keras/src/layers/preprocessing/string_lookup_test.py b/keras/src/layers/preprocessing/string_lookup_test.py index 4319d511a9a..b735546ab43 100644 --- a/keras/src/layers/preprocessing/string_lookup_test.py +++ b/keras/src/layers/preprocessing/string_lookup_test.py @@ -77,8 +77,7 @@ def test_tf_data_compatibility(self): ) input_data = ["b", "c", "d"] ds = tf_data.Dataset.from_tensor_slices(input_data).batch(3).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, np.array([2, 3, 0])) @pytest.mark.skipif(not backend.backend() == "tensorflow", reason="tf only") diff --git a/keras/src/layers/preprocessing/text_vectorization_test.py b/keras/src/layers/preprocessing/text_vectorization_test.py index 1f641e5a92d..c5a8c593f41 100644 --- a/keras/src/layers/preprocessing/text_vectorization_test.py +++ b/keras/src/layers/preprocessing/text_vectorization_test.py @@ -95,8 +95,7 @@ def test_tf_data_compatibility(self): ) input_data = [["foo qux bar"], ["qux baz"]] ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output = output.numpy() + output = next(iter(ds)).numpy() self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]])) # Test adapt flow @@ -107,8 +106,7 @@ def test_tf_data_compatibility(self): ) layer.adapt(input_data) ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) - for output in ds.take(1): - output.numpy() + next(iter(ds)).numpy() @pytest.mark.skipif( backend.backend() != "tensorflow", reason="Requires string tensors." diff --git a/keras/src/legacy/backend.py b/keras/src/legacy/backend.py index dbb933112ad..1c3876d8583 100644 --- a/keras/src/legacy/backend.py +++ b/keras/src/legacy/backend.py @@ -1279,6 +1279,8 @@ def relu(x, alpha=0.0, max_value=None, threshold=0.0): negative_part = tf.nn.relu(-x + threshold) else: negative_part = tf.nn.relu(-x) + else: + negative_part = 1 clip_max = max_value is not None diff --git a/keras/src/models/model.py b/keras/src/models/model.py index e03e6dc97bd..42e5cdddba0 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -397,6 +397,7 @@ def quantize(self, mode, **kwargs): def build_from_config(self, config): if not config: return + status = False if "input_shape" in config: # Case: all inputs are in the first arg (possibly nested). if utils.is_default(self.build): @@ -408,7 +409,7 @@ def build_from_config(self, config): self.build(config["input_shape"]) status = True except: - status = False + pass self._build_shapes_dict = config elif "shapes_dict" in config: @@ -420,7 +421,7 @@ def build_from_config(self, config): self.build(**config["shapes_dict"]) status = True except: - status = False + pass self._build_shapes_dict = config["shapes_dict"] if not status: diff --git a/keras/src/models/sequential.py b/keras/src/models/sequential.py index 010460548ec..5815add1c14 100644 --- a/keras/src/models/sequential.py +++ b/keras/src/models/sequential.py @@ -359,6 +359,7 @@ def from_config(cls, config, custom_objects=None): model.add(layer) if ( not model._functional + and "build_input_shape" in locals() and build_input_shape and isinstance(build_input_shape, (tuple, list)) ): diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 3e5daf035b0..5faea6e8122 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -318,7 +318,7 @@ def map_graph(inputs, outputs): "The following previous operations were accessed " f"without issue: {operations_with_complete_input}" ) - operations_with_complete_input.append(operation.name) + operations_with_complete_input.append(node.operation.name) for x in tree.flatten(node.outputs): computable_tensors.add(x) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 4addf21342b..e40117df743 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -914,6 +914,8 @@ def get_config(self): learning_rate = serialization_lib.serialize_keras_object( self._learning_rate ) + else: + learning_rate = 0.5 config = { "name": self.name, diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 410a782dbcd..d87afe128ce 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -170,6 +170,7 @@ def variables(self): return vars def build(self, y_true, y_pred): + num_outputs = 1 # default if self.output_names: output_names = self.output_names elif isinstance(y_pred, dict): @@ -182,7 +183,6 @@ def build(self, y_true, y_pred): output_names = None else: output_names = None - num_outputs = 1 if output_names: num_outputs = len(output_names) diff --git a/keras/src/tree/dmtree_impl.py b/keras/src/tree/dmtree_impl.py index 84466417514..c0ba934f2f4 100644 --- a/keras/src/tree/dmtree_impl.py +++ b/keras/src/tree/dmtree_impl.py @@ -60,6 +60,7 @@ def truncate(value, length): ) return flat_sequence[0] + packed = [] try: final_index, packed = packed_nest_with_indices( structure, flat_sequence, 0, is_nested_fn, sequence_fn diff --git a/keras/src/utils/file_utils.py b/keras/src/utils/file_utils.py index 80268cac662..e18daed07e2 100644 --- a/keras/src/utils/file_utils.py +++ b/keras/src/utils/file_utils.py @@ -100,9 +100,11 @@ def extract_archive(file_path, path=".", archive_format="auto"): if archive_type == "tar": open_fn = tarfile.open is_match_fn = tarfile.is_tarfile - if archive_type == "zip": + elif archive_type == "zip": open_fn = zipfile.ZipFile is_match_fn = zipfile.is_zipfile + else: + raise NotImplementedError(archive_type) if is_match_fn(file_path): with open_fn(file_path) as archive: