Skip to content

Commit

Permalink
Merge pull request #30 from piercefreeman/feature/ilike-support
Browse files Browse the repository at this point in the history
Verify ilike usage occurs on string columns
  • Loading branch information
piercefreeman authored Nov 18, 2024
2 parents 896a462 + c3d82a0 commit 1bf0f9a
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 5 deletions.
263 changes: 263 additions & 0 deletions iceaxe/__tests__/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import ast
import inspect
import os
from contextlib import contextmanager
from dataclasses import dataclass
from json import JSONDecodeError, dump as json_dump, loads as json_loads
from re import Pattern
from tempfile import NamedTemporaryFile, TemporaryDirectory
from textwrap import dedent

from pyright import run


@dataclass
class PyrightDiagnostic:
file: str
severity: str
message: str
rule: str | None
line: int
column: int


class ExpectedPyrightError(Exception):
"""
Exception raised when Pyright doesn't produce the expected error
"""

pass


def get_imports_from_module(module_source: str) -> set[str]:
"""
Extract all import statements from module source
"""
tree = ast.parse(module_source)
imports: set[str] = set()

for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
imports.add(f"import {name.name}")
elif isinstance(node, ast.ImportFrom):
names = ", ".join(name.name for name in node.names)
if node.module is None:
# Handle "from . import x" case
imports.add(f"from . import {names}")
else:
imports.add(f"from {node.module} import {names}")

return imports


def strip_type_ignore(line: str) -> str:
"""
Strip type: ignore comments from a line while preserving the line content
"""
if "#" not in line:
return line

# Split only on the first #
code_part, *comment_parts = line.split("#", 1)
if not comment_parts:
return line

comment = comment_parts[0]
# If this is a type: ignore comment, return just the code
if "type:" in comment and "ignore" in comment:
return code_part.rstrip()

# Otherwise return the full line
return line


def extract_current_function_code():
"""
Extracts the source code of the function calling this utility,
along with any necessary imports at the module level. This only works for
functions in a pytest testing context that are prefixed with `test_`.
"""
# Get the frame of the calling function
frame = inspect.currentframe()

try:
# Go up until we find the test function; workaround to not
# knowing the entrypoint of our contextmanager at runtime
while frame is not None:
func_name = frame.f_code.co_name
if func_name.startswith("test_"):
test_frame = frame
break
frame = frame.f_back
else:
raise RuntimeError("Could not find test function frame")

# Source code of the function
func_source = inspect.getsource(test_frame.f_code)

# Source code of the larger test file, which contains the test function
# All the imports used by the test function should be within this file
module = inspect.getmodule(test_frame)
if not module:
raise RuntimeError("Could not find module for test function")

module_source = inspect.getsource(module)

# Postprocess the source code to build into a valid new module
imports = get_imports_from_module(module_source)
filtered_lines = [strip_type_ignore(line) for line in func_source.split("\n")]
return "\n".join(sorted(imports)) + "\n\n" + dedent("\n".join(filtered_lines))

finally:
del frame # Avoid reference cycles


def create_pyright_config():
"""
Creates a new pyright configuration that ignores unused imports or other
issues that are not related to context-manager wrapped type checking.
"""
return {
"include": ["."],
"exclude": [],
"ignore": [],
"strict": [],
"typeCheckingMode": "strict",
"reportUnusedImport": False,
"reportUnusedVariable": False,
# Focus only on type checking
"reportOptionalMemberAccess": True,
"reportGeneralTypeIssues": True,
"reportPropertyTypeMismatch": True,
"reportFunctionMemberAccess": True,
"reportTypeCommentUsage": True,
"reportMissingTypeStubs": False,
# Only typehint intentional typehints, not inferred values
"reportUnknownParameterType": False,
"reportUnknownVariableType": False,
"reportUnknownMemberType": False,
"reportUnknownArgumentType": False,
"reportMissingParameterType": False,
}


def run_pyright(file_path: str) -> list[PyrightDiagnostic]:
"""
Run pyright on a file and return the diagnostics
"""
try:
with TemporaryDirectory() as temp_dir:
# Create pyright config
config_path = os.path.join(temp_dir, "pyrightconfig.json")
with open(config_path, "w") as f:
json_dump(create_pyright_config(), f)

# Copy the file to analyze into the project directory
test_file = os.path.join(temp_dir, "test.py")
with open(file_path, "r") as src, open(test_file, "w") as dst:
dst.write(src.read())

# Run pyright with the config
result = run(
"--project",
temp_dir,
"--outputjson",
test_file,
capture_output=True,
text=True,
)

try:
output = json_loads(result.stdout)
except JSONDecodeError:
print(f"Failed to parse pyright output: {result.stdout}") # noqa: T201
print(f"Stderr: {result.stderr}") # noqa: T201
raise

if "generalDiagnostics" not in output:
raise RuntimeError(
f"Unknown pyright output, missing generalDiagnostics: {output}"
)

diagnostics: list[PyrightDiagnostic] = []
for diag in output["generalDiagnostics"]:
diagnostics.append(
PyrightDiagnostic(
file=diag["file"],
severity=diag["severity"],
message=diag["message"],
rule=diag.get("rule"),
line=diag["range"]["start"]["line"] + 1, # Convert to 1-based
column=(
diag["range"]["start"]["character"]
+ 1 # Convert to 1-based
),
)
)

return diagnostics

except Exception as e:
raise RuntimeError(f"Failed to run pyright: {str(e)}")


@contextmanager
def pyright_raises(
expected_rule: str,
expected_line: int | None = None,
matches: Pattern | None = None,
):
"""
Context manager that verifies code produces a specific Pyright error.
:params expected_rule: The Pyright rule that should be violated
:params expected_line: Optional line number where the error should occur
:raises ExpectedPyrightError: If Pyright doesn't produce the expected error
"""
# Create a temporary file to store the code
with NamedTemporaryFile(mode="w", suffix=".py") as temp_file:
temp_path = temp_file.name

# Extract the source code of the calling function
source_code = extract_current_function_code()
print(f"Running Pyright on:\n{source_code}") # noqa: T201

# Write the source code to the temporary file
temp_file.write(source_code)
temp_file.flush()

# At runtime, our actual code is probably a no-op but we still let it run
# inside the scope of the contextmanager
yield

# Run Pyright on the temporary file
diagnostics = run_pyright(temp_path)

# Check if any of the diagnostics match our expected error
for diagnostic in diagnostics:
if diagnostic.rule == expected_rule:
if expected_line is not None and diagnostic.line != expected_line:
continue
if matches and not matches.search(diagnostic.message):
continue
# Found matching error
return

# If we get here, we didn't find the expected error
actual_errors = [
f"{d.rule or 'unknown'} on line {d.line}: {d.message}" for d in diagnostics
]
raise ExpectedPyrightError(
f"Expected Pyright error {expected_rule}"
f"{f' on line {expected_line}' if expected_line else ''}"
f" but got: {', '.join(actual_errors) if actual_errors else 'no errors'}"
)
35 changes: 35 additions & 0 deletions iceaxe/__tests__/test_comparison.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from re import compile as re_compile
from typing import Any

import pytest
from typing_extensions import assert_type

from iceaxe.__tests__.helpers import pyright_raises
from iceaxe.base import TableBase
from iceaxe.comparison import ComparisonType, FieldComparison
from iceaxe.field import DBFieldClassDefinition, DBFieldInfo
from iceaxe.typing import column


def test_comparison_type_enum():
Expand All @@ -17,6 +21,9 @@ def test_comparison_type_enum():
assert ComparisonType.IN == "IN"
assert ComparisonType.NOT_IN == "NOT IN"
assert ComparisonType.LIKE == "LIKE"
assert ComparisonType.NOT_LIKE == "NOT LIKE"
assert ComparisonType.ILIKE == "ILIKE"
assert ComparisonType.NOT_ILIKE == "NOT ILIKE"
assert ComparisonType.IS == "IS"
assert ComparisonType.IS_NOT == "IS NOT"

Expand Down Expand Up @@ -158,3 +165,31 @@ def test_comparison_with_different_types(db_field: DBFieldClassDefinition, value
assert result.left == db_field
assert isinstance(result.comparison, ComparisonType)
assert result.right == value


#
# Typehinting
# These checks are run as part of the static typechecking we do
# for our codebase, not as part of the pytest runtime.
#


def test_typehint_ilike():
class UserDemo(TableBase):
id: int
value_str: str
value_int: int

str_col = column(UserDemo.value_str)
int_col = column(UserDemo.value_int)

assert_type(str_col, DBFieldClassDefinition[str])
assert_type(int_col, DBFieldClassDefinition[int])

assert_type(str_col.ilike("test"), bool)

with pyright_raises(
"reportAttributeAccessIssue",
matches=re_compile('Cannot access attribute "ilike"'),
):
int_col.ilike(5) # type: ignore
9 changes: 9 additions & 0 deletions iceaxe/__tests__/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from iceaxe.__tests__.helpers import pyright_raises


def test_basic_type_error():
def type_error_func(x: int) -> int:
return 10

with pyright_raises("reportArgumentType"):
type_error_func("20") # type: ignore
27 changes: 25 additions & 2 deletions iceaxe/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from iceaxe.typing import is_column, is_comparison, is_comparison_group

T = TypeVar("T", bound="ComparisonBase")
J = TypeVar("J")


class ComparisonType(StrEnum):
Expand All @@ -18,7 +19,12 @@ class ComparisonType(StrEnum):
GE = ">="
IN = "IN"
NOT_IN = "NOT IN"

LIKE = "LIKE"
NOT_LIKE = "NOT LIKE"
ILIKE = "ILIKE"
NOT_ILIKE = "NOT ILIKE"

IS = "IS"
IS_NOT = "IS NOT"

Expand Down Expand Up @@ -95,7 +101,7 @@ def to_query(self, start: int = 1):
return QueryLiteral(queries), all_variables


class ComparisonBase(ABC):
class ComparisonBase(ABC, Generic[J]):
def __eq__(self, other): # type: ignore
if other is None:
return self._compare(ComparisonType.IS, None)
Expand Down Expand Up @@ -124,9 +130,26 @@ def in_(self, other) -> bool:
def not_in(self, other) -> bool:
return self._compare(ComparisonType.NOT_IN, other) # type: ignore

def like(self, other) -> bool:
def like(
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
) -> bool:
return self._compare(ComparisonType.LIKE, other) # type: ignore

def not_like(
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
) -> bool:
return self._compare(ComparisonType.NOT_LIKE, other) # type: ignore

def ilike(
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
) -> bool:
return self._compare(ComparisonType.ILIKE, other) # type: ignore

def not_ilike(
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
) -> bool:
return self._compare(ComparisonType.NOT_ILIKE, other) # type: ignore

def _compare(self, comparison: ComparisonType, other: Any) -> FieldComparison[Self]:
return FieldComparison(left=self, comparison=comparison, right=other)

Expand Down
Loading

0 comments on commit 1bf0f9a

Please sign in to comment.