diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index b277738..361bfce 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -176,6 +176,54 @@ def apply[T](self, t: Callable[[Self], "Transformer[In, T]"]) -> "Transformer[In """Apply another pipeline to the current one.""" return t(self) + def loop( + self, + loop_transformer: "Transformer[Out, Out]", + condition: Callable[[list[Out]], bool] | Callable[[list[Out], PipelineContext], bool], + max_iterations: int | None = None, + ) -> "Transformer[In, Out]": + """ + Repeatedly applies a transformer to each chunk until a condition is met. + + The loop continues as long as the `condition` function returns `True` and + the number of iterations has not reached `max_iterations`. The provided + `loop_transformer` must take a chunk of a certain type and return a chunk + of the same type. + + Args: + loop_transformer: The `Transformer` to apply in each iteration. Its + input and output types must match the current pipeline's + output type (`Transformer[Out, Out]`). + condition: A function that takes the current chunk (and optionally + the `PipelineContext`) and returns `True` to continue the + loop, or `False` to stop. + max_iterations: An optional integer to limit the number of repetitions + and prevent infinite loops. + + Returns: + The transformer instance for method chaining. + """ + looped_func = loop_transformer.transformer + condition_is_context_aware = is_context_aware(condition) + + def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: + condition_checker = ( # noqa: E731 + lambda current_chunk: condition(current_chunk, ctx) if condition_is_context_aware else condition(current_chunk) # type: ignore + ) + + current_chunk = chunk + + iterations = 0 + + # The loop now uses the single `condition_checker` function. + while (max_iterations is None or iterations < max_iterations) and condition_checker(current_chunk): # type: ignore + current_chunk = looped_func(current_chunk, ctx) + iterations += 1 + + return current_chunk + + return self._pipe(operation) + def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]: """ Executes the transformer on a data source. diff --git a/tests/test_transformer.py b/tests/test_transformer.py index fbcdc93..f626784 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -57,6 +57,50 @@ def test_tap_side_effects(self): assert result == [1, 2, 3] # Data unchanged assert side_effects == [1, 2, 3] # Side effect applied + def test_loop_basic_operation(self): + """Test loop applies transformer repeatedly until condition is met.""" + # Create a loop transformer that adds 1 to each element + increment_transformer = createTransformer(int).map(lambda x: x + 1) + + # Continue looping while any element is less than 5 + def condition(chunk): + return any(x < 5 for x in chunk) + + transformer = createTransformer(int).loop(increment_transformer, condition, max_iterations=10) + result = list(transformer([1, 2, 3])) + + # Should increment until all elements are >= 5: [1,2,3] -> [2,3,4] -> [3,4,5] -> [4,5,6] -> [5,6,7] + assert result == [5, 6, 7] + + def test_loop_with_max_iterations(self): + """Test loop respects max_iterations limit.""" + # Create a loop transformer that adds 1 to each element + increment_transformer = createTransformer(int).map(lambda x: x + 1) + + # Condition that would normally continue indefinitely + def always_true_condition(chunk): + return True + + transformer = createTransformer(int).loop(increment_transformer, always_true_condition, max_iterations=3) + result = list(transformer([1, 2, 3])) + + # Should stop after 3 iterations: [1,2,3] -> [2,3,4] -> [3,4,5] -> [4,5,6] + assert result == [4, 5, 6] + + def test_loop_no_iterations(self): + """Test loop when condition is false from the start.""" + increment_transformer = createTransformer(int).map(lambda x: x + 1) + + # Condition that's immediately false + def exit_immediately(chunk): + return False + + transformer = createTransformer(int).loop(increment_transformer, exit_immediately) + result = list(transformer([1, 2, 3])) + + # Should not iterate at all + assert result == [1, 2, 3] + class TestTransformerContextSupport: """Test context-aware transformer operations.""" @@ -139,6 +183,52 @@ def test_tap_with_transformer_and_context(self): # Side effects: [11,12,13] -> [55,60,65] -> ["processed:55", "processed:60", "processed:65"] assert side_effects == ["processed:55", "processed:60", "processed:65"] + def test_loop_with_context(self): + """Test loop with context-aware condition and transformer.""" + side_effects = [] + context = PipelineContext({"target_sum": 15, "increment": 2}) + + # Create a context-aware loop transformer that uses context increment + loop_transformer = ( + createTransformer(int) + .map(lambda x, ctx: x + ctx["increment"]) # Use context increment + .tap(lambda x, ctx: side_effects.append(f"iteration:{x}")) # Log each iteration + ) + + # Context-aware condition: continue while sum of chunk is less than target_sum + def condition_with_context(chunk, ctx): + return sum(chunk) < ctx["target_sum"] + + main_transformer = createTransformer(int).loop(loop_transformer, condition_with_context, max_iterations=10) + + result = list(main_transformer([1, 2, 3], context)) + + # Initial: [1,2,3] sum=6 < 15, continue + # After 1st: [3,4,5] sum=12 < 15, continue + # After 2nd: [5,6,7] sum=18 >= 15, stop + assert result == [5, 6, 7] + + # Should have logged both iterations + assert side_effects == ["iteration:3", "iteration:4", "iteration:5", "iteration:5", "iteration:6", "iteration:7"] + + def test_loop_with_context_and_side_effects(self): + """Test loop with context-aware condition that reads context data.""" + context = PipelineContext({"max_value": 20, "increment": 3}) + + # Simple loop transformer that uses context increment + loop_transformer = createTransformer(int).map(lambda x, ctx: x + ctx["increment"]) + + # Context-aware condition: continue while max value in chunk is less than context max_value + def condition_with_context(chunk, ctx): + return max(chunk) < ctx["max_value"] + + main_transformer = createTransformer(int).loop(loop_transformer, condition_with_context, max_iterations=10) + + result = list(main_transformer([5, 8], context)) + + # [5,8] -> [8,11] -> [11,14] -> [14,17] -> [17,20] (stop because max(17,20) >= 20) + assert result == [17, 20] + class TestTransformerChaining: """Test chaining multiple transformer operations."""