diff --git a/patches/pyink.patch b/patches/pyink.patch index 0cb94c99c6f..ee82bed7e7a 100644 --- a/patches/pyink.patch +++ b/patches/pyink.patch @@ -370,6 +370,19 @@ ) lines: List[Tuple[int, int]] = [] +@@ -1153,9 +677,10 @@ + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged ++ line = ink.get_code_start(src) + if ( +- src[:2] == "%%" +- and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics ++ line[:2] == "%%" ++ and line.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics + ): + raise NothingChanged + @@ -1175,7 +1219,6 @@ raise NothingChanged @@ -1395,6 +1408,19 @@ runner = CliRunner() +@@ -209,6 +209,12 @@ + assert result == expected_output + + ++def test_cell_magic_with_custom_python_magic_after_spaces_and_comments_noop() -> None: ++ src = "\n \n # comment\n\t\n %%custom_python_magic \nx=2" ++ with pytest.raises(NothingChanged): ++ format_cell(src, fast=True, mode=JUPYTER_MODE) ++ ++ + def test_cell_magic_nested() -> None: + src = "%%time\n%%time\n2+2" + result = format_cell(src, fast=True, mode=JUPYTER_MODE) @@ -385,6 +385,45 @@ assert result == expected diff --git a/src/pyink/__init__.py b/src/pyink/__init__.py index f96ac01cb20..2d160691a2f 100644 --- a/src/pyink/__init__.py +++ b/src/pyink/__init__.py @@ -1150,9 +1150,10 @@ def validate_cell(src: str, mode: Mode) -> None: """ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): raise NothingChanged + line = ink.get_code_start(src) if ( - src[:2] == "%%" - and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics + line[:2] == "%%" + and line.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics ): raise NothingChanged diff --git a/src/pyink/ink.py b/src/pyink/ink.py index 6cd4b3ffdd1..2cb219972e6 100644 --- a/src/pyink/ink.py +++ b/src/pyink/ink.py @@ -3,6 +3,7 @@ This is a separate module for easier patch management. """ +import re from typing import ( Collection, Iterator, @@ -62,6 +63,26 @@ def majority_quote(node: LN) -> Quote: return Quote.DOUBLE +def get_code_start(src: str) -> str: + """Provides the first line where the code starts. + + Iterates over lines of code until it finds the first line that doesn't + contain only empty spaces and comments. If such line doesn't exist, it + returns an empty string. + + Args: + src: The multi-line source code. + + Returns: + The first line of code without initial spaces or an empty string. + """ + for match in re.finditer(".+", src): + line = match.group(0).lstrip() + if line and not line.startswith("#"): + return line + return "" + + def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]): """Converts unchanged lines to STANDALONE_COMMENT. diff --git a/tests/test_ipynb.py b/tests/test_ipynb.py index 7cd2950eb82..115cea65d97 100644 --- a/tests/test_ipynb.py +++ b/tests/test_ipynb.py @@ -210,6 +210,12 @@ def test_cell_magic_with_custom_python_magic( assert result == expected_output +def test_cell_magic_with_custom_python_magic_after_spaces_and_comments_noop() -> None: + src = "\n \n # comment\n\t\n %%custom_python_magic \nx=2" + with pytest.raises(NothingChanged): + format_cell(src, fast=True, mode=JUPYTER_MODE) + + def test_cell_magic_nested() -> None: src = "%%time\n%%time\n2+2" result = format_cell(src, fast=True, mode=JUPYTER_MODE)