Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nanotron tests #160

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
cache: 'pip'
- name: Install lighteval in editable mode
run: |
pip install -e .[dev,extended_tasks]
pip install -e .[dev,nanotron,extended_tasks]
- name: Get cached files
uses: actions/cache@v2
id: get-cache
Expand Down
6 changes: 6 additions & 0 deletions run_evals_nanotron.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def get_parser():
default=None,
help="Cache directory",
)
parser.add_argument(
"--max_samples",
type=int,
default=10,
help="number of samples used for evaluation",
)

return parser

Expand Down
19 changes: 19 additions & 0 deletions src/lighteval/main_nanotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
# flake8: noqa: C901
import os
import random
from argparse import Namespace
from typing import Optional, Type

import numpy as np
import torch

from lighteval.evaluator import evaluate, make_results_table
from lighteval.logging.evaluation_tracker import EvaluationTracker
Expand Down Expand Up @@ -64,7 +66,15 @@ def main(
config_cls: Type = Config,
model_config_cls: Optional[Type] = None,
model_cls: Optional[Type] = None,
args: Optional[Namespace] = None, # accept args for more flexibility
):
if args is not None:
checkpoint_config_path = (
args.checkpoint_config_path if checkpoint_config_path is None else checkpoint_config_path
)
lighteval_config_path = args.lighteval_override if lighteval_config_path is None else lighteval_config_path
cache_dir = args.cache_dir if cache_dir is None else cache_dir

if cache_dir is None:
cache_dir = CACHE_DIR

Expand All @@ -82,6 +92,7 @@ def main(
model_config_class=model_config_cls,
skip_unused_config_keys=True,
skip_null_keys=True,
igonore_all_unused_keys=True,
)

if lighteval_config_path:
Expand All @@ -90,6 +101,9 @@ def main(
else:
lighteval_config = nanotron_config.lighteval

if args.max_samples is not None:
lighteval_config.tasks.max_samples = args.max_samples

parallel_context = ParallelContext(
tensor_parallel_size=lighteval_config.parallelism.tp,
pipeline_parallel_size=lighteval_config.parallelism.pp,
Expand Down Expand Up @@ -157,8 +171,13 @@ def main(

with htrack_block("Setting seeds and waiting for all processes"):
hlog(f"setting seed to {SEED} for random and numpy")

torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

dist.barrier()

with htrack_block("Evaluation"):
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
# this took me some time each time when I run the tests, even I don't need it
llm_judge_multi_turn = SampleLevelMetricGrouping(
metric=["single_turn", "multi_turn"],
higher_is_better=True,
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/nanotron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def prepare_batch(

# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen

if padding_length - inplen < 0:
raise ValueError("Negative padding")
padded.append(padding_length - inplen)
Expand Down Expand Up @@ -655,7 +656,7 @@ def _get_subsets(self, dataset, dataset_splits):
def _loglikelihood_single_token(
self, requests, disable_tqdm: bool = False, override_bs: int = -1, dataset_splits: int = 1
) -> List[LoglikelihoodSingleTokenReturn]:
dataset = LoglikelihoodSingleTokenDataset(requests=requests)
dataset = LoglikelihoodSingleTokenDataset(requests=requests, dataset_splits=dataset_splits)
res = []

# Dataset is sorted in descending size.
Expand Down
55 changes: 55 additions & 0 deletions tests/config/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Nanotron tests guide
## How it works:
First select some tasks and then use the model to generate reference scores and save them in reference_task_scores_nanotron.py file, it has been done, but if you want to add a new task, you need to re-run it.

After that, each time a test need to be conducted, the evaluation will be run and the results are compared to the previous reference score.

## To run nanotron test:
```
pytest tests/test_main_nanotron.py -sv
```

## Choose your own tasks for evaluation:
Modify the **tasks.tasks** in config file(lighteval/tests/config/lighteval_config_override_custom.yaml) to set the tasks.
Example:
```
tasks:
custom_tasks: null
dataset_loading_processes: 1
max_samples: 10
multichoice_continuations_start_space: null
no_multichoice_continuations_start_space: null
num_fewshot_seeds: null
tasks: lighteval|anli:r1|0|0,lighteval|blimp:adjunct_island|0|0,...
```

## Randomized results
Please make sure to set **for_inference** to true. This will load model with a fixed output layer norm implementation. It's set to false by default for training
```
model:
ddp_bucket_cap_mb: 25
dtype: float64
init_method:
std: 0.02
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 512
initializer_range: 0.02
intermediate_size: 2048
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 16
num_hidden_layers: 16
num_key_value_heads: 16
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 50272
for_inference: true
```
30 changes: 30 additions & 0 deletions tests/config/lighteval_config_override_custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
batch_size: 16
checkpoints_path: null
generation: null
logging:
hub_repo_details: null
hub_repo_results: null
hub_repo_tensorboard: zzhhjjj/debug-nanotron
local_output_path: /scratch/haojun/lighteval/nanotron-119M-seed-6-3188821
push_details_to_hub: null
push_results_to_hub: null
push_results_to_tensorboard: true
tensorboard_metric_prefix: e
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
slurm_script_dir: /fsx/haojun/logs_evals
slurm_template: /fsx/haojun/brrr/examples/get-started-kit/run_eval.slurm.jinja
tasks:
custom_tasks: null
dataset_loading_processes: 1
max_samples: 10
multichoice_continuations_start_space: null
no_multichoice_continuations_start_space: null
num_fewshot_seeds: null
tasks: lighteval|anli:r1|0|0,lighteval|blimp:adjunct_island|0|0,lighteval|blimp:ellipsis_n_bar_1|0|0,leaderboard|arc:challenge|25|0,leaderboard|hellaswag|10|0,leaderboard|mmlu:abstract_algebra|5|0,leaderboard|mmlu:college_chemistry|5|0,leaderboard|mmlu:computer_security|5|0,leaderboard|mmlu:us_foreign_policy|5|0,leaderboard|truthfulqa:mc|0|0,helm|mmlu:abstract_algebra|5|0,helm|mmlu:college_chemistry|5|0,helm|mmlu:computer_security|5|0,helm|mmlu:us_foreign_policy|5|0,helm|boolq|5|0,helm|hellaswag|5|0,leaderboard|gsm8k|5|0
Loading
Loading