Skip to content

Commit

Permalink
Merge pull request #203 from mlfoundations/revert-201-classification-…
Browse files Browse the repository at this point in the history
…refactor

Revert "Classification refactor"
  • Loading branch information
anas-awadalla committed Jun 8, 2023
2 parents dd88d06 + fbad1be commit c2e80b4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 111 deletions.
72 changes: 0 additions & 72 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
import os
from dataclasses import dataclass, field
from typing import Optional, Sequence, Mapping

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
Expand Down Expand Up @@ -109,75 +106,6 @@ def __getitem__(self, idx):
}


def topk(probs_ary: np.ndarray, k: int) -> np.ndarray:
"""Return the indices of the top k elements in probs_ary."""
return np.argsort(probs_ary)[::-1][:k]


@dataclass
class ClassificationDataset:
"""Class to hold a classification dataset for evals.
All Dataset objects (train_dataset, val_dataset, test_dataset)
should return a dictionary containing at least the
following keys: image, class_id, class_name. See
ImageNetDataset for an example.
"""

train_dataset: Dataset
prompts: Sequence[str] = field(
metadata={
"help": "A sequence of prompts to be used during evaluation;"
"e.g. 'A photo of a'. It is recommended to 'strip' the prompt (remove leading/trailing "
"spaces) for best performance."
}
)
class_id_to_label: Mapping[int, str] = field(
metadata={
"help": "mapping of numeric class IDs to string class names/labels."
"Downstream metrics will be evaluated against the mapped strings."
}
)
val_dataset: Optional[Dataset] = None
test_dataset: Optional[Dataset] = None

def get_in_context_samples(self, num: int, **kwargs) -> Sequence[int]:
"""Fetch a set of `num` in-context sample indices."""
return np.random.choice(len(self.train_dataset), num, replace=False)

def metric_fn(
self, labels: Sequence[int], outputs: Sequence[float]
) -> Mapping[str, float]:
"""
Compute metrics for a set of labels and predictions.
labels: An array-like of shape [batch_size,]
outputs: Model outputs; an array-like of shape [batch_size, num_classes]. The
[i,j]^th element of outputs should correspond to the probability
that the i^th observation has numeric class label j.
"""
batch_size = len(labels)

# Sanity check that batch size is consistent
assert len(outputs) == len(labels)

# Sanity check that outputs has same dimension as class mapping.
assert outputs.shape[1] == len(self.class_id_to_label)

acc5 = 0.0
acc1 = 0.0

for i in range(batch_size):
top5 = [self.class_id_to_label[pred] for pred in topk(outputs[i], 5)]

y_i = labels[i]["class_name"]
acc5 += int(y_i in set(top5))
acc1 += int(y_i == top5[0])

print(f"[DEBUG]: elem {i} of {batch_size}:" f"label {y_i} // top5 {top5}")
return {"acc1": acc1, "acc5": acc5}


class ImageNetDataset(ImageFolder):
"""Class to represent the ImageNet1k dataset."""

Expand Down
78 changes: 39 additions & 39 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
import random
import uuid
from collections import defaultdict
from typing import Mapping

from einops import repeat
import more_itertools
import numpy as np
import torch


from coco_metric import compute_cider, postprocess_captioning_generation
from eval_datasets import CaptionDataset, ImageNetDataset, ClassificationDataset
from eval_datasets import CaptionDataset, VQADataset, ImageNetDataset
from tqdm import tqdm


from eval_datasets import VQADataset, ImageNetDataset
from open_flamingo.eval.imagenet_utils import (
openai_imagenet_classnames,
IMAGENET_1K_CLASS_ID_TO_LABEL,
)

Expand Down Expand Up @@ -429,25 +431,13 @@ def main():
for shot in args.shots:
scores = []
for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
imagenet_dataset = ClassificationDataset(
train_dataset=ImageNetDataset(
os.path.join(args.imagenet_root, "train")
),
val_dataset=ImageNetDataset(
os.path.join(args.imagenet_root, "val")
),
class_id_to_label=IMAGENET_1K_CLASS_ID_TO_LABEL,
prompts=[
"A photo of a",
],
)
imagenet_score = evaluate_classification(
imagenet_score = evaluate_imagenet(
eval_model=eval_model,
batch_size=args.batch_size,
num_samples=args.num_samples,
num_shots=shot,
seed=seed,
classification_dataset=imagenet_dataset,
imagenet_root=args.imagenet_root,
)
print(
f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
Expand Down Expand Up @@ -797,20 +787,21 @@ def evaluate_vqa(
return acc


def evaluate_classification(
def evaluate_imagenet(
eval_model,
batch_size: int,
classification_dataset: ClassificationDataset,
imagenet_root: str,
seed: int = 42,
num_samples: int = 5000,
num_shots: int = 8,
):
"""
Evaluate a model on a classification dataset.
Evaluate a model on ImageNet dataset.
Args:
eval_model (BaseEvalModel): model to evaluate
batch_size (int): batch size
imagenet_root (str): path to imagenet root for the specified split.
seed (int, optional): random seed. Defaults to 42.
num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
num_shots (int, optional): number of shots to use. Defaults to 8.
Expand All @@ -826,20 +817,17 @@ def evaluate_classification(
model, tokenizer = eval_model.model, eval_model.tokenizer
assert isinstance(model, Flamingo)

train_dataset = classification_dataset.train_dataset
val_dataset = classification_dataset.val_dataset
class_id_to_label = classification_dataset.class_id_to_label
train_dataset = ImageNetDataset(os.path.join(imagenet_root, "train"))
val_dataset = ImageNetDataset(os.path.join(imagenet_root, "val"))

effective_num_shots = compute_effective_num_shots(num_shots, args.model)
tokenizer.padding_side = (
"left" # For generation padding tokens should be on the left
)

_metrics = defaultdict(
float
) # accumulates metric values over each batch, to be averaged at end.
# TODO(jpgard): iterate over prompts here.
prompt_text = f"<image>{classification_dataset.prompts[0]}"
acc1 = 0
acc5 = 0
prompt_text = "<image>A photo of a"

val_iterator = more_itertools.chunked(val_dataset, batch_size)
for batch_idx, batch in enumerate(val_iterator):
Expand All @@ -849,8 +837,8 @@ def evaluate_classification(
for idx in range(len(batch)):
# Choose a different set of random context samples for each sample
# from the training set
context_indices = classification_dataset.get_in_context_samples(
effective_num_shots
context_indices = np.random.choice(
len(train_dataset), effective_num_shots, replace=False
)

in_context_samples = [train_dataset[i] for i in context_indices]
Expand Down Expand Up @@ -905,12 +893,12 @@ def _detach_pkvs(pkvs):
precomputed_logits = precomputed.logits.detach()

overall_probs = []
for class_name in tqdm(class_id_to_label.values()):
for imagenet_class_name in tqdm(openai_imagenet_classnames):
past_key_values = None
# Tokenize only the class name and iteratively decode the model's
# predictions for this class.
classname_tokens = tokenizer(
class_name, add_special_tokens=False, return_tensors="pt"
imagenet_class_name, add_special_tokens=False, return_tensors="pt"
)["input_ids"].cuda()

if classname_tokens.ndim == 1: # Case: classname is only 1 token
Expand Down Expand Up @@ -961,23 +949,35 @@ def _detach_pkvs(pkvs):

overall_probs = np.row_stack(overall_probs).T # shape [B, num_classes]

targets = [x["class_name"] for x in batch]
metrics = classification_dataset.metric_fn(targets, overall_probs)
def topk(probs_ary: np.ndarray, k: int) -> np.ndarray:
"""Return the indices of the top k elements in probs_ary."""
return np.argsort(probs_ary)[::-1][:k]

for k in metrics.keys():
_metrics[k] += metrics[k]
for i in range(batch_size):
top5 = [
IMAGENET_1K_CLASS_ID_TO_LABEL[pred]
for pred in topk(overall_probs[i], 5)
]

y_i = batch[i]["class_name"]
acc5 += int(y_i in set(top5))
acc1 += int(y_i == top5[0])

print(
f"DEBUG: batch {idx} elem {i} of {batch_size}:"
f"label {y_i} // top5 {top5}"
)

examples_seen = (batch_idx + 1) * batch_size
print(
f"eval {examples_seen}/{num_samples}: "
+ "\n\t".join(
[f"{k}: {v / examples_seen:.4f}" for k, v in _metrics.items()]
"eval {}/{}: acc@1 ({}), acc@5 ({})".format(
examples_seen, num_samples, acc1 / examples_seen, acc5 / examples_seen
)
)
if batch_idx * batch_size >= num_samples - 1:
break

return float(_metrics["acc1"]) / num_samples
return float(acc1) / num_samples


if __name__ == "__main__":
Expand Down

0 comments on commit c2e80b4

Please sign in to comment.