Skip to content

Commit

Permalink
Support builtin function frozenset
Browse files Browse the repository at this point in the history
Cherry-pick: pytorch#134563

Change-Id: Iab6b0a8795edea759fde697f752a47f96b6e9e56
  • Loading branch information
internal developer authored and aostrowski-hbn committed Oct 16, 2024
1 parent 087cfb1 commit 78a7e04
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2723,6 +2723,71 @@ def fn(x):
ref = opt_fn(x)
self.assertEqual(ref, res)

def test_frozenset_construction(self):
def fn(x):
s = frozenset({x})
t = frozenset(s)
return len(t)

opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)

def test_frozenset_reconstruction(self):
d = {}
f = frozenset()
d[f] = torch.randn(4)

def fn(x):
k = frozenset()
torch._dynamo.graph_break()
return d[k] * x

opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)

def test_frozenset_illegal_call_method(self):
def fn_add():
s = frozenset((1, 2, 3))
s.add({2})
return len(s)

def fn_pop():
s = frozenset((1, 2, 3))
s.pop()
return len(s)

def fn_update():
s = frozenset((1, 2, 3))
s.update({4, 5, 6})
return len(s)

def fn_remove():
s = frozenset((1, 2, 3))
s.remove(2)
return len(s)

def fn_discard():
s = frozenset((1, 2, 3))
s.discard(2)
return len(s)

def fn_clear():
s = frozenset((1, 2, 3))
s.clear()
return len(s)

for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]:
torch._dynamo.reset()
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
opt_fn()

def test_is_tensor_tensor(self):
def fn(x, y):
if x is y:
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CustomizedDictVariable,
DataClassVariable,
DefaultDictVariable,
FrozensetVariable,
SetVariable,
)
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable
Expand Down
15 changes: 15 additions & 0 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ConstDictVariable,
DefaultDictVariable,
DictView,
FrozensetVariable,
is_hashable,
SetVariable,
)
Expand Down Expand Up @@ -1341,6 +1342,20 @@ def call_set(self, tx, *args, **kwargs):
else:
unimplemented(f"set(): {args} {kwargs}")

def call_frozenset(self, tx, *args, **kwargs):
assert not kwargs
if not args:
return FrozensetVariable([])
assert len(args) == 1
arg = args[0]
if isinstance(arg, variables.FrozensetVariable):
return FrozensetVariable([x.vt for x in arg.set_items])
elif arg.has_unpack_var_sequence(tx):
items = arg.unpack_var_sequence(tx)
return FrozensetVariable(items)
else:
unimplemented(f"frozenset(): {args} {kwargs}")

def call_zip(self, tx, *args, **kwargs):
if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs
Expand Down
47 changes: 47 additions & 0 deletions torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,53 @@ def getitem_const(self, arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")


class FrozensetVariable(SetVariable):
def __init__(
self,
items: List[VariableTracker],
**kwargs,
) -> None:
super().__init__(items, **kwargs)

def debug_repr(self):
if not self.items:
return "frozenset()"
else:
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"

@property
def set_items(self):
return self.items.keys()

def python_type(self):
return frozenset

def as_python_constant(self):
return {k.vt.as_python_constant() for k in self.set_items}

def reconstruct(self, codegen):
codegen.foreach([x.vt for x in self.set_items])
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_global("frozenset"),
]
)
)
codegen.extend_output(create_call_function(0, False))

def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> "VariableTracker":
if name in ["add", "pop", "update", "remove", "discard", "clear"]:
raise RuntimeError(f"Illegal call_method {name} on a frozenset")
return super().call_method(tx, name, args, kwargs)


class DictView(VariableTracker):
"""
Models _PyDictViewObject
Expand Down

0 comments on commit 78a7e04

Please sign in to comment.