diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index a6ffbd8d..f472cba6 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -16,9 +16,6 @@ package com.google.adk.agents; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.Arrays.stream; - import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; @@ -255,10 +252,11 @@ public Flowable runAsync(InvocationContext parentContext) { spanContext, span, () -> - callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), callbackPlugin), + processAgentCallbackResult( + ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx), invocationContext) + .map(Optional::of) + .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher( beforeEventOpt -> { if (invocationContext.endInvocation()) { @@ -271,11 +269,14 @@ public Flowable runAsync(InvocationContext parentContext) { Flowable afterEvents = Flowable.defer( () -> - callCallback( - afterCallbacksToFunctions( - invocationContext.pluginManager(), - callbackPlugin), + processAgentCallbackResult( + ctx -> + invocationContext + .combinedPlugin() + .afterAgentCallback(this, ctx), invocationContext) + .map(Optional::of) + .switchIfEmpty(Single.just(Optional.empty())) .flatMapPublisher(Flowable::fromOptional)); return Flowable.concat(beforeEvents, mainEvents, afterEvents); @@ -284,73 +285,32 @@ public Flowable runAsync(InvocationContext parentContext) { } /** - * Converts before-agent callbacks to functions. + * Processes the result of an agent callback, creating an {@link Event} if necessary. * - * @return callback functions. + * @param agentCallback The callback function. + * @param invocationContext The current invocation context. + * @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise. */ - private ImmutableList>> beforeCallbacksToFunctions( - Plugin... plugins) { - return stream(plugins) - .map( - p -> - (Function>) ctx -> p.beforeAgentCallback(this, ctx)) - .collect(toImmutableList()); - } - - /** - * Converts after-agent callbacks to functions. - * - * @return callback functions. - */ - private ImmutableList>> afterCallbacksToFunctions( - Plugin... plugins) { - return stream(plugins) - .map( - p -> (Function>) ctx -> p.afterAgentCallback(this, ctx)) - .collect(toImmutableList()); - } - - /** - * Calls agent callbacks and returns the first produced event, if any. - * - * @param agentCallbacks Callback functions. - * @param invocationContext Current invocation context. - * @return single emitting first event, or empty if none. - */ - private Single> callCallback( - List>> agentCallbacks, + private Maybe processAgentCallbackResult( + Function> agentCallback, InvocationContext invocationContext) { - if (agentCallbacks == null || agentCallbacks.isEmpty()) { - return Single.just(Optional.empty()); - } - - CallbackContext callbackContext = - new CallbackContext(invocationContext, /* eventActions= */ null); - - return Flowable.fromIterable(agentCallbacks) - .concatMap( - callback -> { - Maybe maybeContent = callback.apply(callbackContext); - - return maybeContent - .map( - content -> { - invocationContext.setEndInvocation(true); - return Optional.of( - Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()) - .content(content) - .build()); - }) - .toFlowable(); + var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null); + return agentCallback + .apply(callbackContext) + .map( + content -> { + invocationContext.setEndInvocation(true); + return Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch()) + .actions(callbackContext.eventActions()) + .content(content) + .build(); }) - .firstElement() .switchIfEmpty( - Single.defer( + Maybe.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -361,9 +321,9 @@ private Single> callCallback( .branch(invocationContext.branch()) .actions(callbackContext.eventActions()); - return Single.just(Optional.of(eventBuilder.build())); + return Maybe.just(eventBuilder.build()); } else { - return Single.just(Optional.empty()); + return Maybe.empty(); } })); } diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 532bc92f..0a8ed416 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -25,6 +25,7 @@ import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; @@ -44,6 +45,7 @@ public class InvocationContext { private final BaseArtifactService artifactService; private final BaseMemoryService memoryService; private final Plugin pluginManager; + private final Plugin combinedPlugin; private final Optional liveRequestQueue; private final Map activeStreamingTools; private final String invocationId; @@ -73,6 +75,13 @@ protected InvocationContext(Builder builder) { this.endInvocation = builder.endInvocation; this.resumabilityConfig = builder.resumabilityConfig; this.invocationCostManager = builder.invocationCostManager; + this.combinedPlugin = + Optional.ofNullable(builder.agent) + .map(BaseAgent::getPlugin) + .map( + agentPlugin -> + (Plugin) new PluginManager(ImmutableList.of(pluginManager, agentPlugin))) + .orElse(pluginManager); } /** @@ -235,6 +244,14 @@ public Plugin pluginManager() { return pluginManager; } + /** + * Returns a {@link Plugin} that combines agent-specific plugins with framework-level plugins, + * allowing tools from both to be invoked. + */ + public Plugin combinedPlugin() { + return combinedPlugin; + } + /** Returns a map of tool call IDs to active streaming tools for the current invocation. */ public Map activeStreamingTools() { return activeStreamingTools; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 5e6331b7..307b159f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -199,7 +199,7 @@ private Flowable callLlm( .onErrorResumeNext( exception -> context - .pluginManager() + .combinedPlugin() .onModelErrorCallback( new CallbackContext( context, eventForCallbackUsage.actions()), @@ -243,27 +243,9 @@ private Single> handleBeforeModelCallback( Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); - - LlmAgent agent = (LlmAgent) context.agent(); - - Optional> callbacksOpt = agent.beforeModelCallback(); - if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { - return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); - } - - List callbacks = callbacksOpt.get(); - - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) - .firstElement()); - - return pluginResult - .switchIfEmpty(callbackResult) + return context + .combinedPlugin() + .beforeModelCallback(callbackContext, llmRequestBuilder) .map(Optional::of) .defaultIfEmpty(Optional.empty()); } @@ -279,24 +261,10 @@ private Single handleAfterModelCallback( Event callbackEvent = modelResponseEvent.toBuilder().build(); CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); - - LlmAgent agent = (LlmAgent) context.agent(); - Optional> callbacksOpt = agent.afterModelCallback(); - - if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { - return pluginResult.defaultIfEmpty(llmResponse); - } - - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacksOpt.get()) - .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); + return context + .combinedPlugin() + .afterModelCallback(callbackContext, llmResponse) + .defaultIfEmpty(llmResponse); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 8536e470..9a81da8e 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -22,10 +22,7 @@ import com.google.adk.Telemetry; import com.google.adk.agents.ActiveStreamingTool; -import com.google.adk.agents.Callbacks.AfterToolCallback; -import com.google.adk.agents.Callbacks.BeforeToolCallback; import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.events.EventActions; @@ -388,7 +385,7 @@ private static Maybe postProcessFunctionResult( .onErrorResumeNext( t -> invocationContext - .pluginManager() + .combinedPlugin() .onToolErrorCallback(tool, functionArgs, toolContext, t) .map(isLive ? Optional::ofNullable : Optional::of) .switchIfEmpty(Single.error(t))) @@ -457,30 +454,7 @@ private static Maybe> maybeInvokeBeforeToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); - - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - - Optional> callbacksOpt = agent.beforeToolCallback(); - if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { - return pluginResult; - } - List callbacks = callbacksOpt.get(); - - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call(invocationContext, tool, functionArgs, toolContext)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); - } - return Maybe.empty(); + return invocationContext.combinedPlugin().beforeToolCallback(tool, functionArgs, toolContext); } private static Maybe> maybeInvokeAfterToolCall( @@ -489,37 +463,9 @@ private static Maybe> maybeInvokeAfterToolCall( Map functionArgs, ToolContext toolContext, Map functionResult) { - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); - - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); - - Optional> callbacksOpt = agent.afterToolCallback(); - if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { - return pluginResult; - } - List callbacks = callbacksOpt.get(); - - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback.call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult)) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); - } - return Maybe.empty(); + return invocationContext + .combinedPlugin() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); } private static Maybe> callTool( diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 740b5469..7df90c51 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -514,7 +514,7 @@ public Flowable runAsync( updatedSession, session); return contextWithUpdatedSession - .pluginManager() + .combinedPlugin() .onEventCallback( contextWithUpdatedSession, registeredEvent) diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 92ec8791..8258d32d 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -18,6 +18,8 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.testing.TestBaseAgent; import com.google.adk.testing.TestCallback; @@ -25,6 +27,7 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -71,6 +74,25 @@ public void constructor_setsNameAndDescription() { assertThat(afterCallback.wasCalled()).isFalse(); } + @Test + public void runAsync_firstBeforeCallbackReturnsContent_skipsSecondBeforeCallback() { + Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); + var beforeCallback1 = TestCallback.returning(callbackContent); + var beforeCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of( + beforeCallback1.asBeforeAgentCallback(), beforeCallback2.asBeforeAgentCallback()), + ImmutableList.of(), + TestCallback.returningEmpty().asRunAsyncImplSupplier("main_output")); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + var unused = agent.runAsync(invocationContext).toList().blockingGet(); + assertThat(beforeCallback1.wasCalled()).isTrue(); + assertThat(beforeCallback2.wasCalled()).isFalse(); + } + @Test public void runAsync_noCallbacks_invokesRunAsyncImpl() { var runAsyncImpl = TestCallback.returningEmpty(); @@ -142,4 +164,101 @@ public void runAsync_noCallbacks_invokesRunAsyncImpl() { assertThat(beforeCallback.wasCalled()).isTrue(); assertThat(afterCallback.wasCalled()).isTrue(); } + + @Test + public void + runAsync_beforeCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + BeforeAgentCallback beforeCallback = + new BeforeAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // State event from before callback + assertThat(results.get(0).content()).isEmpty(); + assertThat(results.get(0).actions().stateDelta()).containsEntry("key", "value"); + // Content event from runAsyncImpl + assertThat(results.get(1).content()).hasValue(runAsyncImplContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runAsync_afterCallbackMutatesStateAndReturnsEmpty_invokesRunAsyncImplAndReturnsStateEvent() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + AfterAgentCallback afterCallback = + new AfterAgentCallback() { + @Override + public Maybe call(CallbackContext context) { + context.state().put("key", "value"); + return Maybe.empty(); + } + }; + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + // Content event from runAsyncImpl + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + // State event from after callback + assertThat(results.get(1).content()).isEmpty(); + assertThat(results.get(1).actions().stateDelta()).containsEntry("key", "value"); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + } + + @Test + public void runAsync_firstAfterCallbackReturnsContent_skipsSecondAfterCallback() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var afterCallback1 = TestCallback.returning(afterCallbackContent); + var afterCallback2 = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(), + ImmutableList.of( + afterCallback1.asAfterAgentCallback(), afterCallback2.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(afterCallback1.wasCalled()).isTrue(); + assertThat(afterCallback2.wasCalled()).isFalse(); + } }