Skip to content

Commit

Permalink
Add table customization
Browse files Browse the repository at this point in the history
This adds regex for filtering filepath results.

This also adds adaptable column width for the filepath to prevent name
clipping.

To help reduce excess chars in the filepath, we trim the `/ckpt.pt` as
all of the paths will have this in common.
  • Loading branch information
gkielian committed Apr 18, 2024
1 parent 386bd36 commit 57b0688
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Hardware Related
* [Perform Inference From Custom Model](#perform-inference-from-custom-model)
* [Explorations](#explorations)
* [Start Exploration](#start-exploration)
* [Inspect and Monitor Best Val Losses](#inspect-and-monitor-best-val-losses)
* [Start Tensorboard Logging](#start-tensorboard-logging)
* [Troubleshooting](#troubleshooting)
* [Creating New Features and Exploration Scripts](#creating-new-features-and-exploration-scripts)
Expand Down Expand Up @@ -157,14 +158,17 @@ logs/

and save checkpoints for inference in `out_test`

### Inspect best losses
### Inspect and Monitor Best Val Losses

Often we want to run a large number of experiments and find the best validation
loss (a metric for how well the model does on next token prediction on a given
dataset).
Often for large explorations with `run_experiments` one wants to monitor the
the best validation losses so far (a metric for how well the model does on next
token prediction on the current dataset).

The included `inspect_ckpts.py` script to recursively check the best valiation
loss and associated iteration number for all ckpt.pt files in a given directory:
The included `inspect_ckpts.py` script reports the best valiation loss and
associated iteration number for all ckpt.pt files recursivel for a specified
parent directory.

Example usage:
```bash
python3 inspect_ckpts.py --directory ./out --sort loss
```
Expand Down
44 changes: 31 additions & 13 deletions inspect_ckpts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import os
import torch
import csv
import re
from rich.console import Console
from rich.table import Table


def get_best_val_loss_and_iter_num(checkpoint_file):
"""
Extracts the best validation loss and the corresponding iteration number from a PyTorch checkpoint file.
Expand All @@ -25,12 +25,13 @@ def get_best_val_loss_and_iter_num(checkpoint_file):

return best_val_loss, iter_num

def find_ckpt_files(directory):
def find_ckpt_files(directory, path_regex=None):
"""
Recursively finds all 'ckpt.pt' files in the given directory.
Args:
directory (str): The directory to search.
path_regex (str): Regular expression to filter the checkpoint file paths.
Returns:
list: A list of paths to the 'ckpt.pt' files.
Expand All @@ -39,30 +40,48 @@ def find_ckpt_files(directory):
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('ckpt.pt'):
ckpt_files.append(os.path.join(root, file))
ckpt_file = os.path.join(root, file)
if path_regex is None or re.search(path_regex, ckpt_file):
ckpt_files.append(ckpt_file)
return ckpt_files

def get_short_ckpt_file(ckpt_file):
"""
Removes the '/ckpt.pt' suffix from the checkpoint file path.
Args:
ckpt_file (str): The full checkpoint file path.
Returns:
str: The checkpoint file path with the '/ckpt.pt' suffix removed.
"""
if ckpt_file.endswith('/ckpt.pt'):
return ckpt_file[:-8]
else:
return ckpt_file

def main():
parser = argparse.ArgumentParser(description='Extract best validation loss and iteration number from PyTorch checkpoint files.')
parser.add_argument('--directory', type=str, help='Path to the directory containing the checkpoint files.')
parser.add_argument('--csv_file', type=str, help='Path to the CSV file containing the checkpoint data.')
parser.add_argument('--path_regex', type=str, help='Regular expression to filter the checkpoint file paths.')
parser.add_argument('--sort', type=str, choices=['path', 'loss', 'iter'], default='path', help='Sort the table by checkpoint file path, best validation loss, or iteration number.')
parser.add_argument('--reverse', action='store_true', help='Reverse the sort order.')
parser.add_argument('--output', type=str, help='Path to the output CSV file.')
args = parser.parse_args()

if args.directory:
ckpt_files = find_ckpt_files(args.directory)
ckpt_files = find_ckpt_files(args.directory, args.path_regex)

# Extract the best validation loss and iteration number for each checkpoint file
ckpt_data = [(ckpt_file, *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files]
ckpt_data = [(get_short_ckpt_file(ckpt_file), *get_best_val_loss_and_iter_num(ckpt_file)) for ckpt_file in ckpt_files]
elif args.csv_file:
ckpt_data = []
with open(args.csv_file, 'r') as csvfile:
csv_reader = csv.reader(csvfile)
next(csv_reader) # Skip the header row
for row in csv_reader:
ckpt_data.append((row[0], float(row[1]), int(row[2])))
ckpt_data.append((get_short_ckpt_file(row[0]), float(row[1]), int(row[2])))
else:
print("Please provide either a directory or a CSV file.")
return
Expand All @@ -75,14 +94,13 @@ def main():
elif args.sort == 'iter':
ckpt_data.sort(key=lambda x: x[2], reverse=args.reverse)

console = None
# Check if the TERM environment variable is set to a value that supports ANSI escape codes
if 'TERM' in os.environ and os.environ['TERM'] in ['xterm', 'xterm-color', 'xterm-256color', 'screen', 'screen-256color', 'tmux', 'tmux-256color']:
console = Console(color_system="standard")
else:
console = Console()
console = Console()

# Determine the maximum length of the checkpoint file paths
max_path_length = max(len(ckpt_file) for ckpt_file, _, _ in ckpt_data)

table = Table(show_header=True, header_style="bold magenta")
table.add_column("Checkpoint File", style="dim", width=50)
table.add_column("Checkpoint File", style="dim", width=max_path_length + 2)
table.add_column("Best Validation Loss", justify="right")
table.add_column("Iteration Number", justify="right")

Expand Down

0 comments on commit 57b0688

Please sign in to comment.