-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cherry pick (batch 2) to rel-1.5.1 (#5290)
* remove implicit linking of tensorrt and dnnl ep shared libs (#5262) * Update DirectML Nuget to 1.3.0 (#5274) * Update PyTorch TransformerModel sample (#5275) * Insert telemetry template into GPU build, add telemry build switches. (#5278) * Synchronize training dependency versions between Docker image and Python wheel (#5261) * Downgrade GCC (#5269) * Remove --enable_symbolic_shape_infer_tests to fix linux ci pipeline build error. Co-authored-by: Edward Chen Co-authored-by: George Wu <jywu@microsoft.com> Co-authored-by: Dwayne Robinson <dwayner@microsoft.com> Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Dmitri Smirnov <yuslepukhin@users.noreply.github.com> Co-authored-by: edgchen1 <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Changming Sun <chasun@microsoft.com>
- Loading branch information
1 parent
389cca7
commit c00e13a
Showing
57 changed files
with
651 additions
and
334 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
<?xml version="1.0" encoding="utf-8"?> | ||
<packages> | ||
<package id="DirectML" version="3.0.0" targetFramework="native" /> | ||
<package id="DirectML" version="1.3.0" targetFramework="native" /> | ||
<package id="GoogleTestAdapter" version="0.17.1" targetFramework="net46" /> | ||
</packages> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import argparse | ||
import math | ||
import torch | ||
import onnxruntime | ||
|
||
from utils import prepare_data, get_batch | ||
from ort_utils import my_loss, transformer_model_description_dynamic_axes | ||
from pt_model import TransformerModel | ||
|
||
|
||
def train(trainer, data_source, device, epoch, args, bptt=35): | ||
total_loss = 0. | ||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): | ||
data, targets = get_batch(data_source, i) | ||
|
||
loss, pred = trainer.train_step(data, targets) | ||
total_loss += loss.item() | ||
if batch % args.log_interval == 0 and batch > 0: | ||
cur_loss = total_loss / args.log_interval | ||
print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, | ||
batch, | ||
len(data_source) // bptt, | ||
cur_loss)) | ||
total_loss = 0 | ||
|
||
|
||
def evaluate(trainer, data_source, bptt=35): | ||
total_loss = 0. | ||
with torch.no_grad(): | ||
for i in range(0, data_source.size(0) - 1, bptt): | ||
data, targets = get_batch(data_source, i) | ||
loss, pred = trainer.eval_step(data, targets) | ||
total_loss += len(data) * loss.item() | ||
return total_loss / (len(data_source) - 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') | ||
parser.add_argument('--batch-size', type=int, default=20, metavar='N', | ||
help='input batch size for training (default: 20)') | ||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', | ||
help='input batch size for testing (default: 20)') | ||
parser.add_argument('--epochs', type=int, default=2, metavar='N', | ||
help='number of epochs to train (default: 2)') | ||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | ||
help='learning rate (default: 0.001)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--log-interval', type=int, default=200, metavar='N', | ||
help='how many batches to wait before logging training status (default: 200)') | ||
|
||
# Basic setup | ||
args = parser.parse_args() | ||
if not args.no_cuda and torch.cuda.is_available(): | ||
device = "cuda" | ||
else: | ||
device = "cpu" | ||
torch.manual_seed(args.seed) | ||
onnxruntime.set_seed(args.seed) | ||
|
||
# Model | ||
optim_config = onnxruntime.training.optim.SGDConfig(lr=args.lr) | ||
model_desc = transformer_model_description_dynamic_axes() | ||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) | ||
|
||
# Preparing data | ||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) | ||
trainer = onnxruntime.training.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss) | ||
|
||
# Train | ||
for epoch in range(1, args.epochs + 1): | ||
train(trainer, train_data, device, epoch, args) | ||
val_loss = evaluate(trainer, val_data) | ||
print('-' * 89) | ||
print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) | ||
print('-' * 89) | ||
|
||
# Evaluate | ||
test_loss = evaluate(trainer, test_data) | ||
print('=' * 89) | ||
print('| End of training | test loss {:5.2f}'.format(test_loss)) | ||
print('=' * 89) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import argparse | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from utils import prepare_data, get_batch | ||
from pt_model import TransformerModel | ||
|
||
|
||
def train(model, data_source, device, epoch, args, bptt=35): | ||
total_loss = 0. | ||
model.train() | ||
for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): | ||
data, targets = get_batch(data_source, i) | ||
|
||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = criterion(output.view(-1, 28785), targets) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
total_loss += loss.item() | ||
if batch % args.log_interval == 0 and batch > 0: | ||
cur_loss = total_loss / args.log_interval | ||
print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, | ||
batch, | ||
len(data_source) // bptt, | ||
cur_loss)) | ||
total_loss = 0 | ||
|
||
|
||
def evaluate(model, data_source, criterion, bptt=35): | ||
total_loss = 0. | ||
model.eval() | ||
with torch.no_grad(): | ||
for i in range(0, data_source.size(0) - 1, bptt): | ||
data, targets = get_batch(data_source, i) | ||
output = model(data) | ||
output_flat = output.view(-1, 28785) | ||
total_loss += len(data) * criterion(output_flat, targets).item() | ||
return total_loss / (len(data_source) - 1) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') | ||
parser.add_argument('--batch-size', type=int, default=20, metavar='N', | ||
help='input batch size for training (default: 20)') | ||
parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', | ||
help='input batch size for testing (default: 20)') | ||
parser.add_argument('--epochs', type=int, default=2, metavar='N', | ||
help='number of epochs to train (default: 2)') | ||
parser.add_argument('--lr', type=float, default=0.001, metavar='LR', | ||
help='learning rate (default: 0.001)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--log-interval', type=int, default=200, metavar='N', | ||
help='how many batches to wait before logging training status (default: 200)') | ||
|
||
# Basic setup | ||
args = parser.parse_args() | ||
if not args.no_cuda and torch.cuda.is_available(): | ||
device = "cuda" | ||
else: | ||
device = "cpu" | ||
torch.manual_seed(args.seed) | ||
|
||
# Model | ||
criterion = nn.CrossEntropyLoss() | ||
lr = 0.001 | ||
model = TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=lr) | ||
|
||
# Preparing data | ||
train_data, val_data, test_data = prepare_data(device, args.batch_size, args.test_batch_size) | ||
|
||
# Train | ||
for epoch in range(1, args.epochs + 1): | ||
train(model, train_data, device, epoch, args) | ||
val_loss = evaluate(model, val_data, criterion) | ||
print('-' * 89) | ||
print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) | ||
print('-' * 89) | ||
|
||
# Evaluate | ||
test_loss = evaluate(model, test_data, criterion) | ||
print('=' * 89) | ||
print('| End of training | test loss {:5.2f}'.format(test_loss)) | ||
print('=' * 89) |
Oops, something went wrong.