Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions injection/_core/asfunction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from functools import wraps
from collections.abc import Awaitable, Callable
from functools import update_wrapper
from inspect import iscoroutinefunction
from typing import Any, Protocol

Expand All @@ -21,33 +21,45 @@ def asfunction[**P, T](
module: Module | None = None,
threadsafe: bool | None = None,
) -> Any:
module = module or mod()

def decorator(wp: AsFunctionWrappedType[P, T]) -> Callable[P, T]:
fake_method = wp.__call__.__get__(NotImplemented, wp)
factory: Caller[..., Callable[P, T]] = module.make_injected_function(
wp,
threadsafe=threadsafe,
).__injection_metadata__
factory: Caller[..., Callable[P, T]] = (
(module or mod())
.make_injected_function(
wp,
threadsafe=threadsafe,
)
.__injection_metadata__
)

wrapper: Callable[P, T]
wrapper: Callable[P, T] = (
_wrap_async(factory) # type: ignore[arg-type, assignment]
if iscoroutinefunction(fake_method)
else _wrap_sync(factory)
)
wrapper = update_wrapper(wrapper, fake_method)

if iscoroutinefunction(fake_method):
for attribute in ("__name__", "__qualname__"):
setattr(wrapper, attribute, getattr(wp, attribute))

@wraps(fake_method)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
self = await factory.acall()
return await self(*args, **kwargs) # type: ignore[misc]
return wrapper

else:
return decorator(wrapped) if wrapped else decorator

@wraps(fake_method)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self = factory.call()
return self(*args, **kwargs)

wrapper.__name__ = wp.__name__
wrapper.__qualname__ = wp.__qualname__
return wrapper
def _wrap_async[**P, T](
factory: Caller[..., Callable[P, Awaitable[T]]],
) -> Callable[P, Awaitable[T]]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self = await factory.acall()
return await self(*args, **kwargs)

return decorator(wrapped) if wrapped else decorator
return wrapper


def _wrap_sync[**P, T](factory: Caller[..., Callable[P, T]]) -> Callable[P, T]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
self = factory.call()
return self(*args, **kwargs)

return wrapper
216 changes: 128 additions & 88 deletions injection/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,104 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine, Iterator
from contextlib import asynccontextmanager, contextmanager
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Coroutine, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import wraps
from types import ModuleType as PythonModule
from typing import TYPE_CHECKING, Any, Concatenate, Protocol, Self, final, overload
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Protocol,
Self,
overload,
runtime_checkable,
)

from injection import Module
from injection.loaders import ProfileLoader, PythonModuleLoader

__all__ = ("AsyncEntrypoint", "Entrypoint", "entrypointmaker")

type Entrypoint[**P, T] = EntrypointBuilder[P, Any, T]
type AsyncEntrypoint[**P, T] = Entrypoint[P, Awaitable[T]]

type AsyncEntrypoint[**P, T] = Entrypoint[P, Coroutine[Any, Any, T]]
type EntrypointSetupMethod[**P, **EPP, T1, T2] = Callable[
Concatenate[Entrypoint[EPP, T1], P],
Entrypoint[EPP, T2],
]


class Rule[**P, T1, T2](ABC):
__slots__ = ()

@abstractmethod
def apply(self, wrapped: Callable[P, T1]) -> Callable[P, T2]:
raise NotImplementedError


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _AsyncToSyncRule[**P, T](Rule[P, Awaitable[T], T]):
run: Callable[[Awaitable[T]], T]

def apply(self, wrapped: Callable[P, Awaitable[T]]) -> Callable[P, T]:
@wraps(wrapped)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return self.run(wrapped(*args, **kwargs))

return wrapper


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _DecorateRule[**P, T1, T2](Rule[P, T1, T2]):
decorator: Callable[[Callable[P, T1]], Callable[P, T2]]

def apply(self, wrapped: Callable[P, T1]) -> Callable[P, T2]:
return self.decorator(wrapped)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _InjectRule[**P, T](Rule[P, T, T]):
module: Module

def apply(self, wrapped: Callable[P, T]) -> Callable[P, T]:
return self.module.make_injected_function(wrapped)


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _LoadModulesRule[**P, T](Rule[P, T, T]):
loader: PythonModuleLoader
packages: Sequence[PythonModule | str]

def apply(self, wrapped: Callable[P, T]) -> Callable[P, T]:
return self.__decorator()(wrapped)

@contextmanager
def __decorator(self) -> Iterator[None]:
self.loader.load(*self.packages)
yield


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class _LoadProfileRule[**P, T](Rule[P, T, T]):
loader: ProfileLoader
profile_name: str

def apply(self, wrapped: Callable[P, T]) -> Callable[P, T]:
return self.__decorator()(wrapped)

@contextmanager
def __decorator(self) -> Iterator[None]:
with self.loader.load(self.profile_name):
yield


@runtime_checkable
class _EntrypointDecorator[**P, T1, T2](Protocol):
__slots__ = ()

if TYPE_CHECKING: # pragma: no cover

@overload
Expand All @@ -42,18 +119,21 @@ def __call__(
autocall: bool = ...,
) -> Callable[[Callable[P, T1]], Callable[P, T2]]: ...

@abstractmethod
def __call__(
self,
wrapped: Callable[P, T1] | None = ...,
/,
*,
autocall: bool = ...,
) -> Any: ...
) -> Any:
raise NotImplementedError


# SMP = Setup Method Parameters
# EPP = EntryPoint Parameters


if TYPE_CHECKING: # pragma: no cover

@overload
Expand Down Expand Up @@ -85,114 +165,74 @@ def entrypointmaker[**SMP, **EPP, T1, T2](
def decorator(
wp: EntrypointSetupMethod[SMP, EPP, T1, T2],
) -> _EntrypointDecorator[EPP, T1, T2]:
return Entrypoint._make_decorator(wp, profile_loader)
pl = (profile_loader or ProfileLoader()).init()
setup_method = pl.module.make_injected_function(wp)
return setup_method(EntrypointBuilder(pl)) # type: ignore[call-arg]

return decorator(wrapped) if wrapped else decorator


@final
@dataclass(repr=False, eq=False, frozen=True, slots=True)
class Entrypoint[**P, T]:
function: Callable[P, T]
class EntrypointBuilder[**P, T1, T2](_EntrypointDecorator[P, T1, T2]):
profile_loader: ProfileLoader = field(default_factory=ProfileLoader)
__rules: list[Rule[P, Any, Any]] = field(default_factory=list, init=False)

def __call__(
self,
wrapped: Callable[P, T1] | None = None,
/,
*,
autocall: bool = False,
) -> Any:
def decorator(wp: Callable[P, T1]) -> Callable[P, T2]:
wrapper = self._apply(wp)

if autocall:
wrapper() # type: ignore[call-arg]

def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.function(*args, **kwargs)
return wrapper

@property
def __module(self) -> Module:
return self.profile_loader.module
return decorator(wrapped) if wrapped else decorator

def async_to_sync[_T](
self: AsyncEntrypoint[P, _T],
self: EntrypointBuilder[P, T1, Awaitable[_T]],
run: Callable[[Coroutine[Any, Any, _T]], _T] = asyncio.run,
/,
) -> Entrypoint[P, _T]:
function = self.function
) -> EntrypointBuilder[P, T1, _T]:
return self._add_rule(_AsyncToSyncRule(run)) # type: ignore[arg-type]

@wraps(function)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> _T:
return run(function(*args, **kwargs))

return self.__recreate(wrapper)

def decorate(
def decorate[_T](
self,
decorator: Callable[[Callable[P, T]], Callable[P, T]],
decorator: Callable[[Callable[P, T2]], Callable[P, _T]],
/,
) -> Self:
return self.__recreate(decorator(self.function))
) -> EntrypointBuilder[P, T1, _T]:
return self._add_rule(_DecorateRule(decorator))

def inject(self) -> Self:
return self.decorate(self.__module.make_injected_function)
self._add_rule(_InjectRule(self.profile_loader.module))
return self

def load_modules(
self,
/,
loader: PythonModuleLoader,
*packages: PythonModule | str,
) -> Self:
return self.setup(lambda: loader.load(*packages))
self._add_rule(_LoadModulesRule(loader, packages))
return self

def load_profile(self, name: str, /) -> Self:
@contextmanager
def decorator(loader: ProfileLoader) -> Iterator[None]:
with loader.load(name):
yield

return self.decorate(decorator(self.profile_loader))

def setup(self, function: Callable[..., Any], /) -> Self:
@contextmanager
def decorator() -> Iterator[Any]:
yield function()
self._add_rule(_LoadProfileRule(self.profile_loader, name))
return self

return self.decorate(decorator())

def async_setup[_T](
self: AsyncEntrypoint[P, _T],
function: Callable[..., Awaitable[Any]],
/,
) -> AsyncEntrypoint[P, _T]:
@asynccontextmanager
async def decorator() -> AsyncIterator[Any]:
yield await function()

return self.decorate(decorator())

def __recreate[**_P, _T](
self: Entrypoint[Any, Any],
function: Callable[_P, _T],
/,
) -> Entrypoint[_P, _T]:
return type(self)(function, self.profile_loader)

@classmethod
def _make_decorator[**_P, _T](
cls,
setup_method: EntrypointSetupMethod[_P, P, T, _T],
/,
profile_loader: ProfileLoader | None = None,
) -> _EntrypointDecorator[P, T, _T]:
profile_loader = profile_loader or ProfileLoader()
setup_method = profile_loader.module.make_injected_function(setup_method)

def entrypoint_decorator(
wrapped: Callable[P, T] | None = None,
/,
*,
autocall: bool = False,
) -> Any:
def decorator(wp: Callable[P, T]) -> Callable[P, _T]:
profile_loader.init()
self = cls(wp, profile_loader)
wrapper = setup_method(self).function # type: ignore[call-arg]

if autocall:
wrapper() # type: ignore[call-arg]

return wrapper
def _add_rule[_T](
self,
rule: Rule[P, T2, _T],
) -> EntrypointBuilder[P, T1, _T]:
self.__rules.append(rule)
return self # type: ignore[return-value]

return decorator(wrapped) if wrapped else decorator
def _apply(self, function: Callable[P, T1], /) -> Callable[P, T2]:
for rule in self.__rules:
function = rule.apply(function)

return entrypoint_decorator
return function # type: ignore[return-value]
Loading