Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could you give an example of the parameters in loss function? #1

Open
teinhonglo opened this issue Feb 4, 2023 · 2 comments
Open

Comments

@teinhonglo
Copy link

Hi,

Thanks for the code you shared.
I would like to copy the code to my project, but I encountered some mistakes.
Could you give an example of the parameters below? (such as logits, probas, true labels, dist_mat, distances)

Thanks in advance.

class OLL1Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        num_classes = model.module.num_labels
        dist_matrix = model.module.dist_matrix
        labels = inputs["labels"]
        outputs = model(**inputs)
        logits = outputs.logits
        probas = F.softmax(logits,dim=1)
        true_labels = [num_classes*[labels[k].item()] for k in range(len(labels))]
        label_ids = len(labels)*[[k for k in range(num_classes)]]
        distances = [[float(dist_matrix[true_labels[j][i]][label_ids[j][i]]) for i in range(num_classes)] for j in range(len(labels))]
        distances_tensor = torch.tensor(distances,device='cuda:0', requires_grad=True)
        err = -torch.log(1-probas)*distances_tensor
        loss = torch.sum(err,axis=1).mean()
        return (loss, outputs) if return_outputs else loss
@castafra
Copy link
Collaborator

Hello,

Sorry for my late reply, I did not see that an issue had been opened. I hope it's not too late. You can find an values of dist_matrix in the datasets.json file. It is a list of lists (for example [[0, 1, 2, 3, 4], [1, 0, 1, 2, 3], [2, 1, 0, 1, 2], [3, 2, 1, 0, 1], [4, 3, 2, 1, 0]])

Logits is a tensor matrix containing the output probabilities for each class, for each input.

Could you send me the error you're getting so I can try to debug it ?

@teinhonglo
Copy link
Author

teinhonglo commented May 3, 2023

Hello,

Sorry for the misleading words above.
The abovementioned mistake means the OLL loss is worse than the Cross-Entropy (CE) loss in our experiment.
I just replaced CE with OLL in our experiment.

  1. OLLoss:
    https://github.com/teinhonglo/CEFR-SP/blob/main/src/losses/losses.py#L239-L276
  2. Replaced CE with OLL
    https://github.com/teinhonglo/CEFR-SP/blob/main/src/model.py#L69

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants