Skip to content

Commit

Permalink
Fixes spotify#3334: Add support for custom parameter classes in mypy …
Browse files Browse the repository at this point in the history
…plugin
  • Loading branch information
starhel committed Jan 17, 2025
1 parent 6d12bd1 commit 1c8cf31
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
56 changes: 28 additions & 28 deletions luigi/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

import re
import sys
from typing import Callable, Dict, Final, Iterator, List, Literal, Optional

Expand Down Expand Up @@ -61,10 +60,6 @@

METADATA_TAG: Final[str] = "task"

PARAMETER_FULLNAME_MATCHER: Final[re.Pattern] = re.compile(
r"^luigi(\.parameter)?\.\w*Parameter$"
)

if sys.version_info[:2] < (3, 8):
# This plugin uses the walrus operator, which is only available in Python 3.8+
raise RuntimeError("This plugin requires Python 3.8+")
Expand All @@ -84,12 +79,17 @@ def get_function_hook(
self, fullname: str
) -> Callable[[FunctionContext], Type] | None:
"""Adjust the return type of the `Parameters` function."""
if PARAMETER_FULLNAME_MATCHER.match(fullname):
if self.check_parameter(fullname):
return self._task_parameter_field_callback
return None

def check_parameter(self, fullname):
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
return any(base.fullname == "luigi.parameter.Parameter" for base in sym.node.mro)

def _task_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api)
transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api, self)
transformer.transform()

def _task_parameter_field_callback(self, ctx: FunctionContext) -> Type:
Expand All @@ -109,6 +109,7 @@ def _task_parameter_field_callback(self, ctx: FunctionContext) -> Type:
# if no `default` argument is found, return AnyType with unannotated type.
except ValueError:
return AnyType(TypeOfAny.unannotated)
print(ctx.context)

default_args = ctx.args[default_idx]

Expand Down Expand Up @@ -210,10 +211,12 @@ def __init__(
cls: ClassDef,
reason: Expression | Statement,
api: SemanticAnalyzerPluginInterface,
task_plugin: TaskPlugin,
) -> None:
self._cls = cls
self._reason = reason
self._api = api
self._task_plugin = task_plugin

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying gokart.Task"""
Expand Down Expand Up @@ -311,7 +314,7 @@ def collect_attributes(self) -> Optional[List[TaskAttribute]]:
# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
for stmt in self._get_assignment_statements_from_block(cls.defs):
if not is_parameter_call(stmt.rvalue):
if not self.is_parameter_call(stmt.rvalue):
continue

# a: int, b: str = 1, 'foo' is not supported syntax so we
Expand Down Expand Up @@ -435,29 +438,26 @@ def _infer_task_attr_init_type(

return default

def is_parameter_call(self, expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return False

def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return False

callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return (
PARAMETER_FULLNAME_MATCHER.match(f"{callee.expr.name}.{callee.name}")
is not None
)
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False
callee = expr.callee
fullname = None
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
fullname = f"{callee.expr.name}.{callee.name}"
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
if isinstance(type_info, TypeInfo):
fullname = type_info.fullname

return False
return fullname is not None and self._task_plugin.check_parameter(fullname)


def plugin(version: str) -> type[Plugin]:
Expand Down
9 changes: 8 additions & 1 deletion test/mypy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ def test_plugin_no_issue(self):

test_code = """
import luigi
from uuid import UUID
class UUIDParameter(luigi.Parameter):
def parse(self, s):
return UUID(s)
class MyTask(luigi.Task):
foo: int = luigi.IntParameter()
bar: str = luigi.Parameter()
uniq: UUID = UUIDParameter()
baz: str = luigi.Parameter(default="baz")
MyTask(foo=1, bar='bar')
MyTask(foo=1, bar='bar', uniq=UUID("9b0591d7-a167-4978-bc6d-41f7d84a288c"))
"""

with tempfile.NamedTemporaryFile(suffix=".py") as test_file:
Expand Down

0 comments on commit 1c8cf31

Please sign in to comment.