diff --git a/brainforge/learner/abstract_learner.py b/brainforge/learner/abstract_learner.py index edeff8f..bba4d4b 100644 --- a/brainforge/learner/abstract_learner.py +++ b/brainforge/learner/abstract_learner.py @@ -95,11 +95,13 @@ def evaluate_batch(self, x, y, metrics=()): def evaluate_stream(self, stream, steps, metrics=(), verbose=False): history = logging.MetricLogs.from_metric_list(steps, ["cost"], metrics) metrics = [_metrics.get(metric) for metric in metrics] - for x, y in stream: + for i, (x, y) in enumerate(stream, start=1): eval_metrics = self.evaluate_batch(x, y, metrics) history.record(eval_metrics) if verbose: history.log("\r", end="") + if i >= steps: + break if verbose: print() history.reduce_mean()