diff --git a/README.md b/README.md index 4087a3dcd..93d9ff852 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Hardware Related * [Perform Inference From Custom Model](#perform-inference-from-custom-model) * [Explorations](#explorations) * [Start Exploration](#start-exploration) + * [Inspect and Monitor Best Val Losses](#inspect-and-monitor-best-val-losses) * [Start Tensorboard Logging](#start-tensorboard-logging) * [Troubleshooting](#troubleshooting) * [Creating New Features and Exploration Scripts](#creating-new-features-and-exploration-scripts) @@ -157,14 +158,17 @@ logs/ and save checkpoints for inference in `out_test` -### Inspect best losses +### Inspect and Monitor Best Val Losses -Often we want to run a large number of experiments and find the best validation -loss (a metric for how well the model does on next token prediction on a given -dataset). +Often for large explorations with `run_experiments` one wants to monitor the +the best validation losses so far (a metric for how well the model does on next +token prediction on the current dataset). -The included `inspect_ckpts.py` script to recursively check the best valiation -loss and associated iteration number for all ckpt.pt files in a given directory: +The included `inspect_ckpts.py` script reports the best valiation loss and +associated iteration number for all ckpt.pt files recursivel for a specified +parent directory. + +Example usage: ```bash python3 inspect_ckpts.py --directory ./out --sort loss ``` diff --git a/inspect_ckpts.py b/inspect_ckpts.py index eb124172a..5e7f92776 100644 --- a/inspect_ckpts.py +++ b/inspect_ckpts.py @@ -2,10 +2,10 @@ import os import torch import csv +import re from rich.console import Console from rich.table import Table - def get_best_val_loss_and_iter_num(checkpoint_file): """ Extracts the best validation loss and the corresponding iteration number from a PyTorch checkpoint file. @@ -25,12 +25,13 @@ def get_best_val_loss_and_iter_num(checkpoint_file): return best_val_loss, iter_num -def find_ckpt_files(directory): +def find_ckpt_files(directory, path_regex=None): """ Recursively finds all 'ckpt.pt' files in the given directory. Args: directory (str): The directory to search. + path_regex (str): Regular expression to filter the checkpoint file paths. Returns: list: A list of paths to the 'ckpt.pt' files. @@ -39,30 +40,48 @@ def find_ckpt_files(directory): for root, dirs, files in os.walk(directory): for file in files: if file.endswith('ckpt.pt'): - ckpt_files.append(os.path.join(root, file)) + ckpt_file = os.path.join(root, file) + if path_regex is None or re.search(path_regex, ckpt_file): + ckpt_files.append(ckpt_file) return ckpt_files +def get_short_ckpt_file(ckpt_file): + """ + Removes the '/ckpt.pt' suffix from the checkpoint file path. + + Args: + ckpt_file (str): The full checkpoint file path. + + Returns: + str: The checkpoint file path with the '/ckpt.pt' suffix removed. + """ + if ckpt_file.endswith('/ckpt.pt'): + return ckpt_file[:-8] + else: + return ckpt_file + def main(): parser = argparse.ArgumentParser(description='Extract best validation loss and iteration number from PyTorch checkpoint files.') parser.add_argument('--directory', type=str, help='Path to the directory containing the checkpoint files.') parser.add_argument('--csv_file', type=str, help='Path to the CSV file containing the checkpoint data.') + parser.add_argument('--path_regex', type=str, help='Regular expression to filter the checkpoint file paths.') parser.add_argument('--sort', type=str, choices=['path', 'loss', 'iter'], default='path', help='Sort the table by checkpoint file path, best validation loss, or iteration number.') parser.add_argument('--reverse', action='store_true', help='Reverse the sort order.') parser.add_argument('--output', type=str, help='Path to the output CSV file.') args = parser.parse_args() if args.directory: - ckpt_files = find_ckpt_files(args.directory) + ckpt_files = find_ckpt_files(args.directory, args.path_regex) # Extract the best validation loss and iteration number for each checkpoint file - ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files] + ckpt_data = [(get_short_ckpt_file(ckpt_file), *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files] elif args.csv_file: ckpt_data = [] with open(args.csv_file, 'r') as csvfile: csv_reader = csv.reader(csvfile) next(csv_reader) # Skip the header row for row in csv_reader: - ckpt_data.append((row[0], float(row[1]), int(row[2]))) + ckpt_data.append((get_short_ckpt_file(row[0]), float(row[1]), int(row[2]))) else: print("Please provide either a directory or a CSV file.") return @@ -75,14 +94,13 @@ def main(): elif args.sort == 'iter': ckpt_data.sort(key=lambda x: x[2], reverse=args.reverse) - console = None - # Check if the TERM environment variable is set to a value that supports ANSI escape codes - if 'TERM' in os.environ and os.environ['TERM'] in ['xterm', 'xterm-color', 'xterm-256color', 'screen', 'screen-256color', 'tmux', 'tmux-256color']: - console = Console(color_system="standard") - else: - console = Console() + console = Console() + + # Determine the maximum length of the checkpoint file paths + max_path_length = max(len(ckpt_file) for ckpt_file, _, _ in ckpt_data) + table = Table(show_header=True, header_style="bold magenta") - table.add_column("Checkpoint File", style="dim", width=50) + table.add_column("Checkpoint File", style="dim", width=max_path_length + 2) table.add_column("Best Validation Loss", justify="right") table.add_column("Iteration Number", justify="right")