Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions injection/_core/asfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ def asfunction[**P, T](
) -> Any:
def decorator(wp: AsFunctionWrappedType[P, T]) -> Callable[P, T]:
fake_method = wp.__call__.__get__(NotImplemented, wp)
factory: Caller[..., Callable[P, T]] = (
(module or mod())
.make_injected_function(
wp,
threadsafe=threadsafe,
)
.__injection_metadata__
factory: Caller[..., Callable[P, T]] = (module or mod())._metadata(
wp,
threadsafe,
)

wrapper: Callable[P, T] = (
Expand Down
4 changes: 2 additions & 2 deletions injection/_core/locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _extract_caller[**P, T](


def _make_injectable[T](
injectable_factory: InjectableFactory[T],
factory: InjectableFactory[T],
recipe: Recipe[..., T],
) -> Injectable[T]:
return injectable_factory(_extract_caller(recipe))
return factory(_extract_caller(recipe))
61 changes: 34 additions & 27 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,7 @@ def make_injected_function[**P, T](
/,
threadsafe: bool | None = None,
) -> InjectedFunction[P, T]:
metadata = InjectMetadata(wrapped, threadsafe)

@metadata.task
def listen() -> None:
metadata.update(self)
self.add_listener(metadata)
metadata = self._metadata(wrapped, threadsafe)

if iscoroutinefunction(wrapped):
return AsyncInjectedFunction(metadata) # type: ignore[arg-type, return-value]
Expand All @@ -405,11 +400,7 @@ def make_async_factory[T](
/,
threadsafe: bool | None = None,
) -> Callable[..., Awaitable[T]]:
factory: InjectedFunction[..., T] = self.make_injected_function(
wrapped,
threadsafe,
)
return factory.__injection_metadata__.acall
return self._metadata(wrapped, threadsafe).acall

async def afind_instance[T](
self,
Expand Down Expand Up @@ -522,12 +513,14 @@ def aget_lazy_instance[T, Default](
*,
threadsafe: bool | None = None,
) -> Awaitable[T | Default]:
function = self.make_injected_function(
lambda instance=default: instance,
threadsafe=threadsafe,
return SimpleAwaitable(
self._metadata(
lambda instance=default: instance,
threadsafe=threadsafe,
)
.set_owner(cls) # type: ignore[arg-type]
.acall
)
metadata = function.__injection_metadata__.set_owner(cls)
return SimpleAwaitable(metadata.acall)

if TYPE_CHECKING: # pragma: no cover

Expand Down Expand Up @@ -556,12 +549,14 @@ def get_lazy_instance[T, Default](
*,
threadsafe: bool | None = None,
) -> Invertible[T | Default]:
function = self.make_injected_function(
lambda instance=default: instance,
threadsafe=threadsafe,
return SimpleInvertible(
self._metadata(
lambda instance=default: instance,
threadsafe=threadsafe,
)
.set_owner(cls) # type: ignore[arg-type]
.call
)
metadata = function.__injection_metadata__.set_owner(cls)
return SimpleInvertible(metadata.call)

def update[T](self, updater: Updater[T]) -> Self:
self.__locator.update(updater)
Expand Down Expand Up @@ -686,6 +681,13 @@ def dispatch(self, event: Event) -> Iterator[None]:
finally:
self.__debug(event)

def _metadata[**P, T](
self,
wrapped: Callable[P, T],
threadsafe: bool | None = None,
) -> InjectMetadata[P, T]:
return InjectMetadata(wrapped, threadsafe).listen(self)

def _iter_locators(self) -> Iterator[Locator]:
for module in self.__modules:
yield from module._iter_locators()
Expand Down Expand Up @@ -879,16 +881,14 @@ def wrapped(self) -> Callable[P, T]:

async def abind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
arguments = self.__get_arguments(args, kwargs)
dependencies = await self.__dependencies.aget_arguments(exclude=arguments)
if dependencies:
if dependencies := await self.__dependencies.aget_arguments(exclude=arguments):
return self.__merge_arguments(arguments, dependencies)

return Arguments(args, kwargs)

def bind(self, args: Iterable[Any], kwargs: Mapping[str, Any]) -> Arguments:
arguments = self.__get_arguments(args, kwargs)
dependencies = self.__dependencies.get_arguments(exclude=arguments)
if dependencies:
if dependencies := self.__dependencies.get_arguments(exclude=arguments):
return self.__merge_arguments(arguments, dependencies)

return Arguments(args, kwargs)
Expand Down Expand Up @@ -930,6 +930,14 @@ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:

return decorator(wrapped) if wrapped else decorator

def listen(self, module: Module) -> Self:
@self.task
def start_listening() -> None:
self.update(module)
module.add_listener(self)

return self

@singledispatchmethod
def on_event(self, event: Event, /) -> ContextManager[None] | None:
return None
Expand All @@ -945,8 +953,7 @@ def __get_arguments(
args: Iterable[Any],
kwargs: Mapping[str, Any],
) -> dict[str, Any]:
bound = self.signature.bind_partial(*args, **kwargs)
return bound.arguments
return self.signature.bind_partial(*args, **kwargs).arguments

def __merge_arguments(
self,
Expand Down
18 changes: 11 additions & 7 deletions injection/ext/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from types import GenericAlias
from typing import Annotated, Any, TypeAlias, TypeAliasType
Expand Down Expand Up @@ -25,20 +26,23 @@ def __call__[T](
) -> Any:
module = module or self.module
threadsafe = self.threadsafe if threadsafe is None else threadsafe
lazy_instance = module.aget_lazy_instance(cls, default, threadsafe=threadsafe)

async def dependency() -> T:
return await lazy_instance

class_name = getattr(cls, "__name__", str(cls))
dependency.__name__ = f"inject({class_name})"
awaitable = module.aget_lazy_instance(cls, default, threadsafe=threadsafe)
dependency = self.__make_dependency(awaitable)
dependency.__name__ = f"Inject[{getattr(cls, '__name__', str(cls))}]"
return Depends(dependency, use_cache=False)

def __getitem__[T, *Ts](self, params: T | tuple[T, *Ts], /) -> TypeAlias:
iter_params = iter(params if isinstance(params, tuple) else (params,))
cls = next(iter_params)
return Annotated[cls, self(cls), *iter_params]

@staticmethod
def __make_dependency[T](awaitable: Awaitable[T]) -> Callable[[], Awaitable[T]]:
async def dependency() -> T:
return await awaitable

return dependency


Inject = FastAPIInject()
InjectThreadSafe = FastAPIInject(threadsafe=True)
Expand Down
Loading