Skip to content

Commit

Permalink
Merge pull request #1093 from NLGithubWP/update_model
Browse files Browse the repository at this point in the history
Add training process for the dynamic model
  • Loading branch information
lzjpaul authored Sep 8, 2023
2 parents 20c9d79 + f84ebe5 commit 5df0cec
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions examples/model_selection_psql/ms_model_mlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, data_size=10, perceptron_size=100, num_classes=10, layer_hidd
self.linear5 = layer.Linear(num_classes)
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
self.sum_error = SumErrorLayer()

def forward(self, inputs):
y = self.linear1(inputs)
y = self.relu(y)
Expand Down Expand Up @@ -187,9 +187,38 @@ def create_model(pretrained=False, **kwargs):
dest='max_epoch')
args = parser.parse_args()

# generate the boundary
f = lambda x: (5 * x + 1)
bd_x = np.linspace(-1.0, 1, 200)
bd_y = f(bd_x)

# generate the training data
x = np.random.uniform(-1, 1, 400)
y = f(x) + 2 * np.random.randn(len(x))

# choose one precision
precision = singa_dtype[args.precision]
np_precision = np_dtype[args.precision]

# convert training data to 2d space
label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32)
data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision)

dev = device.create_cuda_gpu_on(0)
sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision])
tx = tensor.Tensor((400, 2), dev, precision)
ty = tensor.Tensor((400,), dev, tensor.int32)
model = MLP(data_size=2, perceptron_size=3, num_classes=2)

# attach model to graph
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=args.graph, sequential=True)
model.train()
model.train()

for i in range(args.max_epoch):
tx.copy_from_numpy(data)
ty.copy_from_numpy(label)
out, loss = model(tx, ty, 'fp32', spars=None)

if i % 100 == 0:
print("training loss = ", tensor.to_numpy(loss)[0])

0 comments on commit 5df0cec

Please sign in to comment.