Skip to content

Commit

Permalink
Merge pull request #1092 from liuchangshiye/model-selection-psql-ms-m…
Browse files Browse the repository at this point in the history
…odel-mlp

Add create model function for the dynamic model
  • Loading branch information
lzjpaul authored Sep 8, 2023
2 parents 5b36459 + 9e955b1 commit 20c9d79
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions examples/model_selection_psql/ms_model_mlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,49 @@ def train_one_batch(self, x, y, dist_option, spars, synflow_flag):

def set_optimizer(self, optimizer):
self.optimizer = optimizer


def create_model(pretrained=False, **kwargs):
"""Constructs a CNN model.
Args:
pretrained (bool): If True, returns a pre-trained model.
Returns:
The created CNN model.
"""
model = MSMLP(**kwargs)

return model


__all__ = ['MLP', 'create_model']

if __name__ == "__main__":
np.random.seed(0)

parser = argparse.ArgumentParser()
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-g',
'--disable-graph',
default='True',
action='store_false',
help='disable graph',
dest='graph')
parser.add_argument('-m',
'--max-epoch',
default=1001,
type=int,
help='maximum epochs',
dest='max_epoch')
args = parser.parse_args()

model = MLP(data_size=2, perceptron_size=3, num_classes=2)

# attach model to graph
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=args.graph, sequential=True)
model.train()

0 comments on commit 20c9d79

Please sign in to comment.