Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions laygo/transformers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,54 @@ def apply[T](self, t: Callable[[Self], "Transformer[In, T]"]) -> "Transformer[In
"""Apply another pipeline to the current one."""
return t(self)

def loop(
self,
loop_transformer: "Transformer[Out, Out]",
condition: Callable[[list[Out]], bool] | Callable[[list[Out], PipelineContext], bool],
max_iterations: int | None = None,
) -> "Transformer[In, Out]":
"""
Repeatedly applies a transformer to each chunk until a condition is met.

The loop continues as long as the `condition` function returns `True` and
the number of iterations has not reached `max_iterations`. The provided
`loop_transformer` must take a chunk of a certain type and return a chunk
of the same type.

Args:
loop_transformer: The `Transformer` to apply in each iteration. Its
input and output types must match the current pipeline's
output type (`Transformer[Out, Out]`).
condition: A function that takes the current chunk (and optionally
the `PipelineContext`) and returns `True` to continue the
loop, or `False` to stop.
max_iterations: An optional integer to limit the number of repetitions
and prevent infinite loops.

Returns:
The transformer instance for method chaining.
"""
looped_func = loop_transformer.transformer
condition_is_context_aware = is_context_aware(condition)

def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]:
condition_checker = ( # noqa: E731
lambda current_chunk: condition(current_chunk, ctx) if condition_is_context_aware else condition(current_chunk) # type: ignore
)

current_chunk = chunk

iterations = 0

# The loop now uses the single `condition_checker` function.
while (max_iterations is None or iterations < max_iterations) and condition_checker(current_chunk): # type: ignore
current_chunk = looped_func(current_chunk, ctx)
iterations += 1

return current_chunk

return self._pipe(operation)

def __call__(self, data: Iterable[In], context: PipelineContext | None = None) -> Iterator[Out]:
"""
Executes the transformer on a data source.
Expand Down
90 changes: 90 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,50 @@ def test_tap_side_effects(self):
assert result == [1, 2, 3] # Data unchanged
assert side_effects == [1, 2, 3] # Side effect applied

def test_loop_basic_operation(self):
"""Test loop applies transformer repeatedly until condition is met."""
# Create a loop transformer that adds 1 to each element
increment_transformer = createTransformer(int).map(lambda x: x + 1)

# Continue looping while any element is less than 5
def condition(chunk):
return any(x < 5 for x in chunk)

transformer = createTransformer(int).loop(increment_transformer, condition, max_iterations=10)
result = list(transformer([1, 2, 3]))

# Should increment until all elements are >= 5: [1,2,3] -> [2,3,4] -> [3,4,5] -> [4,5,6] -> [5,6,7]
assert result == [5, 6, 7]

def test_loop_with_max_iterations(self):
"""Test loop respects max_iterations limit."""
# Create a loop transformer that adds 1 to each element
increment_transformer = createTransformer(int).map(lambda x: x + 1)

# Condition that would normally continue indefinitely
def always_true_condition(chunk):
return True

transformer = createTransformer(int).loop(increment_transformer, always_true_condition, max_iterations=3)
result = list(transformer([1, 2, 3]))

# Should stop after 3 iterations: [1,2,3] -> [2,3,4] -> [3,4,5] -> [4,5,6]
assert result == [4, 5, 6]

def test_loop_no_iterations(self):
"""Test loop when condition is false from the start."""
increment_transformer = createTransformer(int).map(lambda x: x + 1)

# Condition that's immediately false
def exit_immediately(chunk):
return False

transformer = createTransformer(int).loop(increment_transformer, exit_immediately)
result = list(transformer([1, 2, 3]))

# Should not iterate at all
assert result == [1, 2, 3]


class TestTransformerContextSupport:
"""Test context-aware transformer operations."""
Expand Down Expand Up @@ -139,6 +183,52 @@ def test_tap_with_transformer_and_context(self):
# Side effects: [11,12,13] -> [55,60,65] -> ["processed:55", "processed:60", "processed:65"]
assert side_effects == ["processed:55", "processed:60", "processed:65"]

def test_loop_with_context(self):
"""Test loop with context-aware condition and transformer."""
side_effects = []
context = PipelineContext({"target_sum": 15, "increment": 2})

# Create a context-aware loop transformer that uses context increment
loop_transformer = (
createTransformer(int)
.map(lambda x, ctx: x + ctx["increment"]) # Use context increment
.tap(lambda x, ctx: side_effects.append(f"iteration:{x}")) # Log each iteration
)

# Context-aware condition: continue while sum of chunk is less than target_sum
def condition_with_context(chunk, ctx):
return sum(chunk) < ctx["target_sum"]

main_transformer = createTransformer(int).loop(loop_transformer, condition_with_context, max_iterations=10)

result = list(main_transformer([1, 2, 3], context))

# Initial: [1,2,3] sum=6 < 15, continue
# After 1st: [3,4,5] sum=12 < 15, continue
# After 2nd: [5,6,7] sum=18 >= 15, stop
assert result == [5, 6, 7]

# Should have logged both iterations
assert side_effects == ["iteration:3", "iteration:4", "iteration:5", "iteration:5", "iteration:6", "iteration:7"]

def test_loop_with_context_and_side_effects(self):
"""Test loop with context-aware condition that reads context data."""
context = PipelineContext({"max_value": 20, "increment": 3})

# Simple loop transformer that uses context increment
loop_transformer = createTransformer(int).map(lambda x, ctx: x + ctx["increment"])

# Context-aware condition: continue while max value in chunk is less than context max_value
def condition_with_context(chunk, ctx):
return max(chunk) < ctx["max_value"]

main_transformer = createTransformer(int).loop(loop_transformer, condition_with_context, max_iterations=10)

result = list(main_transformer([5, 8], context))

# [5,8] -> [8,11] -> [11,14] -> [14,17] -> [17,20] (stop because max(17,20) >= 20)
assert result == [17, 20]


class TestTransformerChaining:
"""Test chaining multiple transformer operations."""
Expand Down
Loading