From 240887e5e08268f7bfbf1027575017ebbdf9b6e6 Mon Sep 17 00:00:00 2001 From: EmilyBourne Date: Tue, 5 Mar 2024 10:10:23 +0100 Subject: [PATCH] Fix partial templates (#1780) Fix bug with partial templates (fixes #1779 ) where the partial template is incorrectly stripped away with unused template arguments. Add a test and update patch version. --- CHANGELOG.md | 11 +++++++++++ pyccel/parser/semantic.py | 3 ++- pyccel/version.py | 2 +- tests/epyccel/test_epyccel_decorators.py | 25 +++++++++++++++++++++++- 4 files changed, 38 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d01c4bee9c..322d037c61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. ### Added +### Fixed + +### Changed + +### Deprecated + +## \[1.11.2\] - 2024-03-05 + +### Added + - #1689 : Add Python support for list method `append()`. - #1692 : Add Python support for list method `insert()`. - #1690 : Add Python support for list method `pop()`. @@ -19,6 +29,7 @@ All notable changes to this project will be documented in this file. - #1575 : Fixed inhomogeneous tuple (due to incompatible sizes) being treated as homogeneous tuple. - #1182 : Fix tuples containing objects with different ranks. - #1575 : Fix duplication operator for non-homogeneous tuples with a non-literal but constant multiplier. +- #1779 : Fix standalone partial templates. ### Changed diff --git a/pyccel/parser/semantic.py b/pyccel/parser/semantic.py index c3914c4eb2..4d596d65bc 100644 --- a/pyccel/parser/semantic.py +++ b/pyccel/parser/semantic.py @@ -3800,7 +3800,8 @@ def _visit_FunctionDef(self, expr): arg_annotations = [annot for a in templatable_args for annot in (a.type_list \ if isinstance(a, UnionTypeAnnotation) else [a]) \ if isinstance(annot, SyntacticTypeAnnotation)] - used_type_names = set(a.dtype for a in arg_annotations) + type_names = [a.dtype for a in arg_annotations] + used_type_names = set(d.base if isinstance(d, IndexedElement) else d for d in type_names) templates = {t: v for t,v in templates.items() if t in used_type_names} template_combinations = list(product(*[v.type_list for v in templates.values()])) diff --git a/pyccel/version.py b/pyccel/version.py index 8a2175f417..b01a093893 100644 --- a/pyccel/version.py +++ b/pyccel/version.py @@ -1,4 +1,4 @@ """ Module specifying the current version string for pyccel """ -__version__ = "1.11.1" +__version__ = "1.11.2" diff --git a/tests/epyccel/test_epyccel_decorators.py b/tests/epyccel/test_epyccel_decorators.py index e4e80724d9..dd5f70576f 100644 --- a/tests/epyccel/test_epyccel_decorators.py +++ b/tests/epyccel/test_epyccel_decorators.py @@ -4,7 +4,7 @@ import pytest import numpy as np from pyccel.epyccel import epyccel -from pyccel.decorators import private, inline +from pyccel.decorators import private, inline, template @pytest.mark.parametrize( 'lang', ( pytest.param("fortran", marks = pytest.mark.fortran), @@ -146,3 +146,26 @@ def get_val(x : int = None , y : int = None): g = epyccel(f, language=language) assert f() == g() + +def test_indexed_template(language): + @template(name='T', types=[float, complex]) + def my_sum(v: 'T[:]'): + return v.sum() + + pyccel_sum = epyccel(my_sum, language=language) + + x = np.ones(4, dtype=float) + + python_fl = my_sum(x) + pyccel_fl = pyccel_sum(x) + + assert python_fl == pyccel_fl + assert isinstance(python_fl, type(pyccel_fl)) + + y = np.full(4, 1 + 3j) + + python_cmplx = my_sum(y) + pyccel_cmplx = pyccel_sum(y) + + assert python_cmplx == pyccel_cmplx + assert isinstance(python_cmplx, type(pyccel_cmplx))