diff --git a/.github/workflows/ubuntu.yaml b/.github/workflows/ubuntu.yaml index b67fcdaf8..feb801a4e 100644 --- a/.github/workflows/ubuntu.yaml +++ b/.github/workflows/ubuntu.yaml @@ -40,22 +40,24 @@ jobs: # run: cd build && make # - name: C++ test # run: build/bin/test_singa - + build-cpptest-on-cpu: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v1 - name: get-oneDNN run: wget https://github.com/oneapi-src/oneDNN/releases/download/v1.1/dnnl_lnx_1.1.0_cpu_gomp.tgz -P /tmp/ && tar zxf /tmp/dnnl_lnx_1.1.0_cpu_gomp.tgz -C /tmp + - name: setup-sys-env + run: sudo apt-get install -y curl wget git cmake - name: install-build-dependencies run: sudo apt-get install -y libgoogle-glog-dev libprotobuf-dev protobuf-compiler libncurses-dev libopenblas-dev gfortran libblas-dev liblapack-dev libatlas-base-dev swig dh-autoreconf lcov - name: configure run: mkdir build && cd build && cmake -DUSE_PYTHON=NO -DENABLE_TEST=YES -DCODE_COVERAGE=YES -DUSE_DNNL=YES .. env: - DNNL_ROOT: /tmp/dnnl_lnx_1.1.0_cpu_gomp/ + DNNL_ROOT: /tmp/dnnl_lnx_1.1.0_cpu_gomp/ - name: build - run: cd build && make + run: cd build && make -j8 - name: C++ test run: build/bin/test_singa - name: Upload coverage to Codecov diff --git a/examples/cnn_ms/run.sh b/examples/cnn_ms/run.sh new file mode 100644 index 000000000..a536a1e81 --- /dev/null +++ b/examples/cnn_ms/run.sh @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +#!/usr/bin/env python -W ignore::DeprecationWarning + +### mnist +python train_cnn.py mlp mnist +python train_cnn.py cnn mnist +python train_cnn.py resnet mnist +python train_cnn.py alexnet mnist + +### cifar10 +python train_cnn.py mlp cifar10 +python train_cnn.py cnn cifar10 +python train_cnn.py resnet cifar10 +python train_cnn.py alexnet cifar10 + +### cifar100 +python train_cnn.py mlp cifar100 +python train_cnn.py cnn cifar100 +python train_cnn.py resnet cifar100 +python train_cnn.py alexnet cifar100 diff --git a/examples/cnn_ms/train_cnn.py b/examples/cnn_ms/train_cnn.py new file mode 100644 index 000000000..d7f8f7076 --- /dev/null +++ b/examples/cnn_ms/train_cnn.py @@ -0,0 +1,554 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from singa import singa_wrap as singa +from singa import device +from singa import tensor +from singa import opt +from singa import autograd +from singa.opt import Optimizer +from singa.opt import DecayScheduler +from singa.opt import Constant +import numpy as np +import time +import argparse +from PIL import Image + +np_dtype = {"float16": np.float16, "float32": np.float32} + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +### MSOptimizer +class MSOptimizer(Optimizer): + def __call__(self, loss): + pn_p_g_list = self.call_with_returns(loss) + self.step() + return pn_p_g_list + + def call_with_returns(self, loss): + pn_p_g_list = [] + for p, g in autograd.backward(loss): + if p.name is None: + p.name = id(p) + self.apply(p.name, p, g) + pn_p_g_list.append(p.name, p, g) + return pn_p_g_list + +# MSSGD -- actually no change of code +class MSSGD(MSOptimizer): + """Implements stochastic gradient descent (optionally with momentum). + + Nesterov momentum is based on the formula from `On the importance of initialization and momentum in deep learning`__. + + Args: + lr(float): learning rate + momentum(float, optional): momentum factor(default: 0) + weight_decay(float, optional): weight decay(L2 penalty)(default: 0) + dampening(float, optional): dampening for momentum(default: 0) + nesterov(bool, optional): enables Nesterov momentum(default: False) + + Typical usage example: + >> > from singa import opt + >> > optimizer = opt.SGD(lr=0.1, momentum=0.9) + >> > optimizer.update() + + __ http: // www.cs.toronto.edu / %7Ehinton / absps / momentum.pdf + + .. note:: + The implementation of SGD with Momentum / Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + + Considering the specific case of Momentum, the update can be written as + + .. math:: + v = \rho * v + g \\ + p = p - lr * v + + where p, g, v and: math: `\rho` denote the parameters, gradient, + velocity, and momentum respectively. + + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + + .. math:: + v = \rho * v + lr * g \\ + p = p - v + + The Nesterov version is analogously modified. + """ + + def __init__(self, + lr=0.1, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + dtype=tensor.float32): + super(MSSGD, self).__init__(lr, dtype) + + # init momentum + if type(momentum) == float or type(momentum) == int: + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + self.momentum = Constant(momentum) + elif isinstance(momentum, DecayScheduler): + self.momentum = momentum + momentum = momentum.init_value + else: + raise TypeError("Wrong momentum type") + self.mom_value = self.momentum(self.step_counter).as_type(self.dtype) + + # init dampening + if type(dampening) == float or type(dampening) == int: + self.dampening = Constant(dampening) + elif isinstance(dampening, DecayScheduler): + self.dampening = dampening + dampening = dampening.init_value + else: + raise TypeError("Wrong dampening type") + self.dam_value = self.dampening(self.step_counter).as_type(self.dtype) + + # init weight_decay + if type(weight_decay) == float or type(weight_decay) == int: + if weight_decay < 0.0: + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay)) + self.weight_decay = Constant(weight_decay) + elif isinstance(weight_decay, DecayScheduler): + self.weight_decay = weight_decay + else: + raise TypeError("Wrong weight_decay type") + self.decay_value = self.weight_decay(self.step_counter).as_type( + self.dtype) + + # init other params + self.nesterov = nesterov + self.moments = dict() + + # check value + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening") + + def apply(self, param_name, param_value, param_grad): + """Performs a single optimization step. + + Args: + param_name(String): the name of the param + param_value(Tensor): param values to be update in-place + grad(Tensor): param gradients; the values may be updated + in this function; cannot use it anymore + """ + assert param_value.shape == param_grad.shape, ("shape mismatch", + param_value.shape, + param_grad.shape) + self.device_check(param_value, self.step_counter, self.lr_value, + self.mom_value, self.dam_value, self.decay_value) + + # derive dtype from input + assert param_value.dtype == self.dtype + + # TODO add branch operator + # if self.decay_value != 0: + if self.weight_decay.init_value != 0: + singa.Axpy(self.decay_value.data, param_value.data, param_grad.data) + + if self.momentum.init_value != 0: + if param_name not in self.moments: + flag = param_value.device.graph_enabled() + param_value.device.EnableGraph(False) + self.moments[param_name] = tensor.zeros_like(param_value) + param_value.device.EnableGraph(flag) + + buf = self.moments[param_name] + buf *= self.mom_value + alpha = 1.0 - self.dam_value + singa.Axpy(alpha.data, param_grad.data, buf.data) + + if self.nesterov: + singa.Axpy(self.mom_value.data, buf.data, param_grad.data) + else: + param_grad = buf + + minus_lr = 0.0 - self.lr_value + singa.Axpy(minus_lr.data, param_grad.data, param_value.data) + + def step(self): + # increment step counter, lr and moment + super().step() + mom_value = self.momentum(self.step_counter).as_type(self.dtype) + dam_value = self.dampening(self.step_counter).as_type(self.dtype) + decay_value = self.weight_decay(self.step_counter).as_type(self.dtype) + self.mom_value.copy_from(mom_value) + self.dam_value.copy_from(dam_value) + self.decay_value.copy_from(decay_value) + + def get_states(self): + states = super().get_states() + if self.mom_value > 0: + states[ + 'moments'] = self.moments # a dict for 1st order moments tensors + return states + + def set_states(self, states): + super().set_states(states) + if 'moments' in states: + self.moments = states['moments'] + self.mom_value = self.momentum(self.step_counter) + + +# Data augmentation +def augmentation(x, batch_size): + xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric') + for data_num in range(0, batch_size): + offset = np.random.randint(8, size=2) + x[data_num, :, :, :] = xpad[data_num, :, + offset[0]:offset[0] + x.shape[2], + offset[1]:offset[1] + x.shape[2]] + if_flip = np.random.randint(2) + if (if_flip): + x[data_num, :, :, :] = x[data_num, :, :, ::-1] + return x + + +# Calculate accuracy +def accuracy(pred, target): + # y is network output to be compared with ground truth (int) + y = np.argmax(pred, axis=1) + a = y == target + correct = np.array(a, "int").sum() + return correct + + +# Data partition according to the rank +def partition(global_rank, world_size, train_x, train_y, val_x, val_y): + # Partition training data + data_per_rank = train_x.shape[0] // world_size + idx_start = global_rank * data_per_rank + idx_end = (global_rank + 1) * data_per_rank + train_x = train_x[idx_start:idx_end] + train_y = train_y[idx_start:idx_end] + + # Partition evaluation data + data_per_rank = val_x.shape[0] // world_size + idx_start = global_rank * data_per_rank + idx_end = (global_rank + 1) * data_per_rank + val_x = val_x[idx_start:idx_end] + val_y = val_y[idx_start:idx_end] + return train_x, train_y, val_x, val_y + + +# Function to all reduce NUMPY accuracy and loss from multiple devices +def reduce_variable(variable, dist_opt, reducer): + reducer.copy_from_numpy(variable) + dist_opt.all_reduce(reducer.data) + dist_opt.wait() + output = tensor.to_numpy(reducer) + return output + + +def resize_dataset(x, image_size): + num_data = x.shape[0] + dim = x.shape[1] + X = np.zeros(shape=(num_data, dim, image_size, image_size), + dtype=np.float32) + for n in range(0, num_data): + for d in range(0, dim): + X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize( + (image_size, image_size), Image.BILINEAR), + dtype=np.float32) + return X + + +def run(global_rank, + world_size, + local_rank, + max_epoch, + batch_size, + model, + data, + mssgd, + graph, + verbosity, + dist_option='plain', + spars=None, + precision='float32'): + # dev = device.create_cuda_gpu_on(local_rank) # need to change to CPU device for CPU-only machines + dev = device.get_default_device() + dev.SetRandSeed(0) + np.random.seed(0) + + if data == 'cifar10': + from data import cifar10 + train_x, train_y, val_x, val_y = cifar10.load() + elif data == 'cifar100': + from data import cifar100 + train_x, train_y, val_x, val_y = cifar100.load() + elif data == 'mnist': + from data import mnist + train_x, train_y, val_x, val_y = mnist.load() + + + num_channels = train_x.shape[1] + image_size = train_x.shape[2] + data_size = np.prod(train_x.shape[1:train_x.ndim]).item() + num_classes = (np.max(train_y) + 1).item() + + if model == 'resnet': + from model import resnet + model = resnet.resnet50(num_channels=num_channels, + num_classes=num_classes) + elif model == 'xceptionnet': + from model import xceptionnet + model = xceptionnet.create_model(num_channels=num_channels, + num_classes=num_classes) + elif model == 'cnn': + from model import cnn + model = cnn.create_model(num_channels=num_channels, + num_classes=num_classes) + elif model == 'alexnet': + from model import alexnet + model = alexnet.create_model(num_channels=num_channels, + num_classes=num_classes) + elif model == 'mlp': + import os, sys, inspect + current = os.path.dirname( + os.path.abspath(inspect.getfile(inspect.currentframe()))) + parent = os.path.dirname(current) + sys.path.insert(0, parent) + from mlp import model + model = model.create_model(data_size=data_size, + num_classes=num_classes) + + elif model == 'msmlp': + import os, sys, inspect + current = os.path.dirname( + os.path.abspath(inspect.getfile(inspect.currentframe()))) + parent = os.path.dirname(current) + sys.path.insert(0, parent) + from msmlp import model + model = model.create_model(data_size=data_size, + num_classes=num_classes) + + # For distributed training, sequential has better performance + if hasattr(mssgd, "communicator"): + DIST = True + sequential = True + else: + DIST = False + sequential = False + + if DIST: + train_x, train_y, val_x, val_y = partition(global_rank, world_size, + train_x, train_y, val_x, + val_y) + + if model.dimension == 4: + tx = tensor.Tensor( + (batch_size, num_channels, model.input_size, model.input_size), dev, + singa_dtype[precision]) + elif model.dimension == 2: + tx = tensor.Tensor((batch_size, data_size), dev, singa_dtype[precision]) + np.reshape(train_x, (train_x.shape[0], -1)) + np.reshape(val_x, (val_x.shape[0], -1)) + + ty = tensor.Tensor((batch_size,), dev, tensor.int32) + num_train_batch = train_x.shape[0] // batch_size + num_val_batch = val_x.shape[0] // batch_size + idx = np.arange(train_x.shape[0], dtype=np.int32) + + # Attach model to graph + model.set_optimizer(mssgd) + model.compile([tx], is_train=True, use_graph=graph, sequential=sequential) + dev.SetVerbosity(verbosity) + + # Training and evaluation loop + for epoch in range(max_epoch): + start_time = time.time() + np.random.shuffle(idx) + + if global_rank == 0: + print('Starting Epoch %d:' % (epoch)) + + # Training phase + train_correct = np.zeros(shape=[1], dtype=np.float32) + test_correct = np.zeros(shape=[1], dtype=np.float32) + train_loss = np.zeros(shape=[1], dtype=np.float32) + + model.train() + print ("num_train_batch: \n", num_train_batch) + print () + for b in range(num_train_batch): + if b % 200 == 0: + print ("b: \n", b) + # Generate the patch data in this iteration + x = train_x[idx[b * batch_size:(b + 1) * batch_size]] + if model.dimension == 4: + x = augmentation(x, batch_size) + if (image_size != model.input_size): + x = resize_dataset(x, model.input_size) + x = x.astype(np_dtype[precision]) + y = train_y[idx[b * batch_size:(b + 1) * batch_size]] + + + synflow_flag = False + # Train the model + if epoch == (max_epoch - 1) and b == (num_train_batch - 1): ### synflow calcuation for the last batch + print ("last epoch calculate synflow") + synflow_flag = True + ### step 1: all one input + # Copy the patch data into input tensors + tx.copy_from_numpy(np.ones(x.shape, dtype=np.float32)) + ty.copy_from_numpy(y) + ### step 2: all weights turned to positive (done) + ### step 3: new loss (done) + pn_p_g_list, out, loss = model(tx, ty,dist_option, spars, synflow_flag) + ### step 4: calculate the multiplication of weights + synflow_score = 0.0 + for pn_p_g_item in pn_p_g_list: + print ("calculate weight param * grad parameter name: \n", pn_p_g_item[0]) + if len(pn_p_g_item[1].data.shape) == 2: # param_value.data is "weight" + synflow_score += np.sum(np.absolute(tensor.to_numpy(pn_p_g_item[1].data) * tensor.to_numpy(pn_p_g_item[2].data))) + print ("synflow_score: \n", synflow_score) + elif epoch == (max_epoch - 1) and b == (num_train_batch - 2): # all weights turned to positive + # Copy the patch data into input tensors + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + pn_p_g_list, out, loss = model(tx, ty, dist_option, spars, synflow_flag) + train_correct += accuracy(tensor.to_numpy(out), y) + train_loss += tensor.to_numpy(loss)[0] + # all params turned to positive + for pn_p_g_item in pn_p_g_list: + print ("absolute value parameter name: \n", pn_p_g_item[0]) + pn_p_g_item[1] = tensor.abs(pn_p_g_item[1]) # tensor variables + else: # normal train steps + # Copy the patch data into input tensors + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars) + train_correct += accuracy(tensor.to_numpy(out), y) + train_loss += tensor.to_numpy(loss)[0] + + if DIST: + # Reduce the evaluation accuracy and loss from multiple devices + reducer = tensor.Tensor((1,), dev, tensor.float32) + train_correct = reduce_variable(train_correct, mssgd, reducer) + train_loss = reduce_variable(train_loss, mssgd, reducer) + + if global_rank == 0: + print('Training loss = %f, training accuracy = %f' % + (train_loss, train_correct / + (num_train_batch * batch_size * world_size)), + flush=True) + + # Evaluation phase + model.eval() + for b in range(num_val_batch): + x = val_x[b * batch_size:(b + 1) * batch_size] + if model.dimension == 4: + if (image_size != model.input_size): + x = resize_dataset(x, model.input_size) + x = x.astype(np_dtype[precision]) + y = val_y[b * batch_size:(b + 1) * batch_size] + tx.copy_from_numpy(x) + ty.copy_from_numpy(y) + out_test = model(tx) + test_correct += accuracy(tensor.to_numpy(out_test), y) + + if DIST: + # Reduce the evaulation accuracy from multiple devices + test_correct = reduce_variable(test_correct, mssgd, reducer) + + # Output the evaluation accuracy + if global_rank == 0: + print('Evaluation accuracy = %f, Elapsed Time = %fs' % + (test_correct / (num_val_batch * batch_size * world_size), + time.time() - start_time), + flush=True) + + dev.PrintTimeProfiling() + + +if __name__ == '__main__': + # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training + parser = argparse.ArgumentParser( + description='Training using the autograd and graph.') + parser.add_argument( + 'model', + choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'msmlp', 'alexnet'], + default='cnn') + parser.add_argument('data', + choices=['mnist', 'cifar10', 'cifar100'], + default='mnist') + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-m', + '--max-epoch', + default=100, + type=int, + help='maximum epochs', + dest='max_epoch') + parser.add_argument('-b', + '--batch-size', + default=64, + type=int, + help='batch size', + dest='batch_size') + parser.add_argument('-l', + '--learning-rate', + default=0.005, + type=float, + help='initial learning rate', + dest='lr') + # Determine which gpu to use + parser.add_argument('-i', + '--device-id', + default=0, + type=int, + help='which GPU to use', + dest='device_id') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-v', + '--log-verbosity', + default=0, + type=int, + help='logging verbosity', + dest='verbosity') + + args = parser.parse_args() + + mssgd = MSSGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision]) + run(0, + 1, + args.device_id, + args.max_epoch, + args.batch_size, + args.model, + args.data, + mssgd, + args.graph, + args.verbosity, + precision=args.precision) diff --git a/examples/cnn_ms/train_mpi.py b/examples/cnn_ms/train_mpi.py new file mode 100644 index 000000000..563d4b2c5 --- /dev/null +++ b/examples/cnn_ms/train_mpi.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +from singa import singa_wrap as singa +from singa import opt +from singa import tensor +import argparse +import train_cnn + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +if __name__ == '__main__': + # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training + parser = argparse.ArgumentParser( + description='Training using the autograd and graph.') + parser.add_argument('model', + choices=['cnn', 'resnet', 'xceptionnet', 'mlp'], + default='cnn') + parser.add_argument('data', choices=['mnist', 'cifar10', 'cifar100'], default='mnist') + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-m', + '--max-epoch', + default=10, + type=int, + help='maximum epochs', + dest='max_epoch') + parser.add_argument('-b', + '--batch-size', + default=64, + type=int, + help='batch size', + dest='batch_size') + parser.add_argument('-l', + '--learning-rate', + default=0.005, + type=float, + help='initial learning rate', + dest='lr') + parser.add_argument('-d', + '--dist-option', + default='plain', + choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'], + help='distibuted training options', + dest='dist_option') # currently partialUpdate support graph=False only + parser.add_argument('-s', + '--sparsification', + default='0.05', + type=float, + help='the sparsity parameter used for sparsification, between 0 to 1', + dest='spars') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-v', + '--log-verbosity', + default=0, + type=int, + help='logging verbosity', + dest='verbosity') + + args = parser.parse_args() + + sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision]) + sgd = opt.DistOpt(sgd) + + train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch, + args.batch_size, args.model, args.data, sgd, args.graph, + args.verbosity, args.dist_option, args.spars, args.precision) diff --git a/examples/cnn_ms/train_multiprocess.py b/examples/cnn_ms/train_multiprocess.py new file mode 100644 index 000000000..182dd35ee --- /dev/null +++ b/examples/cnn_ms/train_multiprocess.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +from singa import singa_wrap as singa +from singa import opt +from singa import tensor +import argparse +import train_cnn +import multiprocessing + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +def run(args, local_rank, world_size, nccl_id): + sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5, dtype=singa_dtype[args.precision]) + sgd = opt.DistOpt(sgd, nccl_id=nccl_id, local_rank=local_rank, world_size=world_size) + train_cnn.run(sgd.global_rank, sgd.world_size, sgd.local_rank, args.max_epoch, + args.batch_size, args.model, args.data, sgd, args.graph, + args.verbosity, args.dist_option, args.spars, args.precision) + + +if __name__ == '__main__': + # Use argparse to get command config: max_epoch, model, data, etc., for single gpu training + parser = argparse.ArgumentParser( + description='Training using the autograd and graph.') + parser.add_argument('model', + choices=['resnet', 'xceptionnet', 'cnn', 'mlp'], + default='cnn') + parser.add_argument('data', choices=['cifar10', 'cifar100', 'mnist'], default='mnist') + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-m', + '--max-epoch', + default=10, + type=int, + help='maximum epochs', + dest='max_epoch') + parser.add_argument('-b', + '--batch-size', + default=64, + type=int, + help='batch size', + dest='batch_size') + parser.add_argument('-l', + '--learning-rate', + default=0.005, + type=float, + help='initial learning rate', + dest='lr') + parser.add_argument('-w', + '--world-size', + default=2, + type=int, + help='number of gpus to be used', + dest='world_size') + parser.add_argument('-d', + '--dist-option', + default='plain', + choices=['plain','half','partialUpdate','sparseTopK','sparseThreshold'], + help='distibuted training options', + dest='dist_option') # currently partialUpdate support graph=False only + parser.add_argument('-s', + '--sparsification', + default='0.05', + type=float, + help='the sparsity parameter used for sparsification, between 0 to 1', + dest='spars') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-v', + '--log-verbosity', + default=0, + type=int, + help='logging verbosity', + dest='verbosity') + + args = parser.parse_args() + + # Generate a NCCL ID to be used for collective communication + nccl_id = singa.NcclIdHolder() + + process = [] + for local_rank in range(0, args.world_size): + process.append( + multiprocessing.Process(target=run, + args=(args, local_rank, args.world_size, nccl_id))) + + for p in process: + p.start() diff --git a/examples/hfl/README.md b/examples/hfl/README.md index cf20e64cd..2916bc560 100644 --- a/examples/hfl/README.md +++ b/examples/hfl/README.md @@ -27,7 +27,7 @@ This example uses the Bank dataset and an MLP model in FL. ## Preparation -Go to the Conda environment that contains the Singa library, and run +Go to the Conda environment that contains the Singa library, and install the required libraries. ```bash pip install -r requirements.txt @@ -41,18 +41,18 @@ Download the bank dataset and split it into 3 partitions. # 3. run the following command which: # (1) splits the dataset into N subsets # (2) splits each subsets into train set and test set (8:2) -python -m bank N +python -m bank 3 ``` ## Run the example -Run the server first (set the number of epochs to 3) +Run the server first (set the maximum number of epochs to 3 by the "-m" parameter) ```bash python -m src.server -m 3 --num_clients 3 ``` -Then, start 3 clients in different terminal +Then, start 3 clients in different terminals (similarly set the maximum number of epochs to 3) ```bash python -m src.client --model mlp --data bank -m 3 -i 0 -d non-iid @@ -60,4 +60,4 @@ python -m src.client --model mlp --data bank -m 3 -i 1 -d non-iid python -m src.client --model mlp --data bank -m 3 -i 2 -d non-iid ``` -Finally, the server and clients finish the FL training. \ No newline at end of file +Finally, the server and clients finish the FL training. diff --git a/examples/hfl/src/client.py b/examples/hfl/src/client.py index 80ab11f3a..dbff42b4d 100644 --- a/examples/hfl/src/client.py +++ b/examples/hfl/src/client.py @@ -40,6 +40,7 @@ np_dtype = {"float16": np.float16, "float32": np.float32} singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + class Client: """Client sends and receives protobuf messages. @@ -63,6 +64,7 @@ def __init__( Args: global_rank (int, optional): The rank in training process. Defaults to 0. + Provided by the '-i' parameter (device_id) in the running script. host (str, optional): Host ip address. Defaults to '127.0.0.1'. port (str, optional): Port. Defaults to 1234. """ diff --git a/examples/hfl/src/server.py b/examples/hfl/src/server.py index 7450cc1cf..68780e13c 100644 --- a/examples/hfl/src/server.py +++ b/examples/hfl/src/server.py @@ -80,6 +80,7 @@ def __start_rank_pairing(self) -> None: """Start pair each client to a global rank""" for _ in range(self.num_clients): conn, addr = self.sock.accept() + # rank is the global device_id when initializing the client rank = utils.receive_int(conn) self.conns[rank] = conn self.addrs[rank] = addr diff --git a/examples/model_selection/Trails/README.md b/examples/model_selection/Trails/README.md index 39bd01260..8f1452507 100644 --- a/examples/model_selection/Trails/README.md +++ b/examples/model_selection/Trails/README.md @@ -23,7 +23,7 @@ ![image-20230702035806963](documents/ai_db.001.jpeg) -# Build & Run examples +# Build & Run examples: ## Singa + PostgreSQL diff --git a/examples/model_slicing_psql/README.md b/examples/model_slicing_psql/README.md new file mode 100644 index 000000000..bfdbd5e06 --- /dev/null +++ b/examples/model_slicing_psql/README.md @@ -0,0 +1,22 @@ + + +# Dynamic Model Slicing on PostgreSQL + +Examples inside this folder show how to dynamically slice a model for a subset of database records dynamically specified by a corresponding SQL query inside RDBMS, such as PostgreSQL. \ No newline at end of file diff --git a/examples/msmlp/model.py b/examples/msmlp/model.py new file mode 100644 index 000000000..2a4d0e663 --- /dev/null +++ b/examples/msmlp/model.py @@ -0,0 +1,202 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from singa import layer +from singa import model +from singa import tensor +from singa import opt +from singa import device +from singa.autograd import Operator +from singa.layer import Layer +from singa import singa_wrap as singa +import argparse +import numpy as np + +np_dtype = {"float16": np.float16, "float32": np.float32} + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +#### self-defined loss begin + +### from autograd.py +class SumError(Operator): + + def __init__(self): + super(SumError, self).__init__() + # self.t = t.data + + def forward(self, x): + # self.err = singa.__sub__(x, self.t) + self.data_x = x + # sqr = singa.Square(self.err) + # loss = singa.SumAll(sqr) + loss = singa.SumAll(x) + # self.n = 1 + # for s in x.shape(): + # self.n *= s + # loss /= self.n + return loss + + def backward(self, dy=1.0): + # dx = self.err + dev = device.get_default_device() + dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32']) + dx.copy_from_numpy(np.ones(self.data_x.shape)) + # dx *= float(2 / self.n) + dx *= dy + return dx + +def se_loss(x): + # assert x.shape == t.shape, "input and target shape different: %s, %s" % ( + # x.shape, t.shape) + return SumError()(x)[0] + +### from layer.py +class SumErrorLayer(Layer): + """ + Generate a MeanSquareError operator + """ + + def __init__(self): + super(SumErrorLayer, self).__init__() + + def forward(self, x): + return se_loss(x) + +#### self-defined loss end + +class MSMLP(model.Model): + + def __init__(self, data_size=10, perceptron_size=100, num_classes=10): + super(MSMLP, self).__init__() + self.num_classes = num_classes + self.dimension = 2 + + self.relu = layer.ReLU() + self.linear1 = layer.Linear(perceptron_size) + self.linear2 = layer.Linear(num_classes) + self.softmax_cross_entropy = layer.SoftMaxCrossEntropy() + self.sum_error = SumErrorLayer() + + def forward(self, inputs): + y = self.linear1(inputs) + y = self.relu(y) + y = self.linear2(y) + return y + + def train_one_batch(self, x, y, synflow_flag, dist_option, spars): + out = self.forward(x) + if synflow_flag: + loss = self.sum_error(out) + else: # normal training + loss = self.softmax_cross_entropy(out, y) + + if dist_option == 'plain': + pn_p_g_list = self.optimizer(loss) + elif dist_option == 'half': + self.optimizer.backward_and_update_half(loss) + elif dist_option == 'partialUpdate': + self.optimizer.backward_and_partial_update(loss) + elif dist_option == 'sparseTopK': + self.optimizer.backward_and_sparse_update(loss, + topK=True, + spars=spars) + elif dist_option == 'sparseThreshold': + self.optimizer.backward_and_sparse_update(loss, + topK=False, + spars=spars) + return pn_p_g_list, out, loss + + def set_optimizer(self, optimizer): + self.optimizer = optimizer + + +def create_model(pretrained=False, **kwargs): + """Constructs a CNN model. + + Args: + pretrained (bool): If True, returns a pre-trained model. + + Returns: + The created CNN model. + """ + model = MSMLP(**kwargs) + + return model + + +__all__ = ['MLP', 'create_model'] + +if __name__ == "__main__": + np.random.seed(0) + + parser = argparse.ArgumentParser() + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-g', + '--disable-graph', + default='True', + action='store_false', + help='disable graph', + dest='graph') + parser.add_argument('-m', + '--max-epoch', + default=1001, + type=int, + help='maximum epochs', + dest='max_epoch') + args = parser.parse_args() + + # generate the boundary + f = lambda x: (5 * x + 1) + bd_x = np.linspace(-1.0, 1, 200) + bd_y = f(bd_x) + + # generate the training data + x = np.random.uniform(-1, 1, 400) + y = f(x) + 2 * np.random.randn(len(x)) + + # choose one precision + precision = singa_dtype[args.precision] + np_precision = np_dtype[args.precision] + + # convert training data to 2d space + label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]).astype(np.int32) + data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np_precision) + + dev = device.create_cuda_gpu_on(0) + sgd = opt.SGD(0.1, 0.9, 1e-5, dtype=singa_dtype[args.precision]) + tx = tensor.Tensor((400, 2), dev, precision) + ty = tensor.Tensor((400,), dev, tensor.int32) + model = MLP(data_size=2, perceptron_size=3, num_classes=2) + + # attach model to graph + model.set_optimizer(sgd) + model.compile([tx], is_train=True, use_graph=args.graph, sequential=True) + model.train() + + for i in range(args.max_epoch): + tx.copy_from_numpy(data) + ty.copy_from_numpy(label) + out, loss = model(tx, ty, 'fp32', spars=None) + + if i % 100 == 0: + print("training loss = ", tensor.to_numpy(loss)[0]) diff --git a/examples/msmlp/native.py b/examples/msmlp/native.py new file mode 100644 index 000000000..a82ec3b24 --- /dev/null +++ b/examples/msmlp/native.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from singa import tensor +from singa.tensor import Tensor +from singa import autograd +from singa import opt +import numpy as np +from singa import device +import argparse + +np_dtype = {"float16": np.float16, "float32": np.float32} + +singa_dtype = {"float16": tensor.float16, "float32": tensor.float32} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-p', + choices=['float32', 'float16'], + default='float32', + dest='precision') + parser.add_argument('-m', + '--max-epoch', + default=1001, + type=int, + help='maximum epochs', + dest='max_epoch') + args = parser.parse_args() + + np.random.seed(0) + + autograd.training = True + + # prepare training data in numpy array + + # generate the boundary + f = lambda x: (5 * x + 1) + bd_x = np.linspace(-1.0, 1, 200) + bd_y = f(bd_x) + + # generate the training data + x = np.random.uniform(-1, 1, 400) + y = f(x) + 2 * np.random.randn(len(x)) + + # convert training data to 2d space + label = np.asarray([5 * a + 1 > b for (a, b) in zip(x, y)]) + data = np.array([[a, b] for (a, b) in zip(x, y)], dtype=np.float32) + + def to_categorical(y, num_classes): + """ + Converts a class vector (integers) to binary class matrix. + + Args: + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + + Returns: + A binary matrix representation of the input. + """ + y = np.array(y, dtype="int") + n = y.shape[0] + categorical = np.zeros((n, num_classes)) + categorical[np.arange(n), y] = 1 + return categorical + + label = to_categorical(label, 2).astype(np.float32) + print("train_data_shape:", data.shape) + print("train_label_shape:", label.shape) + + precision = singa_dtype[args.precision] + np_precision = np_dtype[args.precision] + + dev = device.create_cuda_gpu() + + inputs = Tensor(data=data, device=dev) + target = Tensor(data=label, device=dev) + + inputs = inputs.as_type(precision) + target = target.as_type(tensor.int32) + + w0_np = np.random.normal(0, 0.1, (2, 3)).astype(np_precision) + w0 = Tensor(data=w0_np, + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b0 = Tensor(shape=(3,), + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b0.set_value(0.0) + + w1_np = np.random.normal(0, 0.1, (3, 2)).astype(np_precision) + w1 = Tensor(data=w1_np, + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b1 = Tensor(shape=(2,), + device=dev, + dtype=precision, + requires_grad=True, + stores_grad=True) + b1.set_value(0.0) + + sgd = opt.SGD(0.05, 0.8) + + # training process + for i in range(args.max_epoch): + x = autograd.matmul(inputs, w0) + x = autograd.add_bias(x, b0) + x = autograd.relu(x) + x = autograd.matmul(x, w1) + x = autograd.add_bias(x, b1) + loss = autograd.softmax_cross_entropy(x, target) + sgd(loss) + + if i % 100 == 0: + print("%d, training loss = " % i, tensor.to_numpy(loss)[0])