-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom training routines #3
base: master
Are you sure you want to change the base?
Conversation
Dropped checkpointing and monitoring in train_celeba_classifier.py from every 5 epochs to every epoch - because the model converges so quickly that even 5 seems like overkill.
The training can now be done on a subset of the 40 available CelebA labels. For example, to train only on "lipstick" and "big lips", just provide the command line option with numbers that match those columns: --allow 6,36 And all other labels will be zeroed out by a Transformer. Also added the ability to specify other options on the command line: --classifier for output file, defaults to "celeba_classifier.zip" --batch-size for size of all mini-batches. defaults to 100
Added cli options: --classifer for input classifier, defaults to "celeba_classifier.zip" --model for output model, defaults to "celeba_vae_regularization.zip" --batch-size for all mini-batches, defaults to 100 --z-dim to control latent dimensionality, defaults to 1000
Added three scaling factors with command line options to change the relative weighting of the loss function: reconstruction_factor, kl_factor, and discriminative_factor.
Added cli options --monitor-every and --checkpoint-every which change how often monitoring and checkpointing occur during training.
Added options to allow swapping out the celeba dataset with any other fuel compatible dataset with 64x64 features. If the dataset is grayscale instead of color, the color-convert option can also be used to dynamically transform the data into color as it comes in. Note that for train_celeba_classifier the custom fuel dataset must also have compatible targets, which is unlikely already the case for common datasets but it certainly possible with some customization. However, for train_celeba_vae any fuel dataset with 64x64 features can be used - either with regularize left off or by using an existing celeba classifier or any other classifier trained on data close enough to the target dataset.
Added cli options --monitor-every and --checkpoint-every which change how often monitoring and checkpointing occur during training.
Added an option to start training from a prevoius model checkpoint. Note that this also involved a change to MainLoop to add the model, which sadly means that previously saved checkpoints are not usable. Also updated train_monitoring options to set before_first_epoch=True as this is useful to verify that the checkpoint was loaded successfully.
Added --oldmodel to start training from a previously saved model state. Increased the width of the labels from 40 to 64 to support different datasets. Will need to write a transformer to deal with filling out from shorter label lengths.
Moved the allowed option used to filter labels from create_celeba_streams to create_custom_streams.
@@ -117,6 +124,79 @@ def create_svhn_streams(training_batch_size, monitoring_batch_size): | |||
monitoring_batch_size) | |||
|
|||
|
|||
class Colorize(AgnosticSourcewiseTransformer): | |||
def __init__(self, data_stream, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a docstring explaining what this transformer does?
@dribnet I did a first review pass. Could you also |
Thanks @vdumoulin for the constructive feedback, glad to hear you think overall this would be a welcome addition. I'll be reviewing and updating this PR over the next week. |
Replaced block of custom code with fuel.utils.find_in_data_path and added comments to several functions.
Formatting updates after running flake8 for better legibility.
If incoming dataset doesn't provide enough labels for standard training, they can be zero-padded with the --stretch option.
@vdumoulin - I've addressed most issues in the previous review including a general flake8 code cleanup. |
These series of commits adds two new experiments which act as general purpose tools as discussed in #2, which this pull request is meant to replace. It adds a more generic version of train_classifier and train_vae, which are intended to be compatible with the celeba versions but with many additional options.
discriminative_term
in the cost function so it could be monitored separately