diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index e6b260e0d5b6..30d2f8f6f5b3 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -43,6 +43,7 @@ __all__ = ['Typing', 'FileIO', 'Model', 'Serialization', 'typecheck', 'PretrainedModelInfo'] _TYPECHECK_ENABLED = True +_TYPECHECK_SEMANTIC_CHECK_ENABLED = True # TODO @blisc: Remove _HAS_HYDRA _HAS_HYDRA = True @@ -54,6 +55,13 @@ def is_typecheck_enabled(): return _TYPECHECK_ENABLED +def is_semantic_typecheck_enabled(): + """ + Getter method for typechecking semantics state. + """ + return _TYPECHECK_SEMANTIC_CHECK_ENABLED + + @dataclass class TypecheckMetadata: """ @@ -178,7 +186,6 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw kwargs: Dictionary of argument_name:argument_value pairs passed to the wrapped function upon call. """ - # TODO: Properly implement this if input_types is not None: # Precompute metadata metadata = TypecheckMetadata(original_types=input_types, ignore_collections=ignore_collections) @@ -202,9 +209,11 @@ def _validate_input_types(self, input_types=None, ignore_collections=False, **kw ) # Perform neural type check - if hasattr(value, 'neural_type') and not metadata.base_types[key].compare(value.neural_type) in ( - NeuralTypeComparisonResult.SAME, - NeuralTypeComparisonResult.GREATER, + if ( + hasattr(value, 'neural_type') + and is_semantic_typecheck_enabled() + and not metadata.base_types[key].compare(value.neural_type) + in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,) ): error_msg = [ f"{input_types[key].compare(value.neural_type)} :", @@ -379,9 +388,11 @@ def __check_neural_type(self, obj, metadata: TypecheckMetadata, depth: int, name f"Expected nested depth : {metadata.container_depth[name]}" ) - if hasattr(obj, 'neural_type') and not type_val.compare(obj.neural_type) in ( - NeuralTypeComparisonResult.SAME, - NeuralTypeComparisonResult.GREATER, + if ( + hasattr(obj, 'neural_type') + and is_semantic_typecheck_enabled() + and not type_val.compare(obj.neural_type) + in (NeuralTypeComparisonResult.SAME, NeuralTypeComparisonResult.GREATER,) ): raise TypeError( f"{type_val.compare(obj.neural_type)} : \n" @@ -1114,3 +1125,26 @@ def disable_checks(): yield finally: typecheck.set_typecheck_enabled(enabled=True) + + @staticmethod + def set_semantic_check_enabled(enabled: bool = True): + """ + Global method to enable/disable semantic typechecking. + + Args: + enabled: bool, when True will enable semantic typechecking. + """ + global _TYPECHECK_SEMANTIC_CHECK_ENABLED + _TYPECHECK_SEMANTIC_CHECK_ENABLED = enabled + + @staticmethod + @contextmanager + def disable_semantic_checks(): + """ + Context manager that temporarily disables semantic type checking within its context. + """ + typecheck.set_semantic_check_enabled(enabled=False) + try: + yield + finally: + typecheck.set_semantic_check_enabled(enabled=True) diff --git a/tests/core/test_typecheck.py b/tests/core/test_typecheck.py index 1b0c927bae57..b55c8a1c9c4c 100644 --- a/tests/core/test_typecheck.py +++ b/tests/core/test_typecheck.py @@ -1152,3 +1152,41 @@ def __call__(self, x): assert len(outA[0]) == 3 for i in range(len(outA)): assert outA[0][i].neural_type.compare(NeuralType(('B', 'D'), LogitsType())) + + @pytest.mark.unit + def test_disable_semantic_types_input_output(self): + class InputOutputTypes(Typing): + @property + def input_types(self): + return {"x": NeuralType(('B',), LogprobsType())} + + @property + def output_types(self): + return {"y": NeuralType(('B',), LabelsType())} + + @typecheck() + def __call__(self, x): + x += 1 + return x + + obj = InputOutputTypes() + result = obj(x=torch.zeros(10)) + + assert result.sum() == torch.tensor(10.0) + assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME + + # Test that input is provided with wrong type and semantic checks are not disabled + with pytest.raises(TypeError): + input_data = torch.zeros(10) + input_data.neural_type = NeuralType(('B',), LabelsType()) + _ = obj(x=input_data) + + # Provide input with wrong type after disabling semantic type checks + with typecheck.disable_semantic_checks(): + input_data = torch.zeros(10) + input_data.neural_type = NeuralType(('B',), LabelsType()) # Should be LogprobsType() + result = obj(x=input_data) + + # assert that even if semantic types are disabled, output is attached with appropriate types + assert result.sum() == torch.tensor(10.0) + assert result.neural_type.compare(NeuralType(('B',), LabelsType())) == NeuralTypeComparisonResult.SAME