From 0bdeaad93683239d012cef45d5670fd1e0b5eca9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 19 Feb 2024 22:31:32 +0000 Subject: [PATCH] Revert "add register_device_op_overrides (#119268)" This reverts commit 2864a7e161cc107f7e4c00cccdf860a6089c73c3. Reverted https://github.com/pytorch/pytorch/pull/119268 on behalf of https://github.com/malfet due to Broke lint ([comment](https://github.com/pytorch/pytorch/pull/119268#issuecomment-1953231324)) --- torch/_inductor/codegen/common.py | 40 ++++++++----------- .../codegen/cuda/device_op_overrides.py | 4 +- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 4ff98c7763612..e0665863c76ec 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -89,22 +89,6 @@ class DeviceCodegen: device_codegens: Dict[str, DeviceCodegen] = {} -class DeviceOpOverrides: - def import_get_raw_stream_as(self, name): - raise NotImplementedError() - - def set_device(self, device_idx): - raise NotImplementedError() - - def synchronize(self): - raise NotImplementedError() - - def device_guard(self, device_idx): - raise NotImplementedError() - - -device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {} - # The code generated by Inductor consists of two main parts: kernel code and wrapper code. # For any new backend looking to integrate with Inductor, customization of these two main # parts are necessary to generate its specific code. @@ -148,17 +132,13 @@ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes): # added contiguous index prevents reordering return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] -def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): - device_op_overrides_dict[device] = device_op_overrides def get_device_op_overrides(device: str): assert isinstance(device, str) + if device == "cuda": + from .cuda.device_op_overrides import CUDADeviceOpOverrides - if not device_op_overrides_dict.keys(): - from .cuda import device_op_overrides - - if device in device_op_overrides_dict.keys(): - return device_op_overrides_dict[device] + return CUDADeviceOpOverrides() return DeviceOpOverrides() @@ -823,6 +803,20 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]: return h +class DeviceOpOverrides: + def import_get_raw_stream_as(self, name): + raise NotImplementedError() + + def set_device(self, device_idx): + raise NotImplementedError() + + def synchronize(self): + raise NotImplementedError() + + def device_guard(self, device_idx): + raise NotImplementedError() + + class DeferredLine(DeferredLineBase): """A line that can be 'unwritten' by adding name to V.graph.removed_buffers""" diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 468c75d7444d5..7f722cf86d678 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -1,4 +1,4 @@ -from ..common import DeviceOpOverrides, register_device_op_overrides +from ..common import DeviceOpOverrides class CUDADeviceOpOverrides(DeviceOpOverrides): @@ -13,5 +13,3 @@ def synchronize(self): def device_guard(self, device_idx): return f"torch.cuda._DeviceGuard({device_idx})" - -register_device_op_overrides('cuda', CUDADeviceOpOverrides())