Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ipex in TorchEstimator #302

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/pytorch_nyctaxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def forward(self, x):
feature_columns=features, feature_types=torch.float,
label_column="fare_amount", label_type=torch.float,
batch_size=64, num_epochs=30,
metrics_name = ["MeanAbsoluteError", "MeanSquaredError"])
metrics_name=["MeanAbsoluteError", "MeanSquaredError"],
use_ipex=False, use_bf16=False, use_amp=False,
use_ccl=False, use_jit_trace=False)
# Train the model
estimator.fit_on_spark(train_df, test_df)
# Get the trained model
Expand Down
42 changes: 42 additions & 0 deletions python/raydp/torch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ray.train.torch.config import _TorchBackend
from ray.train.torch.config import TorchConfig as RayTorchConfig
from ray.train._internal.worker_group import WorkerGroup
from dataclasses import dataclass
from packaging import version
import importlib_metadata

@dataclass
class TorchConfig(RayTorchConfig):

@property
def backend_cls(self):
return EnableCCLBackend

def ccl_import():
# pylint: disable=import-outside-toplevel
import oneccl_bindings_for_pytorch

class EnableCCLBackend(_TorchBackend):

def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
for i in range(len(worker_group)):
worker_group.execute_single_async(i, ccl_import)
super().on_start(worker_group, backend_config)

def check_ipex():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

try:
_torch_version = importlib_metadata.version("torch")
except importlib_metadata.PackageNotFoundError:
return None, None

try:
_ipex_version = importlib_metadata.version("intel_extension_for_pytorch")
except importlib_metadata.PackageNotFoundError:
return _torch_version, None

torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
return torch_major_and_minor , ipex_major_and_minor
93 changes: 83 additions & 10 deletions python/raydp/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from raydp.torch.torch_metrics import TorchMetric
from raydp import stop_spark
from raydp.spark import spark_dataframe_to_ray_dataset
from raydp.torch.config import TorchConfig, check_ipex

import ray
from ray import train
Expand Down Expand Up @@ -89,6 +90,11 @@ def __init__(self,
num_processes_for_data_loader: int = 0,
metrics_name: Optional[List[Union[str, Callable]]] = None,
metrics_config: Optional[Dict[str,Dict[str, Any]]] = None,
use_ipex: bool = False,
use_bf16: bool = False,
use_amp: bool = False,
use_ccl: bool = False,
use_jit_trace: bool = False,
**extra_config):
"""
:param num_workers: the number of workers to do the distributed training
Expand Down Expand Up @@ -131,6 +137,12 @@ def __init__(self,
:param metrics_config: the optional config for the metrics. Its format is:
{"metric_name_1": {"param1": value1, "param2": value2}, "metric_name_2":{}}, where
param is the parameter corresponding to a concrete metric class of TorchMetrics.
:param use_ipex: whether to enable ipex optimization
:param use_bf16: whether to cast model parameters to ``torch.bfloat16``
:param use_amp: whether to enable auto mixed precision
:param use_ccl: whether to use torch_ccl as the backend to initialize default distributed
process group
:param use_jit_trace: whether to use jit.trace to accelerate the model
:param extra_config: the extra config will be set to ray.train.torch.TorchTrainer
"""
self._num_workers = num_workers
Expand All @@ -149,11 +161,25 @@ def __init__(self,
self._shuffle = shuffle
self._num_processes_for_data_loader = num_processes_for_data_loader
self._metrics = TorchMetric(metrics_name, metrics_config)
self._use_ipex = use_ipex
self._use_bf16 = use_bf16
self._use_amp = use_amp
self._use_ccl = use_ccl
self._use_jit_trace = use_jit_trace
self._extra_config = extra_config

if self._num_processes_for_data_loader > 0:
raise TypeError("multiple processes for data loader has not supported")

if self._use_ipex:
torch_version, ipex_version = check_ipex()
assert torch_version is not None, "Pytorch is not found. Please install Pytorch."
assert ipex_version is not None, "Intel Extension for PyTorch is not found. "\
"Please install Intel Extension for PyTorch."
assert torch_version==ipex_version, "Intel Extension for PyTorch {ipex} needs to "\
"work with PyTorch {ipex}.*, but PyTorch {torch} is found. Please switch to "\
"the matching version.".format(ipex=ipex_version, torch=torch_version)

self._trainer: TorchTrainer = None

self._check()
Expand Down Expand Up @@ -211,6 +237,18 @@ def train_func(config):
# get merics
metrics = config["metrics"]

# ipex optimize
use_ipex = config["use_ipex"]
use_bf16 = config["use_bf16"]
use_amp = config["use_amp"]
use_jit_trace = config["use_jit_trace"]
if use_ipex:
# pylint: disable=import-outside-toplevel
import intel_extension_for_pytorch as ipex
model = model.to(memory_format=torch.channels_last)
dtype = torch.bfloat16 if use_bf16 else None
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype)

# create dataset
train_data_shard = session.get_dataset_shard("train")
train_dataset = train_data_shard.to_torch(feature_columns=config["feature_columns"],
Expand All @@ -233,11 +271,14 @@ def train_func(config):
loss_results = []
for epoch in range(config["num_epochs"]):
train_res, train_loss = TorchEstimator.train_epoch(train_dataset, model, loss,
optimizer, metrics, lr_scheduler)
optimizer, metrics, use_amp,
use_bf16, use_jit_trace,
lr_scheduler)
session.report(dict(epoch=epoch, train_res=train_res, train_loss=train_loss))
if config["evaluate"]:
eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset,
model, loss, metrics)
eval_res, evaluate_loss = TorchEstimator.evaluate_epoch(evaluate_dataset, model,
loss, metrics, use_amp,
use_bf16, use_jit_trace)
session.report(dict(epoch=epoch, eval_res=eval_res, test_loss=evaluate_loss))
loss_results.append(evaluate_loss)
if hasattr(model, "module"):
Expand All @@ -250,13 +291,14 @@ def train_func(config):
}))

@staticmethod
def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):
def train_epoch(dataset, model, criterion, optimizer, metrics, use_amp, use_bf16, use_jit_trace,
scheduler=None):
model.train()
train_loss, data_size, batch_idx = 0, 0, 0
for batch_idx, (inputs, targets) in enumerate(dataset):
# Compute prediction error
outputs = model(inputs)
loss = criterion(outputs, targets)
outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets, criterion,
use_amp, use_bf16, use_jit_trace)
train_loss += loss.item()
metrics.update(outputs, targets)
data_size += targets.size(0)
Expand All @@ -273,14 +315,16 @@ def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):
return train_res, train_loss

@staticmethod
def evaluate_epoch(dataset, model, criterion, metrics):
def evaluate_epoch(dataset, model, criterion, metrics, use_amp, use_bf16, use_jit_trace):
model.eval()
test_loss, data_size, batch_idx = 0, 0, 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(dataset):
# Compute prediction error
outputs = model(inputs)
test_loss += criterion(outputs, targets).item()
outputs, loss = TorchEstimator.train_batch(batch_idx, model, inputs, targets,
criterion, use_amp, use_bf16,
use_jit_trace, is_eval=True)
test_loss += loss.item()
metrics.update(outputs, targets)
data_size += targets.size(0)

Expand All @@ -289,6 +333,26 @@ def evaluate_epoch(dataset, model, criterion, metrics):
metrics.reset()
return eval_res, test_loss

@staticmethod
def train_batch(batch_idx, model, inputs, targets, criterion, use_amp, use_bf16, use_jit_trace,
is_eval=False):
if use_amp and use_bf16:
with torch.cpu.amp.autocast():
if use_jit_trace and batch_idx==0:
model = torch.jit.trace(model, inputs)
if is_eval:
model = torch.jit.freeze(model)
outputs = model(inputs)
loss = criterion(outputs, targets)
else:
if use_jit_trace and batch_idx==0:
model = torch.jit.trace(model, inputs)
if is_eval:
model = torch.jit.freeze(model)
outputs = model(inputs)
loss = criterion(outputs, targets)
return outputs, loss

def fit(self,
train_ds: Dataset,
evaluate_ds: Optional[Dataset] = None,
Expand All @@ -306,7 +370,11 @@ def fit(self,
"num_epochs": self._num_epochs,
"drop_last": self._drop_last,
"evaluate": True,
"metrics": self._metrics
"metrics": self._metrics,
"use_ipex": self._use_ipex,
"use_bf16": self._use_bf16,
"use_amp": self._use_amp,
"use_jit_trace": self._use_jit_trace
}
scaling_config = ScalingConfig(num_workers=self._num_workers,
resources_per_worker=self._resources_per_worker)
Expand All @@ -320,10 +388,15 @@ def fit(self,
train_loop_config["evaluate"] = False
else:
datasets["evaluate"] = evaluate_ds
if self._use_ccl:
torch_config = TorchConfig(backend="ccl")
else:
torch_config = None
self._trainer = TorchTrainer(TorchEstimator.train_func,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
torch_config=torch_config,
datasets=datasets)

result = self._trainer.fit()
Expand Down