Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Entmax loss #3

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

[wip] Entmax loss #3

wants to merge 26 commits into from

Conversation

bpopeters
Copy link
Collaborator

This pull request adds support for entmax loss for training GPT models. This can be done through the --loss_function argument, which supports the following values: 'cross_entropy' (default), 'entmax15', 'sparsemax', and 'entmax_bisect'. 'entmax15' and 'sparsemax' make use of an additional --entmax-topk argument which sensibly defaults to 512. If using 'entmax_bisect', the alpha can be specified with --entmax-alpha (defaulting to 1.5) and --entmax-n-iter (defaulting to 30). Note that these flags work only for GPT models without pipeline parallelism (supporting other models should be easy, although I doubt anyone is interested right now; I don't know what would be required for pipeline parallelism).

I've run some quick tests with entmax15 on artemis with a very small (i.e. 3-layer, 128dim) model on {1, 2, 4} GPUs. Performance is quite a bit worse than cross entropy, but I believe this is (at least partially) an artifact of how small the model was -- the output layer and loss computation probably dominated the runtime in a away that it would not with a more reasonably-sized model. However, my attempts to train bigger models have been unsuccessful because memory usage is shockingly high (not just with entmax loss, also with cross entropy).

Note also that entmax loss does not currently support sequence-parallel loss computation. I'm not sure if this is relevant for our case (meaning, scaling up to 1B parameter models). However, it shouldn't be difficult to implement if we need to.

Before merging, we should probably think more about these performance issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant