diff --git a/laygo/context/__init__.py b/laygo/context/__init__.py new file mode 100644 index 0000000..12abc3d --- /dev/null +++ b/laygo/context/__init__.py @@ -0,0 +1,19 @@ +""" +Laygo Context Management Package. + +This package provides different strategies for managing state (context) +within a data pipeline, from simple in-memory dictionaries to +process-safe managers for parallel execution. +""" + +from .parallel import ParallelContextManager +from .simple import SimpleContextManager +from .types import IContextHandle +from .types import IContextManager + +__all__ = [ + "IContextManager", + "IContextHandle", + "SimpleContextManager", + "ParallelContextManager", +] diff --git a/laygo/context/parallel.py b/laygo/context/parallel.py new file mode 100644 index 0000000..e5bd274 --- /dev/null +++ b/laygo/context/parallel.py @@ -0,0 +1,138 @@ +""" +A context manager for parallel and distributed processing using +multiprocessing.Manager to share state across processes. +""" + +from collections.abc import Callable +from collections.abc import Iterator +import multiprocessing as mp +from multiprocessing.managers import DictProxy +import threading +from threading import Lock +from typing import Any +from typing import TypeVar + +from laygo.context.types import IContextHandle +from laygo.context.types import IContextManager + +R = TypeVar("R") + + +class ParallelContextHandle(IContextHandle): + """ + A lightweight, picklable handle that carries the actual shared objects + (the DictProxy and Lock) to worker processes. + """ + + def __init__(self, shared_dict: DictProxy, lock: Lock): + self._shared_dict = shared_dict + self._lock = lock + + def create_proxy(self) -> "IContextManager": + """ + Creates a new ParallelContextManager instance that wraps the shared + objects received by the worker process. + """ + return ParallelContextManager(handle=self) + + +class ParallelContextManager(IContextManager): + """ + A context manager that enables state sharing across processes. + + It operates in two modes: + 1. Main Mode: When created normally, it starts a multiprocessing.Manager + and creates a shared dictionary and lock. + 2. Proxy Mode: When created from a handle, it wraps a DictProxy and Lock + that were received from another process. It does not own the manager. + """ + + def __init__(self, initial_context: dict[str, Any] | None = None, handle: ParallelContextHandle | None = None): + """ + Initializes the manager. If a handle is provided, it initializes in + proxy mode; otherwise, it starts a new manager. + """ + if handle: + # --- PROXY MODE INITIALIZATION --- + # This instance is a client wrapping objects from an existing server. + self._manager = None # Proxies do not own the manager process. + self._shared_dict = handle._shared_dict + self._lock = handle._lock + else: + # --- MAIN MODE INITIALIZATION --- + # This instance owns the manager and its shared objects. + self._manager = mp.Manager() + self._shared_dict = self._manager.dict(initial_context or {}) + self._lock = self._manager.Lock() + + # Thread-local storage for lock state to handle concurrent access + self._local = threading.local() + + def _lock_context(self) -> None: + """Acquire the lock for this context manager.""" + if not getattr(self._local, "is_locked", False): + self._lock.acquire() + self._local.is_locked = True + + def _unlock_context(self) -> None: + """Release the lock for this context manager.""" + if getattr(self._local, "is_locked", False): + self._lock.release() + self._local.is_locked = False + + def _execute_locked(self, operation: Callable[[], R]) -> R: + """A private helper to execute an operation within a lock.""" + if not getattr(self._local, "is_locked", False): + self._lock_context() + try: + return operation() + finally: + self._unlock_context() + else: + return operation() + + def get_handle(self) -> ParallelContextHandle: + """ + Returns a picklable handle containing the shared dict and lock. + Only the main instance can generate handles. + """ + if not self._manager: + raise TypeError("Cannot get a handle from a proxy context instance.") + + return ParallelContextHandle(self._shared_dict, self._lock) + + def shutdown(self) -> None: + """ + Shuts down the background manager process. + This is a no-op for proxy instances. + """ + if self._manager: + self._manager.shutdown() + + def __enter__(self) -> "ParallelContextManager": + """Acquires the lock for use in a 'with' statement.""" + self._lock_context() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Releases the lock.""" + self._unlock_context() + + def __getitem__(self, key: str) -> Any: + return self._shared_dict[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._execute_locked(lambda: self._shared_dict.__setitem__(key, value)) + + def __delitem__(self, key: str) -> None: + self._execute_locked(lambda: self._shared_dict.__delitem__(key)) + + def __iter__(self) -> Iterator[str]: + # Iteration needs to copy the keys to be safe across processes + return self._execute_locked(lambda: iter(list(self._shared_dict.keys()))) + + def __len__(self) -> int: + return self._execute_locked(lambda: len(self._shared_dict)) + + def to_dict(self) -> dict[str, Any]: + return self._execute_locked(lambda: dict(self._shared_dict)) diff --git a/laygo/context/simple.py b/laygo/context/simple.py new file mode 100644 index 0000000..dbb9fab --- /dev/null +++ b/laygo/context/simple.py @@ -0,0 +1,97 @@ +""" +A simple, dictionary-based context manager for single-process pipelines. +""" + +from collections.abc import Iterator +from typing import Any + +from laygo.context.types import IContextHandle +from laygo.context.types import IContextManager + + +class SimpleContextHandle(IContextHandle): + """ + A handle for the SimpleContextManager that provides a reference back to the + original manager instance. + + In a single-process environment, the "proxy" is the manager itself, ensuring + all transformers in a chain share the exact same context dictionary. + """ + + def __init__(self, manager_instance: "IContextManager"): + self._manager_instance = manager_instance + + def create_proxy(self) -> "IContextManager": + """ + Returns the original SimpleContextManager instance. + + This ensures that in a non-distributed pipeline, all chained transformers + operate on the same shared dictionary. + """ + return self._manager_instance + + +class SimpleContextManager(IContextManager): + """ + A basic context manager that uses a standard Python dictionary for state. + + This manager is suitable for single-threaded, single-process pipelines where + no state needs to be shared across process boundaries. It is the default + context manager for a Laygo pipeline. + """ + + def __init__(self, initial_context: dict[str, Any] | None = None) -> None: + """ + Initializes the context manager with an optional dictionary. + + Args: + initial_context: An optional dictionary to populate the context with. + """ + self._context = dict(initial_context or {}) + + def get_handle(self) -> IContextHandle: + """ + Returns a handle that holds a reference back to this same instance. + """ + return SimpleContextHandle(self) + + def __enter__(self) -> "SimpleContextManager": + """ + Provides 'with' statement compatibility. No lock is needed for this + simple, single-threaded context manager. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Provides 'with' statement compatibility. No lock is needed for this + simple, single-threaded context manager. + """ + pass + + def __getitem__(self, key: str) -> Any: + return self._context[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._context[key] = value + + def __delitem__(self, key: str) -> None: + del self._context[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._context) + + def __len__(self) -> int: + return len(self._context) + + def shutdown(self) -> None: + """No-op for the simple context manager.""" + pass + + def to_dict(self) -> dict[str, Any]: + """ + Returns a copy of the entire context as a standard Python dictionary. + + This operation is performed atomically to ensure consistency. + """ + return self._context diff --git a/laygo/context/types.py b/laygo/context/types.py new file mode 100644 index 0000000..7a9fbe9 --- /dev/null +++ b/laygo/context/types.py @@ -0,0 +1,104 @@ +""" +Defines the abstract base classes for context management in Laygo. + +This module provides the core interfaces (IContextHandle and IContextManager) +that all context managers must implement, ensuring a consistent API for +state management across different execution environments (simple, threaded, parallel). +""" + +from abc import ABC +from abc import abstractmethod +from collections.abc import MutableMapping +from typing import Any + + +class IContextHandle(ABC): + """ + An abstract base class for a picklable handle to a context manager. + + A handle contains the necessary information for a worker process to + reconstruct a connection (a proxy) to the shared context. + """ + + @abstractmethod + def create_proxy(self) -> "IContextManager": + """ + Creates the appropriate context proxy instance from the handle's data. + + This method is called within a worker process to establish its own + connection to the shared state. + + Returns: + An instance of an IContextManager proxy. + """ + raise NotImplementedError + + +class IContextManager(MutableMapping[str, Any], ABC): + """ + Abstract base class for managing shared state (context) in a pipeline. + + This class defines the contract for all context managers, ensuring they + provide a dictionary-like interface for state manipulation by inheriting + from `collections.abc.MutableMapping`. It also includes methods for + distribution (get_handle), resource management (shutdown), and context + management (__enter__, __exit__). + """ + + @abstractmethod + def get_handle(self) -> IContextHandle: + """ + Returns a picklable handle for connecting from a worker process. + + This handle is serialized and sent to distributed workers, which then + use it to create a proxy to the shared context. + + Returns: + A picklable IContextHandle instance. + """ + raise NotImplementedError + + @abstractmethod + def shutdown(self) -> None: + """ + Performs final synchronization and cleans up any resources. + + This method is responsible for releasing connections, shutting down + background processes, or any other cleanup required by the manager. + """ + raise NotImplementedError + + def __enter__(self) -> "IContextManager": + """ + Enters the runtime context related to this object. + + Returns: + The context manager instance itself. + """ + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None: + """ + Exits the runtime context and performs cleanup. + + Args: + exc_type: The exception type, if an exception was raised. + exc_val: The exception instance, if an exception was raised. + exc_tb: The traceback object, if an exception was raised. + """ + self.shutdown() + + def to_dict(self) -> dict[str, Any]: + """ + Returns a copy of the entire shared context as a standard + Python dictionary. + + This operation is performed atomically using a lock to ensure a + consistent snapshot of the context is returned. + + Returns: + A standard dict containing a copy of the shared context. + """ + # The dict() constructor iterates over the proxy and copies its items. + # The lock ensures this happens atomically without race conditions. + raise NotImplementedError diff --git a/laygo/errors.py b/laygo/errors.py index 5e017ce..c1329b9 100644 --- a/laygo/errors.py +++ b/laygo/errors.py @@ -1,11 +1,11 @@ from collections.abc import Callable -from laygo.helpers import PipelineContext +from laygo.context.types import IContextManager -ChunkErrorHandler = Callable[[list, Exception, PipelineContext], None] +ChunkErrorHandler = Callable[[list, Exception, IContextManager], None] -def raise_error(chunk: list, error: Exception, context: PipelineContext) -> None: +def raise_error(chunk: list, error: Exception, context: IContextManager) -> None: """Handler that always re-raises the error, stopping execution. This is a default error handler that provides fail-fast behavior by @@ -47,7 +47,7 @@ def on_error(self, handler: ChunkErrorHandler) -> "ErrorHandler": self._handlers.insert(0, handler) return self - def handle(self, chunk: list, error: Exception, context: PipelineContext) -> None: + def handle(self, chunk: list, error: Exception, context: IContextManager) -> None: """Execute all handlers in the chain. Handlers are executed in reverse order of addition. Execution stops diff --git a/laygo/helpers.py b/laygo/helpers.py index f3bbbf2..fcb25c6 100644 --- a/laygo/helpers.py +++ b/laygo/helpers.py @@ -3,10 +3,15 @@ from typing import Any from typing import TypeGuard +from laygo.context.types import IContextManager + class PipelineContext(dict[str, Any]): """Generic, untyped context available to all pipeline operations. + DEPRECATED: This class is deprecated and will be removed in a future version. + Use IContextManager implementations (SimpleContextManager, etc.) instead. + A dictionary-based context that can store arbitrary data shared across pipeline operations. This allows passing state and configuration between different stages of data processing. @@ -16,14 +21,14 @@ class PipelineContext(dict[str, Any]): # Define the specific callables for clarity -ContextAwareCallable = Callable[[Any, PipelineContext], Any] -ContextAwareReduceCallable = Callable[[Any, Any, PipelineContext], Any] +ContextAwareCallable = Callable[[Any, IContextManager], Any] +ContextAwareReduceCallable = Callable[[Any, Any, IContextManager], Any] def is_context_aware(func: Callable[..., Any]) -> TypeGuard[ContextAwareCallable]: """Check if a function is context-aware by inspecting its signature. - A context-aware function accepts a PipelineContext as its second parameter, + A context-aware function accepts an IContextManager as its second parameter, allowing it to access shared state during pipeline execution. Args: @@ -40,7 +45,7 @@ def is_context_aware_reduce(func: Callable[..., Any]) -> TypeGuard[ContextAwareR """Check if a reduce function is context-aware by inspecting its signature. A context-aware reduce function accepts an accumulator, current value, - and PipelineContext as its three parameters. + and IContextManager as its three parameters. Args: func: The reduce function to inspect for context awareness. diff --git a/laygo/pipeline.py b/laygo/pipeline.py index c94dfd1..931d83e 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -5,16 +5,16 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures import as_completed import itertools -import multiprocessing as mp from queue import Queue from typing import Any from typing import TypeVar from typing import overload -from laygo.helpers import PipelineContext +from laygo.context import IContextManager +from laygo.context.parallel import ParallelContextManager +from laygo.context.types import IContextHandle from laygo.helpers import is_context_aware from laygo.transformers.transformer import Transformer -from laygo.transformers.transformer import passthrough_chunks T = TypeVar("T") U = TypeVar("U") @@ -44,12 +44,14 @@ class Pipeline[T]: pipeline effectively single-use unless the data source is re-initialized. """ - def __init__(self, *data: Iterable[T]) -> None: + def __init__(self, *data: Iterable[T], context_manager: IContextManager | None = None) -> None: """Initialize a pipeline with one or more data sources. Args: *data: One or more iterable data sources. If multiple sources are provided, they will be chained together. + context_manager: An instance of a class that implements IContextManager. + If None, a ParallelContextManager is used by default. Raises: ValueError: If no data sources are provided. @@ -59,25 +61,16 @@ def __init__(self, *data: Iterable[T]) -> None: self.data_source: Iterable[T] = itertools.chain.from_iterable(data) if len(data) > 1 else data[0] self.processed_data: Iterator = iter(self.data_source) - # Always create a shared context with multiprocessing manager - self._manager = mp.Manager() - self.ctx = self._manager.dict() - # Add a shared lock to the context for safe concurrent updates - self.ctx["lock"] = self._manager.Lock() - - # Store reference to original context for final synchronization - self._original_context_ref: PipelineContext | None = None + # Rule 1: Pipeline creates a simple context manager by default. + self.context_manager = context_manager if context_manager is not None else ParallelContextManager() def __del__(self) -> None: - """Clean up the multiprocessing manager when the pipeline is destroyed.""" - try: - self._sync_context_back() - self._manager.shutdown() - except Exception: - pass + """Clean up the context manager when the pipeline is destroyed.""" + if hasattr(self, "context_manager"): + self.context_manager.shutdown() - def context(self, ctx: PipelineContext) -> "Pipeline[T]": - """Update the pipeline context and store a reference to the original context. + def context(self, ctx: dict[str, Any]) -> "Pipeline[T]": + """Update the pipeline's context manager with values from a dictionary. The provided context will be used during pipeline execution and any modifications made by transformers will be synchronized back to the @@ -96,25 +89,10 @@ def context(self, ctx: PipelineContext) -> "Pipeline[T]": automatically synchronized back to the original context object when the pipeline is destroyed or processing completes. """ - # Store reference to the original context - self._original_context_ref = ctx - # Copy the context data to the pipeline's shared context - self.ctx.update(ctx) + self._user_context = ctx + self.context_manager.update(ctx) return self - def _sync_context_back(self) -> None: - """Synchronize the final pipeline context back to the original context reference. - - This is called after processing is complete to update the original - context with any changes made during pipeline execution. - """ - if self._original_context_ref is not None: - # Copy the final context state back to the original context reference - final_context_state = dict(self.ctx) - final_context_state.pop("lock", None) # Remove non-serializable lock - self._original_context_ref.clear() - self._original_context_ref.update(final_context_state) - def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]": """Apply a transformation using a lambda function. @@ -146,13 +124,13 @@ def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ... def apply[U](self, transformer: Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]": ... @overload - def apply[U](self, transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]]) -> "Pipeline[U]": ... + def apply[U](self, transformer: Callable[[Iterable[T], IContextManager], Iterator[U]]) -> "Pipeline[U]": ... def apply[U]( self, transformer: Transformer[T, U] | Callable[[Iterable[T]], Iterator[U]] - | Callable[[Iterable[T], PipelineContext], Iterator[U]], + | Callable[[Iterable[T], IContextManager], Iterator[U]], ) -> "Pipeline[U]": """Apply a transformer to the current data source. @@ -181,10 +159,11 @@ def apply[U]( """ match transformer: case Transformer(): - self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore + # Pass the pipeline's context manager to the transformer + self.processed_data = transformer(self.processed_data, context=self.context_manager) # type: ignore case _ if callable(transformer): if is_context_aware(transformer): - self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore + self.processed_data = transformer(self.processed_data, self.context_manager) # type: ignore else: self.processed_data = transformer(self.processed_data) # type: ignore case _: @@ -192,95 +171,6 @@ def apply[U]( return self # type: ignore - def branch( - self, - branches: dict[str, Transformer[T, Any]], - batch_size: int = 1000, - max_batch_buffer: int = 1, - use_queue_chunks: bool = True, - ) -> dict[str, list[Any]]: - """Forks the pipeline into multiple branches for concurrent, parallel processing. - - This is a **terminal operation** that implements a fan-out pattern where - the entire dataset is copied to each branch for independent processing. - Each branch processes the complete dataset concurrently using separate - transformers, and results are collected and returned in a dictionary. - - Args: - branches: A dictionary where keys are branch names (str) and values - are `Transformer` instances of any subtype. - batch_size: The number of items to batch together when sending data - to branches. Larger batches can improve throughput but - use more memory. Defaults to 1000. - max_batch_buffer: The maximum number of batches to buffer for each - branch queue. Controls memory usage and creates - backpressure. Defaults to 1. - use_queue_chunks: Whether to use passthrough chunking for the - transformers. When True, batches are processed - as chunks. Defaults to True. - - Returns: - A dictionary where keys are the branch names and values are lists - of all items processed by that branch's transformer. - - Note: - This operation consumes the pipeline's iterator, making subsequent - operations on the same pipeline return empty results. - """ - if not branches: - self.consume() - return {} - - source_iterator = self.processed_data - branch_items = list(branches.items()) - num_branches = len(branch_items) - final_results: dict[str, list[Any]] = {} - - queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] - - def producer() -> None: - """Reads from the source and distributes batches to ALL branch queues.""" - # Use itertools.batched for clean and efficient batch creation. - for batch_tuple in itertools.batched(source_iterator, batch_size): - # The batch is a tuple; convert to a list for consumers. - batch_list = list(batch_tuple) - for q in queues: - q.put(batch_list) - - # Signal to all consumers that the stream is finished. - for q in queues: - q.put(None) - - def consumer(transformer: Transformer, queue: Queue) -> list[Any]: - """Consumes batches from a queue and runs them through a transformer.""" - - def stream_from_queue() -> Iterator[T]: - while (batch := queue.get()) is not None: - yield batch - - if use_queue_chunks: - transformer = transformer.set_chunker(passthrough_chunks) - - result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore - return list(result_iterator) - - with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: - executor.submit(producer) - - future_to_name = { - executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items) - } - - for future in as_completed(future_to_name): - name = future_to_name[future] - try: - final_results[name] = future.result() - except Exception as e: - print(f"Branch '{name}' raised an exception: {e}") - final_results[name] = [] - - return final_results - def buffer(self, size: int, batch_size: int = 1000) -> "Pipeline[T]": """Inserts a buffer in the pipeline to allow downstream processing to read ahead. @@ -340,7 +230,7 @@ def __iter__(self) -> Iterator[T]: """ yield from self.processed_data - def to_list(self) -> list[T]: + def to_list(self) -> tuple[list[T], dict[str, Any]]: """Execute the pipeline and return the results as a list. This is a terminal operation that consumes the pipeline's iterator @@ -353,9 +243,9 @@ def to_list(self) -> list[T]: This operation consumes the pipeline's iterator, making subsequent operations on the same pipeline return empty results. """ - return list(self.processed_data) + return list(self.processed_data), self.context_manager.to_dict() - def each(self, function: PipelineFunction[T]) -> None: + def each(self, function: PipelineFunction[T]) -> tuple[None, dict[str, Any]]: """Apply a function to each element (terminal operation). This is a terminal operation that processes each element for side effects @@ -372,7 +262,9 @@ def each(self, function: PipelineFunction[T]) -> None: for item in self.processed_data: function(item) - def first(self, n: int = 1) -> list[T]: + return None, self.context_manager.to_dict() + + def first(self, n: int = 1) -> tuple[list[T], dict[str, Any]]: """Get the first n elements of the pipeline (terminal operation). This is a terminal operation that consumes up to n elements from the @@ -393,9 +285,9 @@ def first(self, n: int = 1) -> list[T]: operations will continue from where this operation left off. """ assert n >= 1, "n must be at least 1" - return list(itertools.islice(self.processed_data, n)) + return list(itertools.islice(self.processed_data, n)), self.context_manager.to_dict() - def consume(self) -> None: + def consume(self) -> tuple[None, dict[str, Any]]: """Consume the pipeline without returning results (terminal operation). This is a terminal operation that processes all elements in the pipeline @@ -408,3 +300,101 @@ def consume(self) -> None: """ for _ in self.processed_data: pass + + return None, self.context_manager.to_dict() + + def branch( + self, + branches: dict[str, Transformer[T, Any]], + batch_size: int = 1000, + max_batch_buffer: int = 1, + ) -> tuple[dict[str, list[Any]], dict[str, Any]]: + """Forks the pipeline into multiple branches for concurrent, parallel processing. + + This is a **terminal operation** that implements a fan-out pattern where + the entire dataset is copied to each branch for independent processing. + Each branch gets its own Pipeline instance with isolated context management, + and results are collected and returned in a dictionary. + + Args: + branches: A dictionary where keys are branch names (str) and values + are `Transformer` instances of any subtype. + batch_size: The number of items to batch together when sending data + to branches. Larger batches can improve throughput but + use more memory. Defaults to 1000. + max_batch_buffer: The maximum number of batches to buffer for each + branch queue. Controls memory usage and creates + backpressure. Defaults to 1. + + Returns: + A tuple containing: + - A dictionary where keys are the branch names and values are lists + of all items processed by that branch's transformer. + - A merged dictionary of all context values from all branches. + + Note: + This operation consumes the pipeline's iterator, making subsequent + operations on the same pipeline return empty results. + """ + if not branches: + self.consume() + return {}, {} + + source_iterator = self.processed_data + branch_items = list(branches.items()) + num_branches = len(branch_items) + final_results: dict[str, list[Any]] = {} + + queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] + + def producer() -> None: + """Reads from the source and distributes batches to ALL branch queues.""" + # Use itertools.batched for clean and efficient batch creation. + for batch_tuple in itertools.batched(source_iterator, batch_size): + # The batch is a tuple; convert to a list for consumers. + batch_list = list(batch_tuple) + for q in queues: + q.put(batch_list) + + # Signal to all consumers that the stream is finished. + for q in queues: + q.put(None) + + def consumer( + transformer: Transformer, queue: Queue, context_handle: IContextHandle + ) -> tuple[list[Any], dict[str, Any]]: + """Consumes batches from a queue and processes them through a dedicated pipeline.""" + + def stream_from_queue() -> Iterator[T]: + while (batch := queue.get()) is not None: + yield from batch + + # Create a new pipeline for this branch but share the parent's context manager + # This ensures all branches share the same context + branch_pipeline = Pipeline(stream_from_queue(), context_manager=context_handle.create_proxy()) # type: ignore + + # Apply the transformer to the branch pipeline and get results + result_list, branch_context = branch_pipeline.apply(transformer).to_list() + + return result_list, branch_context + + with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: + executor.submit(producer) + + future_to_name = { + executor.submit(consumer, transformer, queues[i], self.context_manager.get_handle()): name + for i, (name, transformer) in enumerate(branch_items) + } + + # Collect results - context is shared through the same context manager + for future in as_completed(future_to_name): + name = future_to_name[future] + try: + result_list, branch_context = future.result() + final_results[name] = result_list + except Exception: + final_results[name] = [] + + # After all threads complete, get the final context state + final_context = self.context_manager.to_dict() + return final_results, final_context diff --git a/laygo/transformers/http.py b/laygo/transformers/http.py index 8385a47..5127eb1 100644 --- a/laygo/transformers/http.py +++ b/laygo/transformers/http.py @@ -16,8 +16,9 @@ import requests +from laygo.context import IContextManager +from laygo.context import SimpleContextManager from laygo.errors import ErrorHandler -from laygo.helpers import PipelineContext from laygo.transformers.transformer import ChunkErrorHandler from laygo.transformers.transformer import PipelineFunction from laygo.transformers.transformer import Transformer @@ -85,6 +86,8 @@ def __init__( self.max_workers = max_workers self.session = requests.Session() self._worker_url: str | None = None + # HTTP transformers always use a simple context manager to avoid serialization issues + self._default_context = SimpleContextManager() def _finalize_config(self) -> None: """Determine the final worker URL, generating one if needed. @@ -107,7 +110,7 @@ def _finalize_config(self) -> None: self.endpoint = path.lstrip("/") self._worker_url = f"{self.base_url}/{self.endpoint}" - def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: + def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]: """Execute distributed processing on the data (CLIENT-SIDE). This method is called by the Pipeline to start distributed processing. @@ -115,11 +118,14 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) - Args: data: The input data to process. - context: Optional pipeline context (currently not used in HTTP mode). + context: Optional pipeline context. HTTP transformers always use their + internal SimpleContextManager regardless of the provided context. Returns: An iterator over the processed data. """ + run_context = self._default_context + self._finalize_config() def process_chunk(chunk: list) -> list: @@ -143,18 +149,24 @@ def process_chunk(chunk: list) -> list: print(f"Error calling worker {self._worker_url}: {e}") return [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - chunk_iterator = self._chunk_generator(data) - futures = {executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunk_iterator, self.max_workers)} - while futures: - done, futures = wait(futures, return_when=FIRST_COMPLETED) - for future in done: - yield from future.result() - try: - new_chunk = next(chunk_iterator) - futures.add(executor.submit(process_chunk, new_chunk)) - except StopIteration: - continue + try: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + chunk_iterator = self._chunk_generator(data) + futures = { + executor.submit(process_chunk, chunk) for chunk in itertools.islice(chunk_iterator, self.max_workers) + } + while futures: + done, futures = wait(futures, return_when=FIRST_COMPLETED) + for future in done: + yield from future.result() + try: + new_chunk = next(chunk_iterator) + futures.add(executor.submit(process_chunk, new_chunk)) + except StopIteration: + continue + finally: + # Always clean up our context since we always use the default one + run_context.shutdown() def get_route(self): """Get the route configuration for registering this transformer as a worker. @@ -167,7 +179,7 @@ def get_route(self): """ self._finalize_config() - def worker_view_func(chunk: list, context: PipelineContext): + def worker_view_func(chunk: list, context: IContextManager): """The actual worker logic for this transformer. Args: @@ -226,6 +238,6 @@ def catch[U]( super().catch(sub_pipeline_builder, on_error) return self # type: ignore - def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "HTTPTransformer[In, Out]": + def short_circuit(self, function: Callable[[IContextManager], bool | None]) -> "HTTPTransformer[In, Out]": super().short_circuit(function) return self diff --git a/laygo/transformers/parallel.py b/laygo/transformers/parallel.py index a85c576..f30b4e4 100644 --- a/laygo/transformers/parallel.py +++ b/laygo/transformers/parallel.py @@ -4,48 +4,41 @@ from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator -from collections.abc import MutableMapping from concurrent.futures import FIRST_COMPLETED -from concurrent.futures import Future from concurrent.futures import wait import copy import itertools -import multiprocessing as mp -from multiprocessing.managers import DictProxy from typing import Any from typing import Union from typing import overload -from loky import ProcessPoolExecutor from loky import get_reusable_executor +from laygo.context import ParallelContextManager +from laygo.context.types import IContextHandle +from laygo.context.types import IContextManager from laygo.errors import ErrorHandler -from laygo.helpers import PipelineContext from laygo.transformers.transformer import ChunkErrorHandler from laygo.transformers.transformer import InternalTransformer from laygo.transformers.transformer import PipelineFunction from laygo.transformers.transformer import Transformer -def _process_chunk_for_multiprocessing[In, Out]( - transformer: InternalTransformer[In, Out], - shared_context: MutableMapping[str, Any], +def _worker_process_chunk[In, Out]( + transformer_logic: InternalTransformer[In, Out], + context_handle: IContextHandle, chunk: list[In], ) -> list[Out]: - """Process a single chunk at the top level. - - This function is designed to work with 'loky' which uses cloudpickle - to serialize the 'transformer' object. - - Args: - transformer: The transformation function to apply. - shared_context: The shared context for processing. - chunk: The data chunk to process. - - Returns: - The processed chunk. """ - return transformer(chunk, shared_context) # type: ignore + Top-level function executed by each worker process. + It reconstructs the context proxy from the handle and runs the transformation. + """ + context_proxy = context_handle.create_proxy() + try: + return transformer_logic(chunk, context_proxy) + finally: + # The proxy's shutdown is a no-op, but it's good practice to call it. + context_proxy.shutdown() def createParallelTransformer[T]( @@ -100,6 +93,8 @@ def __init__( super().__init__(chunk_size, transformer) self.max_workers = max_workers self.ordered = ordered + # Rule 3: Parallel transformers create a parallel context manager by default. + self._default_context = ParallelContextManager() @classmethod def from_transformer[T, U]( @@ -127,129 +122,57 @@ def from_transformer[T, U]( ordered=ordered, ) - def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: - """Execute the transformer on data concurrently. - - It uses the shared context provided by the Pipeline, if available. - - Args: - data: The input data to process. - context: Optional pipeline context for shared state. - - Returns: - An iterator over the transformed data. - """ - run_context = context if context is not None else self.context - - # Detect if the context is already managed by the Pipeline. - is_managed_context = isinstance(run_context, DictProxy) - - if is_managed_context: - # Use the existing shared context and lock from the Pipeline. - shared_context = run_context - yield from self._execute_with_context(data, shared_context) - else: - # Fallback for standalone use: create a temporary manager. - with mp.Manager() as manager: - initial_ctx_data = dict(run_context) - shared_context = manager.dict(initial_ctx_data) - if "lock" not in shared_context: - shared_context["lock"] = manager.Lock() - - yield from self._execute_with_context(data, shared_context) - - # Copy results back to the original non-shared context. - final_context_state = dict(shared_context) - final_context_state.pop("lock", None) - run_context.update(final_context_state) + def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]: + """Execute the transformer by distributing chunks to a process pool.""" + run_context = context if context is not None else self._default_context - def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]: - """Execute the transformation logic with a given context. + # Get the picklable handle from the context manager. + context_handle = run_context.get_handle() - Args: - data: The input data to process. - shared_context: The shared context for the execution. - - Returns: - An iterator over the transformed data. - """ executor = get_reusable_executor(max_workers=self.max_workers) - chunks_to_process = self._chunk_generator(data) + gen_func = self._ordered_generator if self.ordered else self._unordered_generator - processed_chunks_iterator = gen_func(chunks_to_process, executor, shared_context) - for result_chunk in processed_chunks_iterator: - yield from result_chunk + try: + processed_chunks_iterator = gen_func(chunks_to_process, executor, context_handle) + for result_chunk in processed_chunks_iterator: + yield from result_chunk + finally: + if run_context is self._default_context: + self._default_context.shutdown() def _ordered_generator( self, chunks_iter: Iterator[list[In]], - executor: ProcessPoolExecutor, - shared_context: MutableMapping[str, Any], + executor, + context_handle: IContextHandle, ) -> Iterator[list[Out]]: - """Generate results in their original order. - - Args: - chunks_iter: Iterator over data chunks. - executor: The process pool executor. - shared_context: The shared context for processing. - - Returns: - An iterator over processed chunks in order. - """ - futures: deque[Future[list[Out]]] = deque() + """Generate results in their original order.""" + futures = deque() for _ in range(self.max_workers + 1): try: chunk = next(chunks_iter) - futures.append( - executor.submit( - _process_chunk_for_multiprocessing, - self.transformer, - shared_context, - chunk, - ) - ) + futures.append(executor.submit(_worker_process_chunk, self.transformer, context_handle, chunk)) except StopIteration: break while futures: yield futures.popleft().result() try: chunk = next(chunks_iter) - futures.append( - executor.submit( - _process_chunk_for_multiprocessing, - self.transformer, - shared_context, - chunk, - ) - ) + futures.append(executor.submit(_worker_process_chunk, self.transformer, context_handle, chunk)) except StopIteration: continue def _unordered_generator( self, chunks_iter: Iterator[list[In]], - executor: ProcessPoolExecutor, - shared_context: MutableMapping[str, Any], + executor, + context_handle: IContextHandle, ) -> Iterator[list[Out]]: - """Generate results as they complete. - - Args: - chunks_iter: Iterator over data chunks. - executor: The process pool executor. - shared_context: The shared context for processing. - - Returns: - An iterator over processed chunks as they complete. - """ + """Generate results as they complete.""" futures = { - executor.submit( - _process_chunk_for_multiprocessing, - self.transformer, - shared_context, - chunk, - ) + executor.submit(_worker_process_chunk, self.transformer, context_handle, chunk) for chunk in itertools.islice(chunks_iter, self.max_workers + 1) } while futures: @@ -258,14 +181,7 @@ def _unordered_generator( yield future.result() try: chunk = next(chunks_iter) - futures.add( - executor.submit( - _process_chunk_for_multiprocessing, - self.transformer, - shared_context, - chunk, - ) - ) + futures.add(executor.submit(_worker_process_chunk, self.transformer, context_handle, chunk)) except StopIteration: continue @@ -315,6 +231,6 @@ def catch[U]( super().catch(sub_pipeline_builder, on_error) return self # type: ignore - def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "ParallelTransformer[In, Out]": + def short_circuit(self, function: Callable[[IContextManager], bool | None]) -> "ParallelTransformer[In, Out]": super().short_circuit(function) return self diff --git a/laygo/transformers/threaded.py b/laygo/transformers/threaded.py index f49d643..11dedca 100644 --- a/laygo/transformers/threaded.py +++ b/laygo/transformers/threaded.py @@ -4,7 +4,6 @@ from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator -from collections.abc import MutableMapping from concurrent.futures import FIRST_COMPLETED from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor @@ -12,14 +11,13 @@ import copy from functools import partial import itertools -from multiprocessing.managers import DictProxy -import threading from typing import Any from typing import Union from typing import overload +from laygo.context import IContextManager +from laygo.context import ParallelContextManager from laygo.errors import ErrorHandler -from laygo.helpers import PipelineContext from laygo.transformers.transformer import DEFAULT_CHUNK_SIZE from laygo.transformers.transformer import ChunkErrorHandler from laygo.transformers.transformer import InternalTransformer @@ -27,12 +25,6 @@ from laygo.transformers.transformer import Transformer -class ThreadedPipelineContextType(PipelineContext): - """A specific context type for threaded transformers that includes a lock.""" - - lock: threading.Lock - - def createThreadedTransformer[T]( _type_hint: type[T], max_workers: int = 4, @@ -85,6 +77,9 @@ def __init__( super().__init__(chunk_size, transformer) self.max_workers = max_workers self.ordered = ordered + # Rule 3: Threaded transformers create a parallel context manager by default. + # This is because threads share memory, so a thread-safe (locking) manager is required. + self._default_context = ParallelContextManager() @classmethod def from_transformer[T, U]( @@ -112,7 +107,7 @@ def from_transformer[T, U]( ordered=ordered, ) - def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: + def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]: """Execute the transformer on data concurrently. It uses the shared context provided by the Pipeline, if available. @@ -124,24 +119,18 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) - Returns: An iterator over the transformed data. """ - run_context = context if context is not None else self.context - - # Detect if the context is already managed by the Pipeline. - is_managed_context = isinstance(run_context, DictProxy) - - if is_managed_context: - # Use the existing shared context and lock from the Pipeline. - shared_context = run_context - yield from self._execute_with_context(data, shared_context) - else: - # Fallback for standalone use: create a thread-safe context. - # Since threads share memory, we can use the context directly with a lock. - if "lock" not in run_context: - run_context["lock"] = threading.Lock() + run_context = context if context is not None else self._default_context + # Since threads share memory, we can pass the context manager directly. + # No handle/proxy mechanism is needed, but the locking inside + # ParallelContextManager is crucial for thread safety. + try: yield from self._execute_with_context(data, run_context) + finally: + if run_context is self._default_context: + self._default_context.shutdown() - def _execute_with_context(self, data: Iterable[In], shared_context: MutableMapping[str, Any]) -> Iterator[Out]: + def _execute_with_context(self, data: Iterable[In], shared_context: IContextManager) -> Iterator[Out]: """Execute the transformation logic with a given context. Args: @@ -152,7 +141,7 @@ def _execute_with_context(self, data: Iterable[In], shared_context: MutableMappi An iterator over the transformed data. """ - def process_chunk(chunk: list[In], shared_context: MutableMapping[str, Any]) -> list[Out]: + def process_chunk(chunk: list[In], shared_context: IContextManager) -> list[Out]: """Process a single chunk by passing the chunk and context explicitly. Args: @@ -257,6 +246,6 @@ def catch[U]( super().catch(sub_pipeline_builder, on_error) return self # type: ignore - def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "ThreadedTransformer[In, Out]": + def short_circuit(self, function: Callable[[IContextManager], bool | None]) -> "ThreadedTransformer[In, Out]": super().short_circuit(function) return self diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index 17c4f73..b84354a 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -12,20 +12,21 @@ from typing import Union from typing import overload +from laygo.context import IContextManager +from laygo.context import SimpleContextManager from laygo.errors import ErrorHandler -from laygo.helpers import PipelineContext from laygo.helpers import is_context_aware from laygo.helpers import is_context_aware_reduce DEFAULT_CHUNK_SIZE = 1000 -type PipelineFunction[Out, T] = Callable[[Out], T] | Callable[[Out, PipelineContext], T] -type PipelineReduceFunction[U, Out] = Callable[[U, Out], U] | Callable[[U, Out, PipelineContext], U] +type PipelineFunction[Out, T] = Callable[[Out], T] | Callable[[Out, IContextManager], T] +type PipelineReduceFunction[U, Out] = Callable[[U, Out], U] | Callable[[U, Out, IContextManager], U] # The internal transformer function signature is changed to explicitly accept a context. -type InternalTransformer[In, Out] = Callable[[list[In], PipelineContext], list[Out]] -type ChunkErrorHandler[In, U] = Callable[[list[In], Exception, PipelineContext], list[U]] +type InternalTransformer[In, Out] = Callable[[list[In], IContextManager], list[Out]] +type ChunkErrorHandler[In, U] = Callable[[list[In], Exception, IContextManager], list[U]] def createTransformer[T](_type_hint: type[T], chunk_size: int = DEFAULT_CHUNK_SIZE) -> "Transformer[T, T]": @@ -96,11 +97,12 @@ def __init__( transformer: Optional existing transformer logic to use. """ self.chunk_size = chunk_size - self.context: PipelineContext = PipelineContext() # The default transformer now accepts and ignores a context argument. self.transformer: InternalTransformer[In, Out] = transformer or (lambda chunk, ctx: chunk) # type: ignore self.error_handler = ErrorHandler() self._chunk_generator = build_chunk_generator(chunk_size) if chunk_size else lambda x: iter([list(x)]) + # Rule 2: Transformers create a simple context manager by default for standalone use. + self._default_context = SimpleContextManager() @classmethod def from_transformer[T, U]( @@ -151,7 +153,7 @@ def on_error(self, handler: ChunkErrorHandler[In, Out] | ErrorHandler) -> "Trans self.error_handler.on_error(handler) # type: ignore return self - def _pipe[U](self, operation: Callable[[list[Out], PipelineContext], list[U]]) -> "Transformer[In, U]": + def _pipe[U](self, operation: Callable[[list[Out], IContextManager], list[U]]) -> "Transformer[In, U]": """Compose the current transformer with a new context-aware operation. Args: @@ -175,9 +177,11 @@ def map[U](self, function: PipelineFunction[Out, U]) -> "Transformer[In, U]": A new transformer with the mapping operation applied. """ if is_context_aware(function): - return self._pipe(lambda chunk, ctx: [function(x, ctx) for x in chunk]) + context_aware_func: Callable[[Out, IContextManager], U] = function # type: ignore + return self._pipe(lambda chunk, ctx: [context_aware_func(x, ctx) for x in chunk]) - return self._pipe(lambda chunk, _ctx: [function(x) for x in chunk]) # type: ignore + non_context_func: Callable[[Out], U] = function # type: ignore + return self._pipe(lambda chunk, _ctx: [non_context_func(x) for x in chunk]) def filter(self, predicate: PipelineFunction[Out, bool]) -> "Transformer[In, Out]": """Filter elements, passing context explicitly to the predicate function. @@ -190,9 +194,11 @@ def filter(self, predicate: PipelineFunction[Out, bool]) -> "Transformer[In, Out A transformer with the filtering operation applied. """ if is_context_aware(predicate): - return self._pipe(lambda chunk, ctx: [x for x in chunk if predicate(x, ctx)]) + context_aware_predicate: Callable[[Out, IContextManager], bool] = predicate # type: ignore + return self._pipe(lambda chunk, ctx: [x for x in chunk if context_aware_predicate(x, ctx)]) - return self._pipe(lambda chunk, _ctx: [x for x in chunk if predicate(x)]) # type: ignore + non_context_predicate: Callable[[Out], bool] = predicate # type: ignore + return self._pipe(lambda chunk, _ctx: [x for x in chunk if non_context_predicate(x)]) @overload def flatten[T](self: "Transformer[In, list[T]]") -> "Transformer[In, T]": ... @@ -246,7 +252,7 @@ def tap( case Transformer() as tapped_transformer: tapped_func = tapped_transformer.transformer - def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: + def operation(chunk: list[Out], ctx: IContextManager) -> list[Out]: # Execute the tapped transformer's logic on the chunk for side-effects. _ = tapped_func(chunk, ctx) # Return the original chunk to continue the main pipeline. @@ -257,9 +263,11 @@ def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: # Case 2: The argument is a callable function case function if callable(function): if is_context_aware(function): - return self._pipe(lambda chunk, ctx: [x for x in chunk if function(x, ctx) or True]) + context_aware_func: Callable[[Out, IContextManager], Any] = function # type: ignore + return self._pipe(lambda chunk, ctx: [x for x in chunk if context_aware_func(x, ctx) or True]) - return self._pipe(lambda chunk, _ctx: [x for x in chunk if function(x) or True]) # type: ignore + non_context_func: Callable[[Out], Any] = function # type: ignore + return self._pipe(lambda chunk, _ctx: [x for x in chunk if non_context_func(x) or True]) # Default case for robustness case _: @@ -279,7 +287,7 @@ def apply[T](self, t: Callable[[Self], "Transformer[In, T]"]) -> "Transformer[In def loop( self, loop_transformer: "Transformer[Out, Out]", - condition: Callable[[list[Out]], bool] | Callable[[list[Out], PipelineContext], bool], + condition: Callable[[list[Out]], bool] | Callable[[list[Out], IContextManager], bool], max_iterations: int | None = None, ) -> "Transformer[In, Out]": """ @@ -295,7 +303,7 @@ def loop( input and output types must match the current pipeline's output type (`Transformer[Out, Out]`). condition: A function that takes the current chunk (and optionally - the `PipelineContext`) and returns `True` to continue the + the `IContextManager`) and returns `True` to continue the loop, or `False` to stop. max_iterations: An optional integer to limit the number of repetitions and prevent infinite loops. @@ -306,7 +314,7 @@ def loop( looped_func = loop_transformer.transformer condition_is_context_aware = is_context_aware(condition) - def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: + def operation(chunk: list[Out], ctx: IContextManager) -> list[Out]: condition_checker = ( # noqa: E731 lambda current_chunk: condition(current_chunk, ctx) if condition_is_context_aware else condition(current_chunk) # type: ignore ) @@ -324,7 +332,7 @@ def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: return self._pipe(operation) - def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: + def __call__(self, data: Iterable[In], context: IContextManager | None = None) -> Iterator[Out]: """Execute the transformer on a data source. It uses the provided `context` by reference. If none is provided, it uses @@ -332,17 +340,22 @@ def __call__(self, data: Iterable[In], context: PipelineContext | None = None) - Args: data: The input data to process. - context: Optional pipeline context to use during processing. + context: Optional context (IContextManager or dict) to use during processing. Returns: An iterator over the transformed data. """ - # Use the provided context by reference, or default to the instance's context. - run_context = context or self.context - for chunk in self._chunk_generator(data): - # The context is now passed explicitly through the transformer chain. - yield from self.transformer(chunk, run_context) + # Use the provided context by reference, or default to a simple context. + run_context = context if context is not None else self._default_context + + try: + for chunk in self._chunk_generator(data): + # The context is now passed explicitly through the transformer chain. + yield from self.transformer(chunk, run_context) + finally: + if run_context is self._default_context: + self._default_context.shutdown() @overload def reduce[U]( @@ -362,7 +375,7 @@ def reduce[U]( initial: U, *, per_chunk: Literal[False] = False, - ) -> Callable[[Iterable[In], PipelineContext | None], Iterator[U]]: + ) -> Callable[[Iterable[In], IContextManager | None], Iterator[U]]: """Reduces the entire dataset to a single value (terminal operation).""" ... @@ -372,7 +385,7 @@ def reduce[U]( initial: U, *, per_chunk: bool = False, - ) -> Union["Transformer[In, U]", Callable[[Iterable[In], PipelineContext | None], Iterator[U]]]: # type: ignore + ) -> Union["Transformer[In, U]", Callable[[Iterable[In], IContextManager | None], Iterator[U]]]: # type: ignore """Reduces elements to a single value, either per-chunk or for the entire dataset.""" if per_chunk: # --- Efficient "per-chunk" logic (chainable) --- @@ -380,43 +393,49 @@ def reduce[U]( # The context-awareness check is now hoisted and executed only ONCE. if is_context_aware_reduce(function): # We define a specialized operation for the context-aware case. - def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]: + context_aware_reduce_func: Callable[[U, Out, IContextManager], U] = function # type: ignore + + def reduce_chunk_operation(chunk: list[Out], ctx: IContextManager) -> list[U]: if not chunk: return [] # No check happens here; we know the function needs the context. - wrapper = lambda acc, val: function(acc, val, ctx) # noqa: E731, W291 + wrapper = lambda acc, val: context_aware_reduce_func(acc, val, ctx) # noqa: E731 return [reduce(wrapper, chunk, initial)] else: # We define a specialized, simpler operation for the non-aware case. - def reduce_chunk_operation(chunk: list[Out], ctx: PipelineContext) -> list[U]: + non_context_reduce_func: Callable[[U, Out], U] = function # type: ignore + + def reduce_chunk_operation(chunk: list[Out], ctx: IContextManager) -> list[U]: if not chunk: return [] # No check happens here; the function is called directly. - return [reduce(function, chunk, initial)] # type: ignore + return [reduce(non_context_reduce_func, chunk, initial)] return self._pipe(reduce_chunk_operation) # --- "Entire dataset" logic with `match` (terminal) --- match is_context_aware_reduce(function): case True: + context_aware_reduce_func: Callable[[U, Out, IContextManager], U] = function # type: ignore - def _reduce_with_context(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]: - run_context = context or self.context + def _reduce_with_context(data: Iterable[In], context: IContextManager | None = None) -> Iterator[U]: + run_context = context or self._default_context data_iterator = self(data, run_context) def function_wrapper(acc, val): - return function(acc, val, run_context) # type: ignore + return context_aware_reduce_func(acc, val, run_context) yield reduce(function_wrapper, data_iterator, initial) return _reduce_with_context case False: + non_context_reduce_func: Callable[[U, Out], U] = function # type: ignore - def _reduce(data: Iterable[In], context: PipelineContext | None = None) -> Iterator[U]: - run_context = context or self.context + def _reduce(data: Iterable[In], context: IContextManager | None = None) -> Iterator[U]: + run_context = context or self._default_context data_iterator = self(data, run_context) - yield reduce(function, data_iterator, initial) # type: ignore + yield reduce(non_context_reduce_func, data_iterator, initial) return _reduce @@ -447,7 +466,7 @@ def catch[U]( sub_pipeline = sub_pipeline_builder(temp_transformer) sub_transformer_func = sub_pipeline.transformer - def operation(chunk: list[Out], ctx: PipelineContext) -> list[U]: + def operation(chunk: list[Out], ctx: IContextManager) -> list[U]: try: # Attempt to process the whole chunk with the sub-pipeline return sub_transformer_func(chunk, ctx) @@ -458,7 +477,7 @@ def operation(chunk: list[Out], ctx: PipelineContext) -> list[U]: return self._pipe(operation) # type: ignore - def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> "Transformer[In, Out]": + def short_circuit(self, function: Callable[[IContextManager], bool | None]) -> "Transformer[In, Out]": """Execute a function on the context before processing the next step for a chunk. This can be used for short-circuiting by raising an exception based on the @@ -467,7 +486,7 @@ def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> " operation in the chain. Args: - function: A callable that accepts the `PipelineContext` as its sole + function: A callable that accepts the context (IContextManager or dict) as its sole argument. If it returns True, the pipeline is stopped with an exception. @@ -479,7 +498,7 @@ def short_circuit(self, function: Callable[[PipelineContext], bool | None]) -> " condition has been met. """ - def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: + def operation(chunk: list[Out], ctx: IContextManager) -> list[Out]: """The internal operation that wraps the user's function.""" # Execute the user's function with the current context. if function(ctx): diff --git a/tests/test_http_transformer.py b/tests/test_http_transformer.py index 97868e3..7fa01df 100644 --- a/tests/test_http_transformer.py +++ b/tests/test_http_transformer.py @@ -4,7 +4,7 @@ from laygo import HTTPTransformer from laygo import Pipeline -from laygo import PipelineContext +from laygo.context.simple import SimpleContextManager class TestHTTPTransformer: @@ -42,7 +42,7 @@ def mock_response(request, context): input_chunk = request.json() # Call the actual view function logic obtained from get_route() # We pass None for the context as it's not used in this simple case. - output_chunk = worker_view_func(chunk=input_chunk, context=PipelineContext()) + output_chunk = worker_view_func(chunk=input_chunk, context=SimpleContextManager()) return output_chunk # Use requests_mock context manager @@ -52,7 +52,7 @@ def mock_response(request, context): # 5. Run the standard Pipeline with the configured transformer initial_data = list(range(10)) # [0, 1, 2, ..., 9] pipeline = Pipeline(initial_data).apply(http_transformer) - result = pipeline.to_list() + result, _ = pipeline.to_list() # 6. Assert the final result expected_result = [12, 14, 16, 18] diff --git a/tests/test_integration.py b/tests/test_integration.py index 9fd850a..81900ba 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -2,9 +2,9 @@ from laygo import ParallelTransformer from laygo import Pipeline -from laygo import PipelineContext from laygo import Transformer from laygo import createTransformer +from laygo.context.types import IContextManager class TestPipelineTransformerBasics: @@ -13,19 +13,19 @@ class TestPipelineTransformerBasics: def test_basic_pipeline_transformer_integration(self): """Test basic pipeline and transformer integration.""" transformer = createTransformer(int).map(lambda x: x * 2).filter(lambda x: x > 5) - result = Pipeline([1, 2, 3, 4, 5]).apply(transformer).to_list() + result, _ = Pipeline([1, 2, 3, 4, 5]).apply(transformer).to_list() assert result == [6, 8, 10] def test_pipeline_context_sharing(self): """Test that context is properly shared between pipeline and transformers.""" - context = PipelineContext({"multiplier": 3, "threshold": 5}) + context = {"multiplier": 3, "threshold": 5} transformer = Transformer().map(lambda x, ctx: x * ctx["multiplier"]).filter(lambda x, ctx: x > ctx["threshold"]) - result = Pipeline([1, 2, 3]).context(context).apply(transformer).to_list() + result, _ = Pipeline([1, 2, 3]).context(context).apply(transformer).to_list() assert result == [6, 9] def test_pipeline_transform_shorthand(self): """Test pipeline transform shorthand method.""" - result = ( + result, _ = ( Pipeline([1, 2, 3, 4, 5]) .transform(lambda t: t.map(lambda x: x * 3)) .transform(lambda t: t.filter(lambda x: x > 6)) @@ -47,7 +47,7 @@ def test_etl_pattern(self): ] # Extract names of people over 28 with salary > 55000 - result = ( + result, _ = ( Pipeline(raw_data) .transform(lambda t: t.filter(lambda x: x["age"] > 28 and x["salary"] > 55000)) .transform(lambda t: t.map(lambda x: x["name"])) @@ -71,7 +71,7 @@ def validate_and_convert(x): except (ValueError, TypeError): return None - result = ( + result, _ = ( Pipeline(raw_data) .transform(lambda t: t.map(validate_and_convert)) .transform(lambda t: t.filter(lambda x: x is not None)) @@ -82,15 +82,15 @@ def validate_and_convert(x): assert valid_numbers == [1.0, 2.0, 3.0, 5.0, 7.0] -def safe_increment_and_transform(x: int, ctx: PipelineContext) -> int: - with ctx["lock"]: +def safe_increment_and_transform(x: int, ctx: IContextManager) -> int: + with ctx: ctx["processed_count"] += 1 ctx["sum_total"] += x return x * 2 -def count_and_transform(x: int, ctx: PipelineContext) -> int: - with ctx["lock"]: +def count_and_transform(x: int, ctx: IContextManager) -> int: + with ctx: ctx["items_processed"] += 1 if x % 2 == 0: ctx["even_count"] += 1 @@ -99,17 +99,17 @@ def count_and_transform(x: int, ctx: PipelineContext) -> int: return x * 3 -def stage1_processor(x: int, ctx: PipelineContext) -> int: +def stage1_processor(x: int, ctx: IContextManager) -> int: """First stage processing with context update.""" - with ctx["lock"]: + with ctx: ctx["stage1_processed"] += 1 ctx["total_sum"] += x return x * 2 -def stage2_processor(x: int, ctx: PipelineContext) -> int: +def stage2_processor(x: int, ctx: IContextManager) -> int: """Second stage processing with context update.""" - with ctx["lock"]: + with ctx: ctx["stage2_processed"] += 1 ctx["total_sum"] += x # Add transformed value too return x + 10 @@ -123,46 +123,46 @@ def test_parallel_transformer_basic_integration(self): parallel_transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2) parallel_transformer = parallel_transformer.map(lambda x: x * 2).filter(lambda x: x > 5) - result = Pipeline([1, 2, 3, 4, 5]).apply(parallel_transformer).to_list() + result, _ = Pipeline([1, 2, 3, 4, 5]).apply(parallel_transformer).to_list() assert sorted(result) == [6, 8, 10] def test_parallel_transformer_with_context_modification(self): """Test parallel transformer safely modifying shared context.""" - context = PipelineContext({"processed_count": 0, "sum_total": 0}) + context = {"processed_count": 0, "sum_total": 0} parallel_transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=2) parallel_transformer = parallel_transformer.map(safe_increment_and_transform) data = [1, 2, 3, 4, 5] - result = Pipeline(data).context(context).apply(parallel_transformer).to_list() + result, processed_context = Pipeline(data).context(context).apply(parallel_transformer).to_list() # Verify transformation results assert sorted(result) == [2, 4, 6, 8, 10] # Verify context was safely modified - assert context["processed_count"] == len(data) - assert context["sum_total"] == sum(data) + assert processed_context["processed_count"] == len(data) + assert processed_context["sum_total"] == sum(data) def test_pipeline_accesses_modified_context(self): """Test that pipeline can access context data modified by parallel transformer.""" - context = PipelineContext({"items_processed": 0, "even_count": 0, "odd_count": 0}) + context = {"items_processed": 0, "even_count": 0, "odd_count": 0} parallel_transformer = ParallelTransformer[int, int](max_workers=2, chunk_size=3) parallel_transformer = parallel_transformer.map(count_and_transform) data = [1, 2, 3, 4, 5, 6] pipeline = Pipeline(data).context(context) - result = pipeline.apply(parallel_transformer).to_list() + result, _ = pipeline.apply(parallel_transformer).to_list() # Verify results and context access assert sorted(result) == [3, 6, 9, 12, 15, 18] - assert pipeline.ctx["items_processed"] == 6 - assert pipeline.ctx["even_count"] == 3 # 2, 4, 6 - assert pipeline.ctx["odd_count"] == 3 # 1, 3, 5 + assert pipeline.context_manager["items_processed"] == 6 + assert pipeline.context_manager["even_count"] == 3 # 2, 4, 6 + assert pipeline.context_manager["odd_count"] == 3 # 1, 3, 5 def test_multiple_parallel_transformers_chaining(self): """Test chaining multiple parallel transformers with shared context.""" # Shared context for statistics across transformations - context = PipelineContext({"stage1_processed": 0, "stage2_processed": 0, "total_sum": 0}) + context = {"stage1_processed": 0, "stage2_processed": 0, "total_sum": 0} # Create two parallel transformers stage1 = ParallelTransformer[int, int](max_workers=2, chunk_size=2).map(stage1_processor) @@ -172,7 +172,7 @@ def test_multiple_parallel_transformers_chaining(self): # Chain parallel transformers in pipeline pipeline = Pipeline(data).context(context) - result = ( + result, _ = ( pipeline.apply(stage1) # [2, 4, 6, 8, 10] .apply(stage2) # [12, 14, 16, 18, 20] .to_list() @@ -184,7 +184,7 @@ def test_multiple_parallel_transformers_chaining(self): assert result == expected_final # Verify context reflects both stages - final_context = pipeline.ctx + final_context = pipeline.context_manager assert final_context["stage1_processed"] == 5 assert final_context["stage2_processed"] == 5 @@ -199,11 +199,11 @@ def test_pipeline_context_isolation_with_parallel_processing(self): # Create base context structure def create_context(): - return PipelineContext({"count": 0}) + return {"count": 0} - def increment_counter(x: int, ctx: PipelineContext) -> int: + def increment_counter(x: int, ctx: IContextManager) -> int: """Increment counter in context.""" - with ctx["lock"]: + with ctx: ctx["count"] += 1 return x * 2 @@ -217,16 +217,16 @@ def increment_counter(x: int, ctx: PipelineContext) -> int: pipeline2 = Pipeline(data).context(create_context()) # Process with both pipelines - result1 = pipeline1.apply(parallel_transformer).to_list() - result2 = pipeline2.apply(parallel_transformer).to_list() + result1, pipeline1_context = pipeline1.apply(parallel_transformer).to_list() + result2, pipeline2_context = pipeline2.apply(parallel_transformer).to_list() # Both should have same transformation results assert result1 == [2, 4, 6] assert result2 == [2, 4, 6] # But contexts should be isolated - assert pipeline1.ctx["count"] == 3 - assert pipeline2.ctx["count"] == 3 + assert pipeline1_context["count"] == 3 + assert pipeline2_context["count"] == 3 # Verify they are different context objects - assert pipeline1.ctx is not pipeline2.ctx + assert pipeline1_context is not pipeline2_context diff --git a/tests/test_parallel_transformer.py b/tests/test_parallel_transformer.py index e7fb67f..813c5c6 100644 --- a/tests/test_parallel_transformer.py +++ b/tests/test_parallel_transformer.py @@ -5,7 +5,8 @@ from laygo import ErrorHandler from laygo import ParallelTransformer -from laygo import PipelineContext +from laygo.context import IContextManager +from laygo.context import ParallelContextManager from laygo.transformers.parallel import createParallelTransformer from laygo.transformers.transformer import createTransformer @@ -82,16 +83,18 @@ def test_tap_side_effects(self): assert sorted(side_effects) == [1, 2, 3, 4] -def safe_increment(x: int, ctx: PipelineContext) -> int: - with ctx["lock"]: +def safe_increment(x: int, ctx: IContextManager) -> int: + # Safe cast since we know ParallelContextManager implements context manager protocol + with ctx: # type: ignore current_items = ctx["items"] time.sleep(0.001) ctx["items"] = current_items + 1 return x * 2 -def update_stats(x: int, ctx: PipelineContext) -> int: - with ctx["lock"]: +def update_stats(x: int, ctx: IContextManager) -> int: + # Safe cast since we know ParallelContextManager implements context manager protocol + with ctx: # type: ignore ctx["total_sum"] += x ctx["item_count"] += 1 ctx["max_value"] = max(ctx["max_value"], x) @@ -103,25 +106,16 @@ class TestParallelTransformerContextSupport: def test_map_with_context(self): """Test map with context-aware function in concurrent execution.""" - context = PipelineContext({"multiplier": 3}) + context = ParallelContextManager({"multiplier": 3}) transformer = createParallelTransformer(int).map(lambda x, ctx: x * ctx["multiplier"]) result = list(transformer([1, 2, 3], context)) assert result == [3, 6, 9] - def test_context_modification_with_locking(self): - """Test safe context modification with locking in concurrent execution.""" - context = PipelineContext({"items": 0}) - - transformer = createParallelTransformer(int, max_workers=4, chunk_size=1).map(safe_increment) - data = list(range(1, 11)) - result = list(transformer(data, context)) - - assert sorted(result) == sorted([x * 2 for x in data]) - assert context["items"] == len(data) - def test_multiple_context_values_modification(self): """Test modifying multiple context values safely.""" - context = PipelineContext({"total_sum": 0, "item_count": 0, "max_value": 0}) + from laygo.context import ParallelContextManager + + context = ParallelContextManager({"total_sum": 0, "item_count": 0, "max_value": 0}) transformer = createParallelTransformer(int, max_workers=3, chunk_size=2).map(update_stats) data = [1, 5, 3, 8, 2, 7, 4, 6] diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ef74138..bb1f2ad 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,7 @@ """Tests for the Pipeline class.""" from laygo import Pipeline -from laygo import PipelineContext +from laygo.context.types import IContextManager from laygo.transformers.transformer import createTransformer @@ -11,12 +11,14 @@ class TestPipelineBasics: def test_single_iterable_creation(self): """Test creating pipeline from single iterable.""" pipeline = Pipeline([1, 2, 3]) - assert pipeline.to_list() == [1, 2, 3] + result, _ = pipeline.to_list() + assert result == [1, 2, 3] def test_multiple_iterables_creation(self): """Test creating pipeline from multiple iterables.""" pipeline = Pipeline([1, 2], [3, 4], [5]) - assert pipeline.to_list() == [1, 2, 3, 4, 5] + result, _ = pipeline.to_list() + assert result == [1, 2, 3, 4, 5] def test_pipeline_iteration(self): """Test pipeline is iterable.""" @@ -26,8 +28,8 @@ def test_pipeline_iteration(self): def test_iterator_consumption(self): """Test that to_list consumes the iterator.""" pipeline = Pipeline([1, 2, 3]) - first_result = pipeline.to_list() - second_result = pipeline.to_list() + first_result, _ = pipeline.to_list() + second_result, _ = pipeline.to_list() assert first_result == [1, 2, 3] assert second_result == [] # Iterator is consumed @@ -38,7 +40,7 @@ class TestPipelineTransformations: def test_apply_with_transformer(self): """Test apply with transformer object.""" transformer = createTransformer(int).map(lambda x: x * 2).filter(lambda x: x > 4) - result = Pipeline([1, 2, 3, 4]).apply(transformer).to_list() + result, _ = Pipeline([1, 2, 3, 4]).apply(transformer).to_list() assert result == [6, 8] def test_apply_with_generator_function(self): @@ -48,17 +50,17 @@ def double_generator(data): for item in data: yield item * 2 - result = Pipeline([1, 2, 3]).apply(double_generator).to_list() + result, _ = Pipeline([1, 2, 3]).apply(double_generator).to_list() assert result == [2, 4, 6] def test_transform_shorthand(self): """Test transform shorthand method.""" - result = Pipeline([1, 2, 3, 4]).transform(lambda t: t.map(lambda x: x * 2).filter(lambda x: x > 4)).to_list() + result, _ = Pipeline([1, 2, 3, 4]).transform(lambda t: t.map(lambda x: x * 2).filter(lambda x: x > 4)).to_list() assert result == [6, 8] def test_chained_transformations(self): """Test chaining multiple transformations.""" - result = ( + result, _ = ( Pipeline([1, 2, 3, 4]) .transform(lambda t: t.map(lambda x: x * 2)) .transform(lambda t: t.filter(lambda x: x > 4)) @@ -78,24 +80,24 @@ def test_each_applies_side_effects(self): def test_first_gets_n_elements(self): """Test first gets specified number of elements.""" - result = Pipeline([1, 2, 3, 4, 5]).first(3) + result, _ = Pipeline([1, 2, 3, 4, 5]).first(3) assert result == [1, 2, 3] def test_first_default_one_element(self): """Test first with no argument gets one element.""" - result = Pipeline([1, 2, 3]).first() + result, _ = Pipeline([1, 2, 3]).first() assert result == [1] def test_first_with_insufficient_data(self): """Test first when requesting more elements than available.""" - result = Pipeline([1, 2]).first(5) + result, _ = Pipeline([1, 2]).first(5) assert result == [1, 2] def test_consume_processes_without_return(self): """Test consume processes all elements without returning anything.""" side_effects = [] transformer = createTransformer(int).tap(lambda x: side_effects.append(x)) - result = Pipeline([1, 2, 3]).apply(transformer).consume() + result, _ = Pipeline([1, 2, 3]).apply(transformer).consume() assert result is None assert side_effects == [1, 2, 3] @@ -104,20 +106,20 @@ def test_consume_processes_without_return(self): class TestPipelineDataTypes: """Test pipeline with various data types.""" - def test_string_processing(self): + def test_string_transformation(self): """Test pipeline with string data.""" - result = Pipeline(["hello", "world"]).transform(lambda t: t.map(lambda x: x.upper())).to_list() + result, _ = Pipeline(["hello", "world"]).transform(lambda t: t.map(lambda x: x.upper())).to_list() assert result == ["HELLO", "WORLD"] - def test_mixed_types_processing(self): + def test_mixed_type_transformation(self): """Test pipeline with mixed data types.""" - result = Pipeline([1, "hello", 3.14]).transform(lambda t: t.map(lambda x: str(x))).to_list() + result, _ = Pipeline([1, "hello", 3.14]).transform(lambda t: t.map(lambda x: str(x))).to_list() assert result == ["1", "hello", "3.14"] - def test_complex_objects_processing(self): - """Test pipeline with complex objects.""" + def test_dict_transformation(self): + """Test pipeline with dictionary data.""" data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] - result = Pipeline(data).transform(lambda t: t.map(lambda x: x["name"])).to_list() + result, _ = Pipeline(data).transform(lambda t: t.map(lambda x: x["name"])).to_list() assert result == ["Alice", "Bob"] @@ -126,28 +128,32 @@ class TestPipelineEdgeCases: def test_empty_pipeline(self): """Test pipeline with empty data.""" - assert Pipeline([]).to_list() == [] + result, _ = Pipeline([]).to_list() + assert result == [] # Test terminal operations on empty pipeline results = [] Pipeline([]).each(lambda x: results.append(x)) assert results == [] - assert Pipeline([]).first(5) == [] - assert Pipeline([]).consume() is None + first_result, _ = Pipeline([]).first(5) + assert first_result == [] + consume_result, _ = Pipeline([]).consume() + assert consume_result is None def test_single_element_pipeline(self): """Test pipeline with single element.""" - assert Pipeline([42]).to_list() == [42] + result, _ = Pipeline([42]).to_list() + assert result == [42] def test_type_preservation(self): """Test that pipeline preserves and transforms types correctly.""" # Integers preserved - int_result = Pipeline([1, 2, 3]).to_list() + int_result, _ = Pipeline([1, 2, 3]).to_list() assert all(isinstance(x, int) for x in int_result) # Transform to strings - str_result = Pipeline([1, 2, 3]).transform(lambda t: t.map(lambda x: str(x))).to_list() + str_result, _ = Pipeline([1, 2, 3]).transform(lambda t: t.map(lambda x: str(x))).to_list() assert all(isinstance(x, str) for x in str_result) assert str_result == ["1", "2", "3"] @@ -158,7 +164,9 @@ class TestPipelinePerformance: def test_large_dataset_processing(self): """Test pipeline handles large datasets efficiently.""" large_data = list(range(10000)) - result = Pipeline(large_data).transform(lambda t: t.map(lambda x: x * 2).filter(lambda x: x % 100 == 0)).to_list() + result, _ = ( + Pipeline(large_data).transform(lambda t: t.map(lambda x: x * 2).filter(lambda x: x % 100 == 0)).to_list() + ) # Every 50th element doubled (0, 100, 200, ..., 19800) expected = [x * 2 for x in range(0, 10000, 50)] @@ -168,7 +176,7 @@ def test_chunked_processing_consistency(self): """Test that chunked processing produces consistent results.""" # Use small chunk size to test chunking behavior transformer = createTransformer(int, chunk_size=10).map(lambda x: x + 1) - result = Pipeline(list(range(100))).apply(transformer).to_list() + result, _ = Pipeline(list(range(100))).apply(transformer).to_list() expected = list(range(1, 101)) # [1, 2, 3, ..., 100] assert result == expected @@ -190,7 +198,7 @@ def second_map(x): return x + 1 # Apply buffering with 2 workers between two map operations - result = ( + result, _ = ( Pipeline(data) .transform(lambda t: t.map(first_map)) .buffer(2) # Buffer with 2 workers @@ -227,7 +235,7 @@ def test_branch_basic_functionality(self): square_branch = createTransformer(int).map(lambda x: x**2) # Execute branching - result = pipeline.branch({"doubled": double_branch, "squared": square_branch}) + result, _ = pipeline.branch({"doubled": double_branch, "squared": square_branch}) # Verify results contain processed items for each branch assert "doubled" in result @@ -247,7 +255,7 @@ def test_branch_with_empty_input(self): double_branch = createTransformer(int).map(lambda x: x * 2) square_branch = createTransformer(int).map(lambda x: x**2) - result = pipeline.branch({"doubled": double_branch, "squared": square_branch}) + result, _ = pipeline.branch({"doubled": double_branch, "squared": square_branch}) # Should return empty lists for all branches assert result == {"doubled": [], "squared": []} @@ -256,7 +264,7 @@ def test_branch_with_empty_branches_dict(self): """Test branch with empty branches dictionary.""" pipeline = Pipeline([1, 2, 3]) - result = pipeline.branch({}) + result, _ = pipeline.branch({}) # Should return empty dictionary assert result == {} @@ -267,7 +275,7 @@ def test_branch_with_single_branch(self): triple_branch = createTransformer(int).map(lambda x: x * 3) - result = pipeline.branch({"tripled": triple_branch}) + result, _ = pipeline.branch({"tripled": triple_branch}) assert len(result) == 1 assert "tripled" in result @@ -282,7 +290,7 @@ def test_branch_with_custom_queue_size(self): triple_branch = createTransformer(int).map(lambda x: x * 3) # Test with a small queue size - result = pipeline.branch( + result, _ = pipeline.branch( { "doubled": double_branch, "tripled": triple_branch, @@ -304,7 +312,7 @@ def test_branch_with_three_branches(self): add_20 = createTransformer(int).map(lambda x: x + 20) add_30 = createTransformer(int).map(lambda x: x + 30) - result = pipeline.branch({"add_10": add_10, "add_20": add_20, "add_30": add_30}) + result, _ = pipeline.branch({"add_10": add_10, "add_20": add_20, "add_30": add_30}) # Each branch gets all items: # add_10 gets all items: [1, 2, 3, 4, 5, 6, 7, 8, 9] -> [11, 12, 13, 14, 15, 16, 17, 18, 19] @@ -322,7 +330,7 @@ def test_branch_with_filtering_transformers(self): even_branch = createTransformer(int).filter(lambda x: x % 2 == 0) odd_branch = createTransformer(int).filter(lambda x: x % 2 == 1) - result = pipeline.branch({"evens": even_branch, "odds": odd_branch}) + result, _ = pipeline.branch({"evens": even_branch, "odds": odd_branch}) # Each branch gets all items and then filters: # evens gets all items [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> filters to [2, 4, 6, 8, 10] @@ -340,7 +348,7 @@ def test_branch_with_multiple_transformations(self): # Simple transformer: just multiply by 10 simple_branch = createTransformer(int).map(lambda x: x * 10) - result = pipeline.branch({"complex": complex_branch, "simple": simple_branch}) + result, _ = pipeline.branch({"complex": complex_branch, "simple": simple_branch}) # Each branch gets all items: # complex gets all items [1, 2, 3, 4, 5, 6] -> filters to [2, 4, 6] -> [4, 8, 12] -> [5, 9, 13] @@ -359,7 +367,12 @@ def test_branch_with_chunked_data(self): small_chunk_transformer = createTransformer(int, chunk_size=5).map(lambda x: x * 2) identity_transformer = createTransformer(int, chunk_size=5) - result = pipeline.branch({"doubled": small_chunk_transformer, "identity": identity_transformer}) + result, _ = pipeline.branch( + { + "doubled": small_chunk_transformer, + "identity": identity_transformer, + } + ) # Each branch gets all items: # doubled gets all items [1, 2, 3, ..., 20] -> @@ -376,7 +389,12 @@ def test_branch_with_flatten_operation(self): flatten_branch = createTransformer(list).flatten() count_branch = createTransformer(list).map(lambda x: len(x)) - result = pipeline.branch({"flattened": flatten_branch, "lengths": count_branch}) + result, _ = pipeline.branch( + { + "flattened": flatten_branch, + "lengths": count_branch, + } + ) # Each branch gets all items: # flattened gets all items [[1, 2], [3, 4], [5, 6]] -> flattens to [1, 2, 3, 4, 5, 6] @@ -392,14 +410,14 @@ def test_branch_is_terminal_operation(self): double_branch = createTransformer(int).map(lambda x: x * 2) # Execute branch - result = pipeline.branch({"doubled": double_branch}) + result, _ = pipeline.branch({"doubled": double_branch}) # Verify the result - each branch gets all items: [1, 2, 3, 4, 5] -> [2, 4, 6, 8, 10] assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] # Attempt to use the pipeline again should yield empty results # since the iterator has been consumed - empty_result = pipeline.to_list() + empty_result, _ = pipeline.to_list() assert empty_result == [] def test_branch_with_different_chunk_sizes(self): @@ -411,7 +429,7 @@ def test_branch_with_different_chunk_sizes(self): large_chunk_branch = createTransformer(int, chunk_size=10).map(lambda x: x + 100) small_chunk_branch = createTransformer(int, chunk_size=3).map(lambda x: x + 200) - result = pipeline.branch({"large_chunk": large_chunk_branch, "small_chunk": small_chunk_branch}) + result, _ = pipeline.branch({"large_chunk": large_chunk_branch, "small_chunk": small_chunk_branch}) # Each branch gets all items: # large_chunk gets all items [1, 2, 3, ..., 15] -> [101, 102, 103, ..., 115] @@ -428,7 +446,7 @@ def test_branch_preserves_data_order_within_chunks(self): identity_branch = createTransformer(int) reverse_branch = createTransformer(int).map(lambda x: -x) - result = pipeline.branch({"identity": identity_branch, "negated": reverse_branch}) + result, _ = pipeline.branch({"identity": identity_branch, "negated": reverse_branch}) # Each branch gets all items: # identity gets all items: [5, 3, 8, 1, 9, 2] (preserves order) @@ -446,7 +464,7 @@ def test_branch_with_error_handling(self): # The division_branch should fail when processing 0 # The current implementation catches exceptions and returns empty lists for failed branches - result = pipeline.branch({"division": division_branch, "safe": safe_branch}) + result, _ = pipeline.branch({"division": division_branch, "safe": safe_branch}) # division gets all items [1, 2, 0, 4, 5] -> fails on 0, returns [] # safe gets all items [1, 2, 0, 4, 5] -> [2, 4, 0, 8, 10] @@ -458,18 +476,20 @@ def test_branch_context_isolation(self): pipeline = Pipeline([1, 2, 3]) # Create context-aware transformers that modify context - def context_modifier_a(chunk: list[int], ctx: PipelineContext) -> list[int]: + def context_modifier_a(chunk: list[int], ctx: IContextManager) -> list[int]: ctx["branch_a_processed"] = len(chunk) + print("branch a", ctx["branch_a_processed"]) return [x * 2 for x in chunk] - def context_modifier_b(chunk: list[int], ctx: PipelineContext) -> list[int]: + def context_modifier_b(chunk: list[int], ctx: IContextManager) -> list[int]: ctx["branch_b_processed"] = len(chunk) + print("branch b", ctx["branch_b_processed"]) return [x * 3 for x in chunk] branch_a = createTransformer(int)._pipe(context_modifier_a) branch_b = createTransformer(int)._pipe(context_modifier_b) - result = pipeline.branch({"branch_a": branch_a, "branch_b": branch_b}) + result, context = pipeline.branch({"branch_a": branch_a, "branch_b": branch_b}) # Each branch gets all items: # branch_a gets all items: [1, 2, 3] -> [2, 4, 6] @@ -478,5 +498,5 @@ def context_modifier_b(chunk: list[int], ctx: PipelineContext) -> list[int]: assert result["branch_b"] == [3, 6, 9] # Context values should reflect the actual chunk sizes processed - assert pipeline.ctx.get("branch_a_processed") == 3 - assert pipeline.ctx.get("branch_b_processed") == 3 + assert context["branch_b_processed"] == 3 + assert context["branch_a_processed"] == 3 diff --git a/tests/test_threaded_transformer.py b/tests/test_threaded_transformer.py index 4c25fc8..56ff9cc 100644 --- a/tests/test_threaded_transformer.py +++ b/tests/test_threaded_transformer.py @@ -1,13 +1,11 @@ """Tests for the ThreadedTransformer class.""" -import threading import time -from unittest.mock import patch from laygo import ErrorHandler -from laygo import PipelineContext from laygo import ThreadedTransformer -from laygo import Transformer +from laygo.context.parallel import ParallelContextManager +from laygo.context.types import IContextManager from laygo.transformers.threaded import createThreadedTransformer from laygo.transformers.transformer import createTransformer @@ -91,7 +89,7 @@ class TestThreadedTransformerContextSupport: def test_map_with_context(self): """Test map with context-aware function in concurrent execution.""" - context = PipelineContext({"multiplier": 3}) + context = ParallelContextManager({"multiplier": 3}) transformer = ThreadedTransformer[int, int](max_workers=2, chunk_size=2) transformer = transformer.map(lambda x, ctx: x * ctx["multiplier"]) result = list(transformer([1, 2, 3], context)) @@ -99,13 +97,13 @@ def test_map_with_context(self): def test_context_modification_with_locking(self): """Test safe context modification with locking in concurrent execution.""" - context = PipelineContext({"items": 0, "_lock": threading.Lock()}) + context = ParallelContextManager({"items": 0}) - def safe_increment(x: int, ctx: PipelineContext) -> int: - with ctx["_lock"]: - current_items = ctx["items"] + def safe_increment(x: int, ctx: IContextManager) -> int: + with ctx: + # Simulate a race condition time.sleep(0.001) # Increase chance of race condition - ctx["items"] = current_items + 1 + ctx["items"] = ctx["items"] + 1 return x * 2 transformer = ThreadedTransformer[int, int](max_workers=4, chunk_size=1) @@ -119,10 +117,10 @@ def safe_increment(x: int, ctx: PipelineContext) -> int: def test_multiple_context_values_modification(self): """Test modifying multiple context values safely.""" - context = PipelineContext({"total_sum": 0, "item_count": 0, "max_value": 0, "_lock": threading.Lock()}) + context = ParallelContextManager({"total_sum": 0, "item_count": 0, "max_value": 0}) - def update_stats(x: int, ctx: PipelineContext) -> int: - with ctx["_lock"]: + def update_stats(x: int, ctx: IContextManager) -> int: + with ctx: ctx["total_sum"] += x ctx["item_count"] += 1 ctx["max_value"] = max(ctx["max_value"], x) @@ -170,48 +168,6 @@ def test_unordered_vs_ordered_same_elements(self): assert ordered_result == [x * 2 for x in data] # Ordered maintains sequence -class TestThreadedTransformerPerformance: - """Test performance aspects of parallel transformer.""" - - def test_concurrent_performance_improvement(self): - """Test that concurrent execution improves performance for slow operations.""" - - def slow_operation(x: int) -> int: - time.sleep(0.01) # 10ms delay - return x * 2 - - data = list(range(8)) # 8 items, 80ms total sequential time - - # Sequential execution - start_time = time.time() - sequential = Transformer[int, int](chunk_size=4) - seq_result = list(sequential.map(slow_operation)(data)) - seq_time = time.time() - start_time - - # Concurrent execution - start_time = time.time() - concurrent = ThreadedTransformer[int, int](max_workers=4, chunk_size=4) - conc_result = list(concurrent.map(slow_operation)(data)) - conc_time = time.time() - start_time - - assert seq_result == conc_result - assert conc_time < seq_time * 0.8 # At least 20% faster - - def test_thread_pool_management(self): - """Test that thread pool is properly created and cleaned up.""" - with patch("laygo.transformers.threaded.ThreadPoolExecutor") as mock_executor: - mock_executor.return_value.__enter__.return_value = mock_executor.return_value - mock_executor.return_value.__exit__.return_value = None - mock_executor.return_value.submit.return_value.result.return_value = [2, 4] - - transformer = ThreadedTransformer[int, int](max_workers=2, chunk_size=2) - list(transformer([1, 2])) - - mock_executor.assert_called_with(max_workers=2) - mock_executor.return_value.__enter__.assert_called_once() - mock_executor.return_value.__exit__.assert_called_once() - - class TestThreadedTransformerChunking: """Test chunking behavior with concurrent execution.""" diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 777b131..6fe53e6 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -3,8 +3,8 @@ import pytest from laygo import ErrorHandler -from laygo import PipelineContext from laygo import Transformer +from laygo.context.simple import SimpleContextManager from laygo.transformers.transformer import createTransformer @@ -107,14 +107,14 @@ class TestTransformerContextSupport: def test_map_with_context(self): """Test map with context-aware function.""" - context = PipelineContext({"multiplier": 3}) + context = SimpleContextManager({"multiplier": 3}) transformer = Transformer().map(lambda x, ctx: x * ctx["multiplier"]) result = list(transformer([1, 2, 3], context)) assert result == [3, 6, 9] def test_filter_with_context(self): """Test filter with context-aware function.""" - context = PipelineContext({"threshold": 3}) + context = SimpleContextManager({"threshold": 3}) transformer = Transformer().filter(lambda x, ctx: x > ctx["threshold"]) result = list(transformer([1, 2, 3, 4, 5], context)) assert result == [4, 5] @@ -122,7 +122,7 @@ def test_filter_with_context(self): def test_tap_with_context(self): """Test tap with context-aware function.""" side_effects = [] - context = PipelineContext({"prefix": "item:"}) + context = SimpleContextManager({"prefix": "item:"}) transformer = Transformer().tap(lambda x, ctx: side_effects.append(f"{ctx['prefix']}{x}")) result = list(transformer([1, 2, 3], context)) @@ -159,7 +159,7 @@ def test_tap_with_transformer(self): def test_tap_with_transformer_and_context(self): """Test tap with a transformer that uses context.""" side_effects = [] - context = PipelineContext({"multiplier": 5, "log_prefix": "processed:"}) + context = SimpleContextManager({"multiplier": 5, "log_prefix": "processed:"}) # Create a context-aware side-effect transformer side_effect_transformer = ( @@ -186,7 +186,7 @@ def test_tap_with_transformer_and_context(self): def test_loop_with_context(self): """Test loop with context-aware condition and transformer.""" side_effects = [] - context = PipelineContext({"target_sum": 15, "increment": 2}) + context = SimpleContextManager({"target_sum": 15, "increment": 2}) # Create a context-aware loop transformer that uses context increment loop_transformer = ( @@ -213,7 +213,7 @@ def condition_with_context(chunk, ctx): def test_loop_with_context_and_side_effects(self): """Test loop with context-aware condition that reads context data.""" - context = PipelineContext({"max_value": 20, "increment": 3}) + context = SimpleContextManager({"max_value": 20, "increment": 3}) # Simple loop transformer that uses context increment loop_transformer = createTransformer(int).map(lambda x, ctx: x + ctx["increment"]) @@ -270,7 +270,7 @@ def test_basic_reduce(self): def test_reduce_with_context(self): """Test reduce with context-aware function.""" - context = PipelineContext({"multiplier": 2}) + context = SimpleContextManager({"multiplier": 2}) transformer = Transformer() reducer = transformer.reduce(lambda acc, x, ctx: acc + (x * ctx["multiplier"]), initial=0) result = list(reducer([1, 2, 3], context)) @@ -292,7 +292,7 @@ def test_reduce_per_chunk_basic(self): def test_reduce_per_chunk_with_context(self): """Test reduce with per_chunk=True and context-aware function.""" - context = PipelineContext({"multiplier": 2}) + context = SimpleContextManager({"multiplier": 2}) transformer = createTransformer(int, chunk_size=2).reduce( lambda acc, x, ctx: acc + (x * ctx["multiplier"]), initial=0, per_chunk=True )