Skip to content

Commit

Permalink
Revert "add register_device_op_overrides (pytorch#119268)"
Browse files Browse the repository at this point in the history
This reverts commit 2864a7e.

Reverted pytorch#119268 on behalf of https://github.com/malfet due to Broke lint ([comment](pytorch#119268 (comment)))
  • Loading branch information
pytorchmergebot committed Feb 19, 2024
1 parent 3ad067f commit 0bdeaad
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
40 changes: 17 additions & 23 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,6 @@ class DeviceCodegen:
device_codegens: Dict[str, DeviceCodegen] = {}


class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()

def set_device(self, device_idx):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()

def device_guard(self, device_idx):
raise NotImplementedError()


device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}

# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
Expand Down Expand Up @@ -148,17 +132,13 @@ def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]

def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
device_op_overrides_dict[device] = device_op_overrides

def get_device_op_overrides(device: str):
assert isinstance(device, str)
if device == "cuda":
from .cuda.device_op_overrides import CUDADeviceOpOverrides

if not device_op_overrides_dict.keys():
from .cuda import device_op_overrides

if device in device_op_overrides_dict.keys():
return device_op_overrides_dict[device]
return CUDADeviceOpOverrides()

return DeviceOpOverrides()

Expand Down Expand Up @@ -823,6 +803,20 @@ def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
return h


class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError()

def set_device(self, device_idx):
raise NotImplementedError()

def synchronize(self):
raise NotImplementedError()

def device_guard(self, device_idx):
raise NotImplementedError()


class DeferredLine(DeferredLineBase):
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/codegen/cuda/device_op_overrides.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..common import DeviceOpOverrides, register_device_op_overrides
from ..common import DeviceOpOverrides


class CUDADeviceOpOverrides(DeviceOpOverrides):
Expand All @@ -13,5 +13,3 @@ def synchronize(self):

def device_guard(self, device_idx):
return f"torch.cuda._DeviceGuard({device_idx})"

register_device_op_overrides('cuda', CUDADeviceOpOverrides())

0 comments on commit 0bdeaad

Please sign in to comment.