diff --git a/examples/model_selection_psql/ms_model_mlp/model.py b/examples/model_selection_psql/ms_model_mlp/model.py index 890673622..1e2b8191c 100644 --- a/examples/model_selection_psql/ms_model_mlp/model.py +++ b/examples/model_selection_psql/ms_model_mlp/model.py @@ -147,3 +147,49 @@ def train_one_batch(self, x, y, dist_option, spars, synflow_flag): def set_optimizer(self, optimizer): self.optimizer = optimizer + + +def create_model(pretrained=False, **kwargs): + """Constructs a CNN model. + + Args: + pretrained (bool): If True, returns a pre-trained model. + + Returns: + The created CNN model. + """ + model = MSMLP(**kwargs) + + return model + + +__all__ = ['MLP', 'create_model'] + +if __name__ == "__main__": + np.random.seed(0) + + parser = argparse.ArgumentParser() + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-m', + '--max-epoch', + default=1001, + type=int, + help='maximum epochs', + dest='max_epoch') + args = parser.parse_args() + + model = MLP(data_size=2, perceptron_size=3, num_classes=2) + + # attach model to graph + model.set_optimizer(sgd) + model.compile([tx], is_train=True, use_graph=args.graph, sequential=True) + model.train() \ No newline at end of file