Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open Flamingo Perplexity Calculation #289

Open
mustafaadogan opened this issue Feb 11, 2024 · 3 comments
Open

Open Flamingo Perplexity Calculation #289

mustafaadogan opened this issue Feb 11, 2024 · 3 comments

Comments

@mustafaadogan
Copy link

mustafaadogan commented Feb 11, 2024

I'm currently working on Open Flamingo which involves calculating perplexity scores for given sentence-image pairs. I've encountered an issue where the perplexity scores for two captions (one true and one false) are turning out to be the same, despite one of them being incorrect.

I've implemented a perplexity calculation method in Python using PyTorch. The method involves extracting logits from the model output, obtaining true labels from the input text, and then calculating perplexity based on the probabilities assigned to the true labels.

def calculate_perplexity(self, prompt, vision_x):
        """
        Calculate the perplexity score given a prompt and vision data.

        Parameters:
        - prompt (str): The input prompt.
        - vision_x (torch.Tensor): Tensor containing vision data.

        Returns:
        float: Model score.
        """
        
        if self.model is None:
            raise AttributeError('Model is not initialized. Call load_model first!')

        lang_x = self.tokenizer(
            [prompt],
            return_tensors="pt",
        )

        with torch.no_grad():
            model_output = self.model(
                vision_x=vision_x.to(self.device),
                lang_x=lang_x["input_ids"].to(self.device),
                attention_mask=lang_x["attention_mask"].to(self.device)
            )
       
        logits = model_output.logits[0]
        true_labels = lang_x["input_ids"].view(-1)  # Flatten the true labels

        # Extract the probabilities assigned to the true labels
        true_probs = torch.gather(logits, dim=1, index=true_labels.unsqueeze(1))
      
        # Calculate perplexity
        perplexity = torch.exp(-torch.mean(torch.log(true_probs)))

        return float(perplexity)

I've ensured that the token IDs are correctly indexed, and the perplexity calculation seems to be set up correctly. However, the perplexity scores are resulting in nan, and I suspect there might be an issue with the softmax probabilities or numerical instability.

To avoid nan values, I added following code block:

# Add a small epsilon to avoid taking the log of zero
epsilon = 1e-8
true_probs = torch.clamp(true_probs, epsilon, 1.0)

This time, I get same scores for my captions.

Example captions:

True caption: Breakfast items including juice are on the table.

False caption: Breakfast items including juice are off the table.

@yongliang-wu
Copy link

Hi Mustafa, have you solved this problem?

@mustafaadogan
Copy link
Author

I tackled the same scoring challenge but stumbled upon poor performance in zero-shot inference for certain benchmarks, sometimes even worse than random chance. Here's the code I employed:

def calculate_perplexity(self, prompt, vision_x):
        """
        Calculate the perplexity score given a prompt and vision data.

        Parameters:
        - prompt (str): The input prompt.
        - vision_x (torch.Tensor): Tensor containing vision data.

        Returns:
        float: Model score.
        """
        
        if self.model is None:
            raise AttributeError('Model is not initialized. Call load_model first!')

        lang_x = self.tokenizer(
            [prompt],
            return_tensors="pt",
        )

        with torch.no_grad():
            model_output = self.model(
                vision_x=vision_x.to(self.device),
                lang_x=lang_x["input_ids"].to(self.device),
                attention_mask=lang_x["attention_mask"].to(self.device)
            )

        logits = model_output.logits[0].to(self.device)
        true_labels = lang_x["input_ids"].view(-1).to(self.device)  # Flatten the true labels
        
        # Calculate cross-entropy loss
        loss = self.crit(logits, true_labels)
        
        # Calculate perplexity
        perplexity = loss.mean().exp()

        return float(perplexity)

@yongliang-wu
Copy link

Thanks Mustafa!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants