Skip to content

Commit

Permalink
Add support in Neural Typecheck to disable semantic checks (NVIDIA#8212)
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <titu1994@gmail.com>
  • Loading branch information
titu1994 committed Jan 24, 2024
1 parent 0773702 commit aeb9799
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 7 deletions.
48 changes: 41 additions & 7 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
__all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck', 'PretrainedModelInfo']

_TYPECHECK_ENABLED = True
_TYPECHECK_SEMANTIC_CHECK_ENABLED = True
# TODO @blisc: Remove _HAS_HYDRA
_HAS_HYDRA = True

Expand All @@ -54,6 +55,13 @@ def is_typecheck_enabled():
return _TYPECHECK_ENABLED


def is_semantic_typecheck_enabled():
"""
Getter method for typechecking semantics state.
"""
return _TYPECHECK_SEMANTIC_CHECK_ENABLED


@dataclass
class TypecheckMetadata:
"""
Expand Down Expand Up @@ -178,7 +186,6 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw
kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped
function upon call.
"""
# TODO: Properly implement this
if input_types is not None:
# Precompute metadata
metadata = TypecheckMetadata(original_types=input_types, ignore_collections=ignore_collections)
Expand All @@ -202,9 +209,11 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw
)

# Perform neural type check
if hasattr(value, 'neural_type') and not metadata.base_types[key].compare(value.neural_type) in (
NeuralTypeComparisonResult.SAME,
NeuralTypeComparisonResult.GREATER,
if (
hasattr(value, 'neural_type')
and is_semantic_typecheck_enabled()
and not metadata.base_types[key].compare(value.neural_type)
in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,)
):
error_msg = [
f"{input_types[key].compare(value.neural_type)} :",
Expand Down Expand Up @@ -379,9 +388,11 @@ def __check_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name
f"Expected nested depth : {metadata.container_depth[name]}"
)

if hasattr(obj, 'neural_type') and not type_val.compare(obj.neural_type) in (
NeuralTypeComparisonResult.SAME,
NeuralTypeComparisonResult.GREATER,
if (
hasattr(obj, 'neural_type')
and is_semantic_typecheck_enabled()
and not type_val.compare(obj.neural_type)
in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,)
):
raise TypeError(
f"{type_val.compare(obj.neural_type)} : \n"
Expand Down Expand Up @@ -1114,3 +1125,26 @@ def disable_checks():
yield
finally:
typecheck.set_typecheck_enabled(enabled=True)

@staticmethod
def set_semantic_check_enabled(enabled: bool = True):
"""
Global method to enable/disable semantic typechecking.
Args:
enabled: bool, when True will enable semantic typechecking.
"""
global _TYPECHECK_SEMANTIC_CHECK_ENABLED
_TYPECHECK_SEMANTIC_CHECK_ENABLED = enabled

@staticmethod
@contextmanager
def disable_semantic_checks():
"""
Context manager that temporarily disables semantic type checking within its context.
"""
typecheck.set_semantic_check_enabled(enabled=False)
try:
yield
finally:
typecheck.set_semantic_check_enabled(enabled=True)
38 changes: 38 additions & 0 deletions tests/core/test_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,3 +1152,41 @@ def __call__(self, x):
assert len(outA[0]) == 3
for i in range(len(outA)):
assert outA[0][i].neural_type.compare(NeuralType(('B', 'D'), LogitsType()))

@pytest.mark.unit
def test_disable_semantic_types_input_output(self):
class InputOutputTypes(Typing):
@property
def input_types(self):
return {"x": NeuralType(('B',), LogprobsType())}

@property
def output_types(self):
return {"y": NeuralType(('B',), LabelsType())}

@typecheck()
def __call__(self, x):
x += 1
return x

obj = InputOutputTypes()
result = obj(x=torch.zeros(10))

assert result.sum() == torch.tensor(10.0)
assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME

# Test that input is provided with wrong type and semantic checks are not disabled
with pytest.raises(TypeError):
input_data = torch.zeros(10)
input_data.neural_type = NeuralType(('B',), LabelsType())
_ = obj(x=input_data)

# Provide input with wrong type after disabling semantic type checks
with typecheck.disable_semantic_checks():
input_data = torch.zeros(10)
input_data.neural_type = NeuralType(('B',), LabelsType()) # Should be LogprobsType()
result = obj(x=input_data)

# assert that even if semantic types are disabled, output is attached with appropriate types
assert result.sum() == torch.tensor(10.0)
assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME

0 comments on commit aeb9799

Please sign in to comment.