Skip to content

LIONS-EPFL/storm-plus-code

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 

Repository files navigation

storm-plus-code

PyTorch optimizer implementation for STORM+ algorithm

How to use the optimizer?

On the contrary to mainstream (PyTorch) optimizers like SGD, our STORM+ (and STORM algorith itself) requires two backward passes to compute two gradients using the same randomness/data sample at two different point/model weights. To do so, STORM+ optimizer has an auxiliary function called compute_estimator() that takes care of this extra task. We will give a brief description of how to loop over optimizer steps.

for epoch in range(num_epochs):

    # Generate data iterator, CIFAR10 in this case
    train_data_loader = torch.utils.data.DataLoader(CIFAR10_train_set, ...)

    # Enter train mode
    model.train()

    # ONLY FOR THE FIRST BATCH: We need to compute the initial estimator d_1, 
    # which is the first (mini-batch) stochastic gradient g_1. To set the estimator
    # we need to call compute_step() with the first batch.
    (data, label) = iter(train_data_loader).next()
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, label)
    loss.backward()
    optimizer.compute_estimator(normalized_norm=True)

    for i, (data, label) in enumerate(training_data, 0):


        # main optimization step
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()

        # uses \tilde g_t from the backward() call above
        # uses d_t already saved as parameter group state from previous iteration
        optimizer.step()  

        # makes the second pass, backpropagation for the NEXT iterate using the current data batch
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()

        # updates estimate d_{t+1} for the next iteration, saves g_{t+1} for next iteration
        optimizer.compute_estimator(normalized_norm=True)

About

PyTorch optimizer implementation for STORM+ algorithm

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages