From e2b496fa99ded0842d930a84f6c7ca62ce37b97d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 3 May 2024 14:45:46 +0200 Subject: [PATCH] Made some fixes to the verification function of the `TranslatedJaxprSDFG`. --- src/jace/translator/translated_jaxpr_sdfg.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/jace/translator/translated_jaxpr_sdfg.py b/src/jace/translator/translated_jaxpr_sdfg.py index b2b7dc8..bc1ffab 100644 --- a/src/jace/translator/translated_jaxpr_sdfg.py +++ b/src/jace/translator/translated_jaxpr_sdfg.py @@ -46,12 +46,18 @@ class TranslatedJaxprSDFG: def validate(self) -> bool: """Validate the underlying SDFG.""" - # To prevent the 'non initialized' data warnings we have to temporary promote the - # input arguments as global. + # To prevent the 'non initialized' data warnings we have to temporary + # promote input and output arguments to globals + promote_to_glob: set[str] = set() org_trans_state: dict[str, bool] = {} - for var in self.inp_names: + if self.inp_names: + promote_to_glob.update(self.inp_names) + if self.out_names: + promote_to_glob.update(self.out_names) + for var in promote_to_glob: org_trans_state[var] = self.sdfg.arrays[var].transient self.sdfg.arrays[var].transient = False + try: self.sdfg.validate() finally: