Skip to content

Instantly share code, notes, and snippets.

@CGamesPlay
Last active July 17, 2024 07:25
Show Gist options
  • Save CGamesPlay/d3f72fca787b16b879efb07f1ba46d7a to your computer and use it in GitHub Desktop.
Save CGamesPlay/d3f72fca787b16b879efb07f1ba46d7a to your computer and use it in GitHub Desktop.
Fully-typed Python decorator for functions, methods, staticmethods, and classmethods.
"""
Showcase a fully-typed decorator that can be applied to functions, methods,
staticmethods, and classmethods.
This example has some limitations due to the limited expressiveness of the
Python type system:
1. When applying the decorator to a function whose first argument is an
optional type object, the returned function's first argument is shown as
required.
2. It's not possible to apply the decorator to a staticmethod whose first
argument is a type object which is a superclass of the class it's defined
in, or whose first argument can hold an instance of the class it's defined
in. The returned function's type hints will be incorrect.
It's possible to work around both of these limitations by removing the first
overload from the decorator, however doing so will require all classmethods to
use the included classdecorator instead of the normal decorator.
"""
import types
import unittest
from typing import (
Any,
Callable,
Concatenate,
Generic,
ParamSpec,
TypeVar,
overload,
reveal_type,
)
P = ParamSpec("P")
BoundP = ParamSpec("BoundP")
R_co = TypeVar("R_co", covariant=True)
S = TypeVar("S")
T = TypeVar("T", bound=type)
class Decorator(Generic[P, R_co]):
def __init__(self, f: Callable[P, R_co]):
self.f = f
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
return self.f(*args, **kwargs)
@overload
def __get__(
self: "Decorator[Concatenate[S, BoundP], R_co]",
instance: S,
owner: type | None,
) -> "Decorator[BoundP, R_co]": ...
@overload
def __get__(self, instance: Any, owner: type | None) -> "Decorator[P, R_co]": ...
def __get__(self, instance: Any, owner: Any = None) -> Any:
# Overload 1 is for accessing through an instance (binding).
# Overload 2 is for accessing through the class itself (non-binding).
# We special case support for staticmethod and classmethod here.
if isinstance(self.f, staticmethod) or isinstance(self.f, classmethod):
return self.f.__get__(instance, owner)
if instance is None:
return self
return Decorator(types.MethodType(self.f, instance))
# This class is only used when type checking, it does not exist at runtime. It
# is responsible for functions which take type objects as the first parameter,
# and handles the self binding.
class MethodDecorator(Generic[T, P, R_co]):
def __call__(self, instance: T, *args: P.args, **kwargs: P.kwargs) -> R_co: ...
# Overload 1 matches member accesses where the class on the left is the
# same as the method's first parameter (classmethods). Overload 2 matches
# everything else (assume non-binding).
@overload
def __get__(self, instance: Any, owner: T) -> "Decorator[P, R_co]": ...
@overload
def __get__(
self, instance: Any, owner: type | None
) -> "Decorator[Concatenate[T, P], R_co]": ...
def __get__(self, instance: Any, owner: type | None) -> Any: ...
@overload
def decorator(f: Callable[Concatenate[T, P], R_co]) -> MethodDecorator[T, P, R_co]: ...
@overload
def decorator( # pyright: ignore[reportOverlappingOverload]
f: Callable[P, R_co]
) -> Decorator[P, R_co]: ...
def decorator(f: Any) -> Any:
return Decorator(f)
def classdecorator(f: Callable[Concatenate[T, P], R_co]) -> MethodDecorator[T, P, R_co]:
return Decorator(f) # type: ignore
@decorator
def func() -> None:
print("in func()")
@decorator
def func_param(val: int = 1) -> None:
print(f"in func_param({val})")
@decorator
def func_typevar(val: type[float]) -> None:
print("in func_typevar")
class Class:
@decorator
def method(self) -> None:
assert isinstance(self, Class)
print("in Class.method()")
@decorator
def method_param(self, val: int = 1) -> None:
assert isinstance(self, Class)
print(f"in Class.method_param({val})")
@decorator
@staticmethod
def static_method() -> None:
print("in Class.static_method()")
@decorator
@staticmethod
def static_method_param(val: int = 1) -> None:
print(f"in Class.static_method_param({val})")
@decorator
@staticmethod
def static_method_typevar(val: type[float]) -> None:
print(f"in Class.static_method_typevar({val})")
@decorator
@staticmethod
def static_method_typevar_same_type(val: "type[Class]") -> None:
print(f"in Class.static_method_typevar_same_type({val})")
@decorator
@staticmethod
def static_method_instance(val: object) -> None:
print(f"in Class.static_method_instance({val})")
@decorator
@classmethod
def class_method(cls) -> None:
assert cls is Class
print("in Class.class_method()")
@decorator
@classmethod
def class_method_param(cls, val: int = 1) -> None:
assert cls is Class
print(f"in Class.class_method_param({val})")
class TestCases(unittest.TestCase):
def test_runtime(self) -> None:
reveal_type(func)
func()
reveal_type(func_param)
func_param()
func_param(1)
reveal_type(func_typevar)
func_typevar(int)
reveal_type(Class.method)
reveal_type(Class().method)
Class.method(Class())
Class().method()
reveal_type(Class.method_param)
reveal_type(Class().method_param)
Class.method_param(Class(), 1)
Class().method_param(1)
Class.method_param(Class())
Class().method_param()
reveal_type(Class.static_method)
reveal_type(Class().static_method)
Class.static_method()
Class().static_method()
reveal_type(Class.static_method_param)
reveal_type(Class().static_method_param)
Class.static_method_param(1)
Class().static_method_param(1)
Class.static_method_param()
Class().static_method_param()
reveal_type(Class.static_method_typevar)
reveal_type(Class().static_method_typevar)
Class.static_method_typevar(int)
Class().static_method_typevar(int)
reveal_type(Class.class_method)
reveal_type(Class().class_method)
Class.class_method()
Class().class_method()
reveal_type(Class.class_method_param)
reveal_type(Class().class_method_param)
Class.class_method_param(1)
Class().class_method_param(1)
Class.class_method_param()
Class().class_method_param()
def test_typing_failures(self) -> None:
reveal_type(Class.static_method_instance)
reveal_type(Class().static_method_instance)
Class.static_method_instance(Class()) # type: ignore
Class().static_method_instance(Class()) # type: ignore
reveal_type(Class.static_method_typevar_same_type)
reveal_type(Class().static_method_typevar_same_type)
Class.static_method_typevar_same_type(Class) # type: ignore
Class().static_method_typevar_same_type(Class) # type: ignore
if __name__ == "__main__":
unittest.main()
@Viicos
Copy link

Viicos commented Jul 13, 2024

Nice attempt! However I think you'll face limitations pretty quickly, in particular with default values that aren't retained by ParamSpec:

@decorator
def func(arg: int = 1) -> None: ...

func()  # type error, expected one required argument

It's as shame that typing decorators that can be applied on both functions and methods isn't supported yet, we tried doing so for functools.cache for example. Some related discussion can be found here.

@CGamesPlay
Copy link
Author

CGamesPlay commented Jul 13, 2024

That's definitely an unfortunate limitation, but on the plus side it only affects the first argument of functions (which implies it also only affects free functions/staticmethods, since methods/classmethods will always have a non-optional self argument). This is because the Concatenate to shift off the self argument is what loses the default value, but the remaining args in the ParamSpec actually do retain their default values. This limitation would go away if the type checkers supported preserving the default value in the case where the input and output ParamSpec is the same.

@CGamesPlay
Copy link
Author

Thought about this a bit more. This limitation I mentioned is necessary in general.

I don't think it's possible to correctly statically analyze self-binding without breaking currently valid behaviors in Python. The problem is that the signature of the method changes dynamically at runtime. Consider the following example: Code sample in pyright playground

from typing import Callable


def func(): ...


class Class:
    func: Callable[[], None] = func


Class().func()

This is a runtime TypeError in a fully typed Python program. Which of these lines should the type checker reject?

@CGamesPlay
Copy link
Author

I've updated the gist. Changes in revision 4:

  • Default arguments are supported for all arguments except those which are the first positional argument and can hold a type object (e.g. def func(first: object = SOME_OBJECT) is still not supported, but def func(first: int = 1) is).
  • Functions which can hold a type object in their first argument are now supported (e.g. def func(first: object) is now supported).
  • Staticmethods which can hold a type object in their first argument are now supported, except those which can also hold the class they're defined in (e.g. def static_method(first: object) is not supported, but def static_method(first: type[SomeUnrelatedClass]) is).
  • Staticmethods which can hold an object in their first argument are now supported, except those which can also hold an instance of the class they're defined in (e.g. def static_method(first: object) is not supported, but def static_method(first: SomeUnrelatedClass) is).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment