diff --git a/laygo/transformers/transformer.py b/laygo/transformers/transformer.py index 8c4d3e5..b277738 100644 --- a/laygo/transformers/transformer.py +++ b/laygo/transformers/transformer.py @@ -123,13 +123,54 @@ def flatten[T]( """Flattens nested lists; the context is passed through the operation.""" return self._pipe(lambda chunk, ctx: [item for sublist in chunk for item in sublist]) # type: ignore - def tap(self, function: PipelineFunction[Out, Any]) -> "Transformer[In, Out]": - """Applies a side-effect function without modifying the data.""" + @overload + def tap(self, arg: "Transformer[Out, Any]") -> "Transformer[In, Out]": ... - if is_context_aware(function): - return self._pipe(lambda chunk, ctx: [x for x in chunk if function(x, ctx) or True]) + @overload + def tap(self, arg: PipelineFunction[Out, Any]) -> "Transformer[In, Out]": ... + + def tap( + self, + arg: Union["Transformer[Out, Any]", PipelineFunction[Out, Any]], + ) -> "Transformer[In, Out]": + """ + Applies a side-effect without modifying the main data stream. + + This method can be used in two ways: + 1. With a `Transformer`: Applies a sub-pipeline to each chunk for side-effects + (e.g., logging a chunk), discarding the sub-pipeline's output. + 2. With a `function`: Applies a function to each element individually for + side-effects (e.g., printing an item). + + Args: + arg: A `Transformer` instance or a function to be applied for side-effects. + + Returns: + The transformer instance for method chaining. + """ + match arg: + # Case 1: The argument is another Transformer + case Transformer() as tapped_transformer: + tapped_func = tapped_transformer.transformer + + def operation(chunk: list[Out], ctx: PipelineContext) -> list[Out]: + # Execute the tapped transformer's logic on the chunk for side-effects. + _ = tapped_func(chunk, ctx) + # Return the original chunk to continue the main pipeline. + return chunk + + return self._pipe(operation) + + # Case 2: The argument is a callable function + case function if callable(function): + if is_context_aware(function): + return self._pipe(lambda chunk, ctx: [x for x in chunk if function(x, ctx) or True]) + + return self._pipe(lambda chunk, _ctx: [x for x in chunk if function(x) or True]) # type: ignore - return self._pipe(lambda chunk, _ctx: [function(x) or x for x in chunk]) # type: ignore + # Default case for robustness + case _: + raise TypeError(f"tap() argument must be a Transformer or a callable, not {type(arg).__name__}") def apply[T](self, t: Callable[[Self], "Transformer[In, T]"]) -> "Transformer[In, T]": """Apply another pipeline to the current one.""" diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 7135349..fbcdc93 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -85,6 +85,60 @@ def test_tap_with_context(self): assert result == [1, 2, 3] assert side_effects == ["item:1", "item:2", "item:3"] + def test_tap_with_transformer(self): + """Test tap with a transformer for side effects.""" + side_effects = [] + + # Create a side-effect transformer that logs processed values + side_effect_transformer = ( + createTransformer(int) + .map(lambda x: x * 10) # Transform for side effect + .tap(lambda x: side_effects.append(x)) # Capture the transformed values + ) + + # Main transformer that uses the side-effect transformer via tap + main_transformer = ( + createTransformer(int) + .map(lambda x: x * 2) # Main transformation + .tap(side_effect_transformer) # Apply side-effect transformer + .map(lambda x: x + 1) # Continue main transformation + ) + + result = list(main_transformer([1, 2, 3])) + + # Main pipeline should produce: [1,2,3] -> [2,4,6] -> [3,5,7] + assert result == [3, 5, 7] + + # Side effects should capture: [2,4,6] -> [20,40,60] + assert side_effects == [20, 40, 60] + + def test_tap_with_transformer_and_context(self): + """Test tap with a transformer that uses context.""" + side_effects = [] + context = PipelineContext({"multiplier": 5, "log_prefix": "processed:"}) + + # Create a context-aware side-effect transformer + side_effect_transformer = ( + createTransformer(int) + .map(lambda x, ctx: x * ctx["multiplier"]) # Use context multiplier + .tap(lambda x, ctx: side_effects.append(f"{ctx['log_prefix']}{x}")) # Log with context prefix + ) + + # Main transformer + main_transformer = ( + createTransformer(int) + .map(lambda x: x + 10) # Main transformation + .tap(side_effect_transformer) # Apply side-effect transformer with context + ) + + result = list(main_transformer([1, 2, 3], context)) + + # Main pipeline: [1,2,3] -> [11,12,13] + assert result == [11, 12, 13] + + # Side effects: [11,12,13] -> [55,60,65] -> ["processed:55", "processed:60", "processed:65"] + assert side_effects == ["processed:55", "processed:60", "processed:65"] + class TestTransformerChaining: """Test chaining multiple transformer operations."""