diff --git a/laygo/__init__.py b/laygo/__init__.py index 1767cc0..ce85957 100644 --- a/laygo/__init__.py +++ b/laygo/__init__.py @@ -14,7 +14,9 @@ from laygo.transformers.threaded import ThreadedTransformer from laygo.transformers.threaded import createThreadedTransformer from laygo.transformers.transformer import Transformer +from laygo.transformers.transformer import build_chunk_generator from laygo.transformers.transformer import createTransformer +from laygo.transformers.transformer import passthrough_chunks __all__ = [ "Pipeline", @@ -28,4 +30,6 @@ "createHTTPTransformer", "PipelineContext", "ErrorHandler", + "passthrough_chunks", + "build_chunk_generator", ] diff --git a/laygo/pipeline.py b/laygo/pipeline.py index 0f82255..a637085 100644 --- a/laygo/pipeline.py +++ b/laygo/pipeline.py @@ -1,20 +1,23 @@ # pipeline.py - from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed import itertools import multiprocessing as mp +from queue import Queue from typing import Any from typing import TypeVar from typing import overload from laygo.helpers import PipelineContext from laygo.helpers import is_context_aware -from laygo.transformers.threaded import ThreadedTransformer from laygo.transformers.transformer import Transformer +from laygo.transformers.transformer import passthrough_chunks T = TypeVar("T") +U = TypeVar("U") PipelineFunction = Callable[[T], Any] @@ -147,16 +150,109 @@ def apply[U]( return self # type: ignore - def buffer(self, size: int) -> "Pipeline[T]": - """Buffer the pipeline using threaded processing. + def branch( + self, + branches: dict[str, Transformer[T, Any]], + batch_size: int = 1000, + max_batch_buffer: int = 1, + use_queue_chunks: bool = True, + ) -> dict[str, list[Any]]: + """Forks the pipeline into multiple branches for concurrent, parallel processing.""" + if not branches: + self.consume() + return {} + + source_iterator = self.processed_data + branch_items = list(branches.items()) + num_branches = len(branch_items) + final_results: dict[str, list[Any]] = {} + + queues = [Queue(maxsize=max_batch_buffer) for _ in range(num_branches)] + + def producer() -> None: + """Reads from the source and distributes batches to ALL branch queues.""" + # Use itertools.batched for clean and efficient batch creation. + for batch_tuple in itertools.batched(source_iterator, batch_size): + # The batch is a tuple; convert to a list for consumers. + batch_list = list(batch_tuple) + for q in queues: + q.put(batch_list) + + # Signal to all consumers that the stream is finished. + for q in queues: + q.put(None) + + def consumer(transformer: Transformer, queue: Queue) -> list[Any]: + """Consumes batches from a queue and runs them through a transformer.""" + + def stream_from_queue() -> Iterator[T]: + while (batch := queue.get()) is not None: + yield batch + + if use_queue_chunks: + transformer = transformer.set_chunker(passthrough_chunks) + + result_iterator = transformer(stream_from_queue(), self.ctx) # type: ignore + return list(result_iterator) + + with ThreadPoolExecutor(max_workers=num_branches + 1) as executor: + executor.submit(producer) + + future_to_name = { + executor.submit(consumer, transformer, queues[i]): name for i, (name, transformer) in enumerate(branch_items) + } + + for future in as_completed(future_to_name): + name = future_to_name[future] + try: + final_results[name] = future.result() + except Exception as e: + print(f"Branch '{name}' raised an exception: {e}") + final_results[name] = [] + + return final_results + + def buffer(self, size: int, batch_size: int = 1000) -> "Pipeline[T]": + """Inserts a buffer in the pipeline to allow downstream processing to read ahead. + + This creates a background thread that reads from the upstream data source + and fills a queue, decoupling the upstream and downstream stages. Args: - size: The number of worker threads to use for buffering. + size: The number of **batches** to hold in the buffer. + batch_size: The number of items to accumulate per batch. Returns: The pipeline instance for method chaining. """ - self.apply(ThreadedTransformer(max_workers=size)) + source_iterator = self.processed_data + + def _buffered_stream() -> Iterator[T]: + queue = Queue(maxsize=size) + # We only need one background thread for the producer. + executor = ThreadPoolExecutor(max_workers=1) + + def _producer() -> None: + """The producer reads from the source and fills the queue.""" + try: + for batch_tuple in itertools.batched(source_iterator, batch_size): + queue.put(list(batch_tuple)) + finally: + # Always put the sentinel value to signal the end of the stream. + queue.put(None) + + # Start the producer in the background thread. + executor.submit(_producer) + + try: + # The main thread becomes the consumer. + while (batch := queue.get()) is not None: + yield from batch + finally: + # Ensure the background thread is cleaned up. + executor.shutdown(wait=False, cancel_futures=True) + + self.processed_data = _buffered_stream() return self def __iter__(self) -> Iterator[T]: diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index 6f7bdd6..bbb6792 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -60,6 +60,20 @@ def chunk_generator(data: Iterable[T]) -> Iterator[list[T]]: return chunk_generator +def passthrough_chunks[T](data: Iterable[list[T]]) -> Iterator[list[T]]: + """A chunk generator that yields the entire input as a single chunk. + + This is useful for transformers that do not require chunking. + + Args: + data: The input data to process. + + Returns: + An iterator yielding the entire input as a single chunk. + """ + yield from iter(data) + + class Transformer[In, Out]: """Define and compose data transformations by passing context explicitly. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index d9443c6..ef74138 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,7 @@ """Tests for the Pipeline class.""" from laygo import Pipeline +from laygo import PipelineContext from laygo.transformers.transformer import createTransformer @@ -211,3 +212,271 @@ def second_map(x): assert sorted(first_map_values) == list(range(10)) assert sorted(second_map_values) == [x * 2 for x in range(10)] + + +class TestPipelineBranch: + """Test pipeline branch method functionality.""" + + def test_branch_basic_functionality(self): + """Test basic branch operation with simple transformers.""" + # Create a pipeline with basic data + pipeline = Pipeline([1, 2, 3, 4, 5]) + + # Create two different branch transformers + double_branch = createTransformer(int).map(lambda x: x * 2) + square_branch = createTransformer(int).map(lambda x: x**2) + + # Execute branching + result = pipeline.branch({"doubled": double_branch, "squared": square_branch}) + + # Verify results contain processed items for each branch + assert "doubled" in result + assert "squared" in result + assert len(result) == 2 + + # Each branch gets all items from the pipeline: + # doubled gets all items: [1, 2, 3, 4, 5] -> [2, 4, 6, 8, 10] + # squared gets all items: [1, 2, 3, 4, 5] -> [1, 4, 9, 16, 25] + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] + assert sorted(result["squared"]) == [1, 4, 9, 16, 25] + + def test_branch_with_empty_input(self): + """Test branch with empty input data.""" + pipeline = Pipeline([]) + + double_branch = createTransformer(int).map(lambda x: x * 2) + square_branch = createTransformer(int).map(lambda x: x**2) + + result = pipeline.branch({"doubled": double_branch, "squared": square_branch}) + + # Should return empty lists for all branches + assert result == {"doubled": [], "squared": []} + + def test_branch_with_empty_branches_dict(self): + """Test branch with empty branches dictionary.""" + pipeline = Pipeline([1, 2, 3]) + + result = pipeline.branch({}) + + # Should return empty dictionary + assert result == {} + + def test_branch_with_single_branch(self): + """Test branch with only one branch.""" + pipeline = Pipeline([1, 2, 3, 4]) + + triple_branch = createTransformer(int).map(lambda x: x * 3) + + result = pipeline.branch({"tripled": triple_branch}) + + assert len(result) == 1 + assert "tripled" in result + # With only one branch, it gets all items + assert sorted(result["tripled"]) == [3, 6, 9, 12] + + def test_branch_with_custom_queue_size(self): + """Test branch with custom queue size parameter.""" + pipeline = Pipeline([1, 2, 3, 4, 5]) + + double_branch = createTransformer(int).map(lambda x: x * 2) + triple_branch = createTransformer(int).map(lambda x: x * 3) + + # Test with a small queue size + result = pipeline.branch( + { + "doubled": double_branch, + "tripled": triple_branch, + }, + max_batch_buffer=2, + ) + + # Each branch gets all items regardless of queue size: + # doubled gets all items: [1, 2, 3, 4, 5] -> [2, 4, 6, 8, 10] + # tripled gets all items: [1, 2, 3, 4, 5] -> [3, 6, 9, 12, 15] + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] + assert sorted(result["tripled"]) == [3, 6, 9, 12, 15] + + def test_branch_with_three_branches(self): + """Test branch with three branches to verify fan-out distribution.""" + pipeline = Pipeline([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + add_10 = createTransformer(int).map(lambda x: x + 10) + add_20 = createTransformer(int).map(lambda x: x + 20) + add_30 = createTransformer(int).map(lambda x: x + 30) + + result = pipeline.branch({"add_10": add_10, "add_20": add_20, "add_30": add_30}) + + # Each branch gets all items: + # add_10 gets all items: [1, 2, 3, 4, 5, 6, 7, 8, 9] -> [11, 12, 13, 14, 15, 16, 17, 18, 19] + # add_20 gets all items: [1, 2, 3, 4, 5, 6, 7, 8, 9] -> [21, 22, 23, 24, 25, 26, 27, 28, 29] + # add_30 gets all items: [1, 2, 3, 4, 5, 6, 7, 8, 9] -> [31, 32, 33, 34, 35, 36, 37, 38, 39] + assert sorted(result["add_10"]) == [11, 12, 13, 14, 15, 16, 17, 18, 19] + assert sorted(result["add_20"]) == [21, 22, 23, 24, 25, 26, 27, 28, 29] + assert sorted(result["add_30"]) == [31, 32, 33, 34, 35, 36, 37, 38, 39] + + def test_branch_with_filtering_transformers(self): + """Test branch with transformers that filter data.""" + pipeline = Pipeline([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + # Create transformers that filter data + even_branch = createTransformer(int).filter(lambda x: x % 2 == 0) + odd_branch = createTransformer(int).filter(lambda x: x % 2 == 1) + + result = pipeline.branch({"evens": even_branch, "odds": odd_branch}) + + # Each branch gets all items and then filters: + # evens gets all items [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> filters to [2, 4, 6, 8, 10] + # odds gets all items [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -> filters to [1, 3, 5, 7, 9] + assert result["evens"] == [2, 4, 6, 8, 10] + assert result["odds"] == [1, 3, 5, 7, 9] + + def test_branch_with_multiple_transformations(self): + """Test branch with complex multi-step transformers.""" + pipeline = Pipeline([1, 2, 3, 4, 5, 6]) + + # Complex transformer: filter evens, then double, then add 1 + complex_branch = createTransformer(int).filter(lambda x: x % 2 == 0).map(lambda x: x * 2).map(lambda x: x + 1) + + # Simple transformer: just multiply by 10 + simple_branch = createTransformer(int).map(lambda x: x * 10) + + result = pipeline.branch({"complex": complex_branch, "simple": simple_branch}) + + # Each branch gets all items: + # complex gets all items [1, 2, 3, 4, 5, 6] -> filters to [2, 4, 6] -> [4, 8, 12] -> [5, 9, 13] + # simple gets all items [1, 2, 3, 4, 5, 6] -> [10, 20, 30, 40, 50, 60] + assert result["complex"] == [5, 9, 13] + assert sorted(result["simple"]) == [10, 20, 30, 40, 50, 60] + + def test_branch_with_chunked_data(self): + """Test branch behavior with data that gets processed in multiple chunks.""" + # Create a dataset large enough to be processed in multiple chunks + # with a small chunk size + data = list(range(1, 21)) # [1, 2, 3, ..., 20] + pipeline = Pipeline(data) + + # Use small chunk size to ensure multiple chunks + small_chunk_transformer = createTransformer(int, chunk_size=5).map(lambda x: x * 2) + identity_transformer = createTransformer(int, chunk_size=5) + + result = pipeline.branch({"doubled": small_chunk_transformer, "identity": identity_transformer}) + + # Each branch gets all items: + # doubled gets all items [1, 2, 3, ..., 20] -> + # [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40] + # identity gets all items [1, 2, 3, ..., 20] -> + # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40] + assert sorted(result["identity"]) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + + def test_branch_with_flatten_operation(self): + """Test branch with flatten operations.""" + pipeline = Pipeline([[1, 2], [3, 4], [5, 6]]) + + flatten_branch = createTransformer(list).flatten() + count_branch = createTransformer(list).map(lambda x: len(x)) + + result = pipeline.branch({"flattened": flatten_branch, "lengths": count_branch}) + + # Each branch gets all items: + # flattened gets all items [[1, 2], [3, 4], [5, 6]] -> flattens to [1, 2, 3, 4, 5, 6] + # lengths gets all items [[1, 2], [3, 4], [5, 6]] -> [2, 2, 2] + assert sorted(result["flattened"]) == [1, 2, 3, 4, 5, 6] + assert result["lengths"] == [2, 2, 2] + + def test_branch_is_terminal_operation(self): + """Test that branch is a terminal operation that consumes the pipeline.""" + pipeline = Pipeline([1, 2, 3, 4, 5]) + + # Create a simple transformer + double_branch = createTransformer(int).map(lambda x: x * 2) + + # Execute branch + result = pipeline.branch({"doubled": double_branch}) + + # Verify the result - each branch gets all items: [1, 2, 3, 4, 5] -> [2, 4, 6, 8, 10] + assert sorted(result["doubled"]) == [2, 4, 6, 8, 10] + + # Attempt to use the pipeline again should yield empty results + # since the iterator has been consumed + empty_result = pipeline.to_list() + assert empty_result == [] + + def test_branch_with_different_chunk_sizes(self): + """Test branch with transformers that have different chunk sizes.""" + data = list(range(1, 16)) # [1, 2, 3, ..., 15] + pipeline = Pipeline(data) + + # Different chunk sizes for different branches + large_chunk_branch = createTransformer(int, chunk_size=10).map(lambda x: x + 100) + small_chunk_branch = createTransformer(int, chunk_size=3).map(lambda x: x + 200) + + result = pipeline.branch({"large_chunk": large_chunk_branch, "small_chunk": small_chunk_branch}) + + # Each branch gets all items: + # large_chunk gets all items [1, 2, 3, ..., 15] -> [101, 102, 103, ..., 115] + # small_chunk gets all items [1, 2, 3, ..., 15] -> [201, 202, 203, ..., 215] + + assert sorted(result["large_chunk"]) == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115] + assert sorted(result["small_chunk"]) == [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215] + + def test_branch_preserves_data_order_within_chunks(self): + """Test that branch preserves data order within the final chunk.""" + pipeline = Pipeline([5, 3, 8, 1, 9, 2]) + + # Identity transformer should preserve order + identity_branch = createTransformer(int) + reverse_branch = createTransformer(int).map(lambda x: -x) + + result = pipeline.branch({"identity": identity_branch, "negated": reverse_branch}) + + # Each branch gets all items: + # identity gets all items: [5, 3, 8, 1, 9, 2] (preserves order) + # negated gets all items: [5, 3, 8, 1, 9, 2] -> [-5, -3, -8, -1, -9, -2] (preserves order) + assert result["identity"] == [5, 3, 8, 1, 9, 2] + assert result["negated"] == [-5, -3, -8, -1, -9, -2] + + def test_branch_with_error_handling(self): + """Test branch behavior when transformers encounter errors.""" + pipeline = Pipeline([1, 2, 0, 4, 5]) + + # Create a transformer that will fail on zero division + division_branch = createTransformer(int).map(lambda x: 10 // x) + safe_branch = createTransformer(int).map(lambda x: x * 2) + + # The division_branch should fail when processing 0 + # The current implementation catches exceptions and returns empty lists for failed branches + result = pipeline.branch({"division": division_branch, "safe": safe_branch}) + + # division gets all items [1, 2, 0, 4, 5] -> fails on 0, returns [] + # safe gets all items [1, 2, 0, 4, 5] -> [2, 4, 0, 8, 10] + assert result["division"] == [] # Error causes empty result + assert sorted(result["safe"]) == [0, 2, 4, 8, 10] + + def test_branch_context_isolation(self): + """Test that different branches don't interfere with each other's context.""" + pipeline = Pipeline([1, 2, 3]) + + # Create context-aware transformers that modify context + def context_modifier_a(chunk: list[int], ctx: PipelineContext) -> list[int]: + ctx["branch_a_processed"] = len(chunk) + return [x * 2 for x in chunk] + + def context_modifier_b(chunk: list[int], ctx: PipelineContext) -> list[int]: + ctx["branch_b_processed"] = len(chunk) + return [x * 3 for x in chunk] + + branch_a = createTransformer(int)._pipe(context_modifier_a) + branch_b = createTransformer(int)._pipe(context_modifier_b) + + result = pipeline.branch({"branch_a": branch_a, "branch_b": branch_b}) + + # Each branch gets all items: + # branch_a gets all items: [1, 2, 3] -> [2, 4, 6] + # branch_b gets all items: [1, 2, 3] -> [3, 6, 9] + assert sorted(result["branch_a"]) == [2, 4, 6] + assert result["branch_b"] == [3, 6, 9] + + # Context values should reflect the actual chunk sizes processed + assert pipeline.ctx.get("branch_a_processed") == 3 + assert pipeline.ctx.get("branch_b_processed") == 3