From 5103c09d3bb9c00276d9e5f6b49fe6733a319a9a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 12 Jan 2026 12:42:59 -0800 Subject: [PATCH] refactor: Introducing a CallbackPlugin to wrap the old style Callbacks The goal is to unify the processing of Plugins and Callbacks. We should consider depercating and removing the old Callbacks. There are a bunch of cyclical dependencies caused by requests back to the agent to get specific Callbacks. The next step will be to augmet the InvocationContext's PluginManager with the appropriate agent specific callbacks PiperOrigin-RevId: 855343793 --- .../java/com/google/adk/agents/BaseAgent.java | 93 +++- .../com/google/adk/agents/CallbackPlugin.java | 335 ++++++++++++ .../java/com/google/adk/agents/LlmAgent.java | 170 +----- .../java/com/google/adk/agents/LoopAgent.java | 21 +- .../com/google/adk/agents/ParallelAgent.java | 17 +- .../google/adk/agents/ReadonlyContext.java | 5 + .../google/adk/agents/SequentialAgent.java | 19 +- .../com/google/adk/agents/BaseAgentTest.java | 106 +++- .../google/adk/agents/CallbackPluginTest.java | 499 ++++++++++++++++++ .../com/google/adk/testing/TestCallback.java | 164 ++++++ 10 files changed, 1192 insertions(+), 237 deletions(-) create mode 100644 core/src/main/java/com/google/adk/agents/CallbackPlugin.java create mode 100644 core/src/test/java/com/google/adk/agents/CallbackPluginTest.java create mode 100644 core/src/test/java/com/google/adk/testing/TestCallback.java diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 53a978974..a6ffbd8d7 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.stream; import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; @@ -59,8 +60,7 @@ public abstract class BaseAgent { private final List subAgents; - private final Optional> beforeAgentCallback; - private final Optional> afterAgentCallback; + protected final CallbackPlugin callbackPlugin; /** * Creates a new BaseAgent. @@ -77,14 +77,34 @@ public BaseAgent( String name, String description, List subAgents, - List beforeAgentCallback, - List afterAgentCallback) { + @Nullable List beforeAgentCallback, + @Nullable List afterAgentCallback) { + this( + name, + description, + subAgents, + createCallbackPlugin(beforeAgentCallback, afterAgentCallback)); + } + + /** + * Creates a new BaseAgent. + * + * @param name Unique agent name. Cannot be "user" (reserved). + * @param description Agent purpose. + * @param subAgents Agents managed by this agent. + * @param callbackPlugin The callback plugin for this agent. + */ + protected BaseAgent( + String name, + String description, + List subAgents, + CallbackPlugin callbackPlugin) { this.name = name; this.description = description; this.parentAgent = null; this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); - this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); + this.callbackPlugin = + callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin; // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -92,6 +112,18 @@ public BaseAgent( } } + /** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */ + private static CallbackPlugin createCallbackPlugin( + @Nullable List beforeAgentCallbacks, + @Nullable List afterAgentCallbacks) { + CallbackPlugin.Builder builder = CallbackPlugin.builder(); + Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback); + Optional.ofNullable(afterAgentCallbacks).stream() + .flatMap(List::stream) + .forEach(builder::addCallback); + return builder.build(); + } + /** * Gets the agent's unique name. * @@ -172,11 +204,15 @@ public List subAgents() { } public Optional> beforeAgentCallback() { - return beforeAgentCallback; + return Optional.of(callbackPlugin.getBeforeAgentCallback()); } public Optional> afterAgentCallback() { - return afterAgentCallback; + return Optional.of(callbackPlugin.getAfterAgentCallback()); + } + + public Plugin getPlugin() { + return callbackPlugin; } /** @@ -221,8 +257,7 @@ public Flowable runAsync(InvocationContext parentContext) { () -> callCallback( beforeCallbacksToFunctions( - invocationContext.pluginManager(), - beforeAgentCallback.orElse(ImmutableList.of())), + invocationContext.pluginManager(), callbackPlugin), invocationContext) .flatMapPublisher( beforeEventOpt -> { @@ -239,7 +274,7 @@ public Flowable runAsync(InvocationContext parentContext) { callCallback( afterCallbacksToFunctions( invocationContext.pluginManager(), - afterAgentCallback.orElse(ImmutableList.of())), + callbackPlugin), invocationContext) .flatMapPublisher(Flowable::fromOptional)); @@ -251,30 +286,27 @@ public Flowable runAsync(InvocationContext parentContext) { /** * Converts before-agent callbacks to functions. * - * @param callbacks Before-agent callbacks. * @return callback functions. */ private ImmutableList>> beforeCallbacksToFunctions( - Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + Plugin... plugins) { + return stream(plugins) + .map( + p -> + (Function>) ctx -> p.beforeAgentCallback(this, ctx)) .collect(toImmutableList()); } /** * Converts after-agent callbacks to functions. * - * @param callbacks After-agent callbacks. * @return callback functions. */ private ImmutableList>> afterCallbacksToFunctions( - Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + Plugin... plugins) { + return stream(plugins) + .map( + p -> (Function>) ctx -> p.afterAgentCallback(this, ctx)) .collect(toImmutableList()); } @@ -399,8 +431,11 @@ public abstract static class Builder> { protected String name; protected String description; protected ImmutableList subAgents; - protected ImmutableList beforeAgentCallback; - protected ImmutableList afterAgentCallback; + protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder(); + + protected CallbackPlugin.Builder callbackPluginBuilder() { + return callbackPluginBuilder; + } /** This is a safe cast to the concrete builder type. */ @SuppressWarnings("unchecked") @@ -434,25 +469,25 @@ public B subAgents(BaseAgent... subAgents) { @CanIgnoreReturnValue public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { - this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback); + callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); + beforeAgentCallback.forEach(callbackPluginBuilder::addCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { - this.afterAgentCallback = ImmutableList.of(afterAgentCallback); + callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); + afterAgentCallback.forEach(callbackPluginBuilder::addCallback); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java new file mode 100644 index 000000000..1d90f218c --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java @@ -0,0 +1,335 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackBase; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackBase; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackBase; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackBase; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackBase; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.PluginManager; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A plugin that wraps callbacks and exposes them as a plugin. */ +public class CallbackPlugin extends PluginManager { + + private static final Logger logger = LoggerFactory.getLogger(CallbackPlugin.class); + + private final ImmutableListMultimap, Object> callbacks; + + private CallbackPlugin( + ImmutableList plugins, + ImmutableListMultimap, Object> callbacks) { + super(plugins); + this.callbacks = callbacks; + } + + @Override + public String getName() { + return "CallbackPlugin"; + } + + @SuppressWarnings("unchecked") // The builder ensures that the type is correct. + private ImmutableList getCallbacks(Class type) { + return callbacks.get(type).stream().map(callback -> (T) callback).collect(toImmutableList()); + } + + public ImmutableList getBeforeAgentCallback() { + return getCallbacks(Callbacks.BeforeAgentCallback.class); + } + + public ImmutableList getAfterAgentCallback() { + return getCallbacks(Callbacks.AfterAgentCallback.class); + } + + public ImmutableList getBeforeModelCallback() { + return getCallbacks(Callbacks.BeforeModelCallback.class); + } + + public ImmutableList getAfterModelCallback() { + return getCallbacks(Callbacks.AfterModelCallback.class); + } + + public ImmutableList getBeforeToolCallback() { + return getCallbacks(Callbacks.BeforeToolCallback.class); + } + + public ImmutableList getAfterToolCallback() { + return getCallbacks(Callbacks.AfterToolCallback.class); + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link CallbackPlugin}. */ + public static class Builder { + // Ensures a unique name for each callback. + private static final AtomicInteger callbackId = new AtomicInteger(0); + + private final ImmutableList.Builder plugins = ImmutableList.builder(); + private final ListMultimap, Object> callbacks = ArrayListMultimap.create(); + + Builder() {} + + @CanIgnoreReturnValue + public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) { + callbacks.put(Callbacks.BeforeAgentCallback.class, callback); + plugins.add( + new BasePlugin("BeforeAgentCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe beforeAgentCallback( + BaseAgent agent, CallbackContext callbackContext) { + return callback.call(callbackContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync callback) { + return addBeforeAgentCallback( + callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); + } + + @CanIgnoreReturnValue + public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) { + callbacks.put(Callbacks.AfterAgentCallback.class, callback); + plugins.add( + new BasePlugin("AfterAgentCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe afterAgentCallback( + BaseAgent agent, CallbackContext callbackContext) { + return callback.call(callbackContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callback) { + return addAfterAgentCallback( + callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); + } + + @CanIgnoreReturnValue + public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) { + callbacks.put(Callbacks.BeforeModelCallback.class, callback); + plugins.add( + new BasePlugin("BeforeModelCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return callback.call(callbackContext, llmRequest); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync callback) { + return addBeforeModelCallback( + (callbackContext, llmRequest) -> + Maybe.fromOptional(callback.call(callbackContext, llmRequest))); + } + + @CanIgnoreReturnValue + public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) { + callbacks.put(Callbacks.AfterModelCallback.class, callback); + plugins.add( + new BasePlugin("AfterModelCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return callback.call(callbackContext, llmResponse); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callback) { + return addAfterModelCallback( + (callbackContext, llmResponse) -> + Maybe.fromOptional(callback.call(callbackContext, llmResponse))); + } + + @CanIgnoreReturnValue + public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) { + callbacks.put(Callbacks.BeforeToolCallback.class, callback); + plugins.add( + new BasePlugin("BeforeToolCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callback) { + return addBeforeToolCallback( + (invocationContext, tool, toolArgs, toolContext) -> + Maybe.fromOptional(callback.call(invocationContext, tool, toolArgs, toolContext))); + } + + @CanIgnoreReturnValue + public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) { + callbacks.put(Callbacks.AfterToolCallback.class, callback); + plugins.add( + new BasePlugin("AfterToolCallback_" + callbackId.getAndIncrement()) { + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return callback.call( + toolContext.invocationContext(), tool, toolArgs, toolContext, result); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterToolCallbackSync(Callbacks.AfterToolCallbackSync callback) { + return addAfterToolCallback( + (invocationContext, tool, toolArgs, toolContext, result) -> + Maybe.fromOptional( + callback.call(invocationContext, tool, toolArgs, toolContext, result))); + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeAgentCallbackBase callback) { + if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { + addBeforeAgentCallback(beforeAgentCallbackInstance); + } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { + addBeforeAgentCallbackSync(beforeAgentCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeAgentCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterAgentCallbackBase callback) { + if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { + addAfterAgentCallback(afterAgentCallbackInstance); + } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { + addAfterAgentCallbackSync(afterAgentCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterAgentCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeModelCallbackBase callback) { + if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { + addBeforeModelCallback(beforeModelCallbackInstance); + } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { + addBeforeModelCallbackSync(beforeModelCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterModelCallbackBase callback) { + if (callback instanceof AfterModelCallback afterModelCallbackInstance) { + addAfterModelCallback(afterModelCallbackInstance); + } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { + addAfterModelCallbackSync(afterModelCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeToolCallbackBase callback) { + if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { + addBeforeToolCallback(beforeToolCallbackInstance); + } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { + addBeforeToolCallbackSync(beforeToolCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterToolCallbackBase callback) { + if (callback instanceof AfterToolCallback afterToolCallbackInstance) { + addAfterToolCallback(afterToolCallbackInstance); + } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { + addAfterToolCallbackSync(afterToolCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + public CallbackPlugin build() { + return new CallbackPlugin(plugins.build(), ImmutableListMultimap.copyOf(callbacks)); + } + } +} diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index ab3ee7edb..25ba762ab 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -56,7 +56,6 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; @@ -95,10 +94,6 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final Optional> beforeModelCallback; - private final Optional> afterModelCallback; - private final Optional> beforeToolCallback; - private final Optional> afterToolCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -113,8 +108,7 @@ protected LlmAgent(Builder builder) { builder.name, builder.description, builder.subAgents, - builder.beforeAgentCallback, - builder.afterAgentCallback); + builder.callbackPluginBuilder.build()); this.model = Optional.ofNullable(builder.model); this.instruction = builder.instruction == null ? new Instruction.Static("") : builder.instruction; @@ -128,10 +122,6 @@ protected LlmAgent(Builder builder) { this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); - this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); - this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); - this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); @@ -173,10 +163,6 @@ public static class Builder extends BaseAgent.Builder { private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; - private ImmutableList beforeModelCallback; - private ImmutableList afterModelCallback; - private ImmutableList beforeToolCallback; - private ImmutableList afterToolCallback; private Schema inputSchema; private Schema outputSchema; private Executor executor; @@ -290,200 +276,86 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) { @CanIgnoreReturnValue public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { - this.beforeModelCallback = ImmutableList.of(beforeModelCallback); + callbackPluginBuilder.addBeforeModelCallback(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(List beforeModelCallback) { - if (beforeModelCallback == null) { - this.beforeModelCallback = null; - } else if (beforeModelCallback.isEmpty()) { - this.beforeModelCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeModelCallbackBase callback : beforeModelCallback) { - if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { - builder.add(beforeModelCallbackInstance); - } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { - builder.add( - (BeforeModelCallback) - (callbackContext, llmRequestBuilder) -> - Maybe.fromOptional( - beforeModelCallbackSyncInstance.call( - callbackContext, llmRequestBuilder))); - } else { - logger.warn( - "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.beforeModelCallback = builder.build(); - } - + beforeModelCallback.forEach(callbackPluginBuilder::addCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) { - this.beforeModelCallback = - ImmutableList.of( - (callbackContext, llmRequestBuilder) -> - Maybe.fromOptional( - beforeModelCallbackSync.call(callbackContext, llmRequestBuilder))); + callbackPluginBuilder.addBeforeModelCallbackSync(beforeModelCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(AfterModelCallback afterModelCallback) { - this.afterModelCallback = ImmutableList.of(afterModelCallback); + callbackPluginBuilder.addAfterModelCallback(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(List afterModelCallback) { - if (afterModelCallback == null) { - this.afterModelCallback = null; - } else if (afterModelCallback.isEmpty()) { - this.afterModelCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterModelCallbackBase callback : afterModelCallback) { - if (callback instanceof AfterModelCallback afterModelCallbackInstance) { - builder.add(afterModelCallbackInstance); - } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { - builder.add( - (AfterModelCallback) - (callbackContext, llmResponse) -> - Maybe.fromOptional( - afterModelCallbackSyncInstance.call(callbackContext, llmResponse))); - } else { - logger.warn( - "Invalid afterModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.afterModelCallback = builder.build(); - } - + afterModelCallback.forEach(callbackPluginBuilder::addCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallbackSync(AfterModelCallbackSync afterModelCallbackSync) { - this.afterModelCallback = - ImmutableList.of( - (callbackContext, llmResponse) -> - Maybe.fromOptional(afterModelCallbackSync.call(callbackContext, llmResponse))); + callbackPluginBuilder.addAfterModelCallbackSync(afterModelCallbackSync); return this; } @CanIgnoreReturnValue public Builder beforeAgentCallbackSync(BeforeAgentCallbackSync beforeAgentCallbackSync) { - this.beforeAgentCallback = - ImmutableList.of( - (callbackContext) -> - Maybe.fromOptional(beforeAgentCallbackSync.call(callbackContext))); + callbackPluginBuilder.addBeforeAgentCallbackSync(beforeAgentCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackSync) { - this.afterAgentCallback = - ImmutableList.of( - (callbackContext) -> - Maybe.fromOptional(afterAgentCallbackSync.call(callbackContext))); + callbackPluginBuilder.addAfterAgentCallbackSync(afterAgentCallbackSync); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) { - this.beforeToolCallback = ImmutableList.of(beforeToolCallback); + callbackPluginBuilder.addBeforeToolCallback(beforeToolCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback( @Nullable List beforeToolCallbacks) { - if (beforeToolCallbacks == null) { - this.beforeToolCallback = null; - } else if (beforeToolCallbacks.isEmpty()) { - this.beforeToolCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeToolCallbackBase callback : beforeToolCallbacks) { - if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { - builder.add(beforeToolCallbackInstance); - } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { - builder.add( - (invocationContext, baseTool, input, toolContext) -> - Maybe.fromOptional( - beforeToolCallbackSyncInstance.call( - invocationContext, baseTool, input, toolContext))); - } else { - logger.warn( - "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.beforeToolCallback = builder.build(); - } + beforeToolCallbacks.forEach(callbackPluginBuilder::addCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) { - this.beforeToolCallback = - ImmutableList.of( - (invocationContext, baseTool, input, toolContext) -> - Maybe.fromOptional( - beforeToolCallbackSync.call( - invocationContext, baseTool, input, toolContext))); + callbackPluginBuilder.addBeforeToolCallbackSync(beforeToolCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(AfterToolCallback afterToolCallback) { - this.afterToolCallback = ImmutableList.of(afterToolCallback); + callbackPluginBuilder.addAfterToolCallback(afterToolCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(@Nullable List afterToolCallbacks) { - if (afterToolCallbacks == null) { - this.afterToolCallback = null; - } else if (afterToolCallbacks.isEmpty()) { - this.afterToolCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterToolCallbackBase callback : afterToolCallbacks) { - if (callback instanceof AfterToolCallback afterToolCallbackInstance) { - builder.add(afterToolCallbackInstance); - } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { - builder.add( - (invocationContext, baseTool, input, toolContext, response) -> - Maybe.fromOptional( - afterToolCallbackSyncInstance.call( - invocationContext, baseTool, input, toolContext, response))); - } else { - logger.warn( - "Invalid afterToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.afterToolCallback = builder.build(); - } + afterToolCallbacks.forEach(callbackPluginBuilder::addCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) { - this.afterToolCallback = - ImmutableList.of( - (invocationContext, baseTool, input, toolContext, response) -> - Maybe.fromOptional( - afterToolCallbackSync.call( - invocationContext, baseTool, input, toolContext, response))); + callbackPluginBuilder.addAfterToolCallbackSync(afterToolCallbackSync); return this; } @@ -757,19 +629,19 @@ public boolean disallowTransferToPeers() { } public Optional> beforeModelCallback() { - return beforeModelCallback; + return Optional.of(callbackPlugin.getBeforeModelCallback()); } public Optional> afterModelCallback() { - return afterModelCallback; + return Optional.of(callbackPlugin.getAfterModelCallback()); } public Optional> beforeToolCallback() { - return beforeToolCallback; + return Optional.of(callbackPlugin.getBeforeToolCallback()); } public Optional> afterToolCallback() { - return afterToolCallback; + return Optional.of(callbackPlugin.getAfterToolCallback()); } public Optional inputSchema() { @@ -830,8 +702,8 @@ private Model resolveModelInternal() { } BaseAgent current = this.parentAgent(); while (current != null) { - if (current instanceof LlmAgent) { - return ((LlmAgent) current).resolvedModel(); + if (current instanceof LlmAgent llmAgent) { + return llmAgent.resolvedModel(); } current = current.parentAgent(); } diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index d9d049f80..921ef3689 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -46,16 +46,13 @@ public class LoopAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private LoopAgent( - String name, - String description, - List subAgents, - Optional maxIterations, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); - this.maxIterations = maxIterations; + private LoopAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); + this.maxIterations = builder.maxIterations; } /** Builder for {@link LoopAgent}. */ @@ -76,9 +73,7 @@ public Builder maxIterations(Optional maxIterations) { @Override public LoopAgent build() { - // TODO(b/410859954): Add validation for required fields like name. - return new LoopAgent( - name, description, subAgents, maxIterations, beforeAgentCallback, afterAgentCallback); + return new LoopAgent(this); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index f30d951aa..583bfffcb 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -45,14 +45,12 @@ public class ParallelAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private ParallelAgent( - String name, - String description, - List subAgents, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + private ParallelAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); } /** Builder for {@link ParallelAgent}. */ @@ -60,8 +58,7 @@ public static class Builder extends BaseAgent.Builder { @Override public ParallelAgent build() { - return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + return new ParallelAgent(this); } } diff --git a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java index 7d3a5acb9..dc7480f58 100644 --- a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java +++ b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java @@ -34,6 +34,11 @@ public ReadonlyContext(InvocationContext invocationContext) { this.invocationContext = invocationContext; } + /** Returns the invocation context. */ + public InvocationContext invocationContext() { + return invocationContext; + } + /** Returns the user content that initiated this invocation. */ public Optional userContent() { return invocationContext.userContent(); diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index b0b45a0ec..aa4b76fb6 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -18,7 +18,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; -import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,14 +35,12 @@ public class SequentialAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private SequentialAgent( - String name, - String description, - List subAgents, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + private SequentialAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); } /** Builder for {@link SequentialAgent}. */ @@ -51,9 +48,7 @@ public static class Builder extends BaseAgent.Builder { @Override public SequentialAgent build() { - // TODO(b/410859954): Add validation for required fields like name. - return new SequentialAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + return new SequentialAgent(this); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 6e06a34ab..92ec8791e 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -20,14 +20,12 @@ import com.google.adk.events.Event; import com.google.adk.testing.TestBaseAgent; +import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -51,37 +49,97 @@ public void constructor_setsNameAndDescription() { @Test public void runAsync_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunAsyncImplAndAfterCallback() { - AtomicBoolean runAsyncImplCalled = new AtomicBoolean(false); - AtomicBoolean afterAgentCallbackCalled = new AtomicBoolean(false); + var runAsyncImpl = TestCallback.returningEmpty(); Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - Callbacks.BeforeAgentCallback beforeCallback = (callbackContext) -> Maybe.just(callbackContent); - Callbacks.AfterAgentCallback afterCallback = - (callbackContext) -> { - afterAgentCallbackCalled.set(true); - return Maybe.empty(); - }; + var beforeCallback = TestCallback.returning(callbackContent); + var afterCallback = TestCallback.returningEmpty(); TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback), - ImmutableList.of(afterCallback), - () -> - Flowable.defer( - () -> { - runAsyncImplCalled.set(true); - return Flowable.just( - Event.builder() - .content(Content.fromParts(Part.fromText("main_output"))) - .build()); - })); + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier("main_output")); InvocationContext invocationContext = TestUtils.createInvocationContext(agent); List results = agent.runAsync(invocationContext).toList().blockingGet(); assertThat(results).hasSize(1); assertThat(results.get(0).content()).hasValue(callbackContent); - assertThat(runAsyncImplCalled.get()).isFalse(); - assertThat(afterAgentCallbackCalled.get()).isFalse(); + assertThat(runAsyncImpl.wasCalled()).isFalse(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isFalse(); + } + + @Test + public void runAsync_noCallbacks_invokesRunAsyncImpl() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + /* beforeAgentCallbacks= */ ImmutableList.of(), + /* afterAgentCallbacks= */ ImmutableList.of(), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + } + + @Test + public void + runAsync_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunAsyncImplAndAfterCallbacks() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runAsync_afterCallbackReturnsContent_invokesRunAsyncImplAndAfterCallbacksAndReturnsAllContent() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returning(afterCallbackContent); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); } } diff --git a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java new file mode 100644 index 000000000..361c86193 --- /dev/null +++ b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java @@ -0,0 +1,499 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.events.EventActions; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestCallback; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class CallbackPluginTest { + + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + @Mock private BaseAgent agent; + @Mock private BaseTool tool; + @Mock private ToolContext toolContext; + private InvocationContext invocationContext; + private CallbackContext callbackContext; + + @Before + public void setUp() { + invocationContext = createInvocationContext(agent); + callbackContext = + new CallbackContext( + invocationContext, + EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()); + } + + @Test + public void build_empty_successful() { + CallbackPlugin plugin = CallbackPlugin.builder().build(); + assertThat(plugin.getName()).isEqualTo("CallbackPlugin"); + assertThat(plugin.getBeforeAgentCallback()).isEmpty(); + assertThat(plugin.getAfterAgentCallback()).isEmpty(); + assertThat(plugin.getBeforeModelCallback()).isEmpty(); + assertThat(plugin.getAfterModelCallback()).isEmpty(); + assertThat(plugin.getBeforeToolCallback()).isEmpty(); + assertThat(plugin.getAfterToolCallback()).isEmpty(); + } + + @Test + public void addBeforeAgentCallback_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + BeforeAgentCallback callback = testCallback.asBeforeAgentCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeAgentCallback(callback).build(); + + assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); + + Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addBeforeAgentCallbackSync_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeAgentCallbackSync(testCallback.asBeforeAgentCallbackSync()) + .build(); + + assertThat(plugin.getBeforeAgentCallback()).hasSize(1); + + Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addAfterAgentCallback_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + AfterAgentCallback callback = testCallback.asAfterAgentCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterAgentCallback(callback).build(); + + assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); + + Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addAfterAgentCallbackSync_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterAgentCallbackSync(testCallback.asAfterAgentCallbackSync()) + .build(); + + assertThat(plugin.getAfterAgentCallback()).hasSize(1); + + Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addBeforeModelCallback_isReturnedAndInvoked() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback = TestCallback.returning(expectedResponse); + BeforeModelCallback callback = testCallback.asBeforeModelCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeModelCallback(callback).build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addBeforeModelCallbackSync_isReturnedAndInvoked() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback = TestCallback.returning(expectedResponse); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallbackSync(testCallback.asBeforeModelCallbackSync()) + .build(); + + assertThat(plugin.getBeforeModelCallback()).hasSize(1); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addAfterModelCallback_isReturnedAndInvoked() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); + var testCallback = TestCallback.returning(expectedResponse); + AfterModelCallback callback = testCallback.asAfterModelCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallback(callback).build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addAfterModelCallbackSync_isReturnedAndInvoked() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); + var testCallback = TestCallback.returning(expectedResponse); + AfterModelCallbackSync callback = testCallback.asAfterModelCallbackSync(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallbackSync(callback).build(); + + assertThat(plugin.getAfterModelCallback()).hasSize(1); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addBeforeToolCallback_isReturnedAndInvoked() { + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + BeforeToolCallback callback = testCallback.asBeforeToolCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeToolCallback(callback).build(); + + assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); + + Map result = + plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addBeforeToolCallbackSync_isReturnedAndInvoked() { + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeToolCallbackSync(testCallback.asBeforeToolCallbackSync()) + .build(); + + assertThat(plugin.getBeforeToolCallback()).hasSize(1); + + Map result = + plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addAfterToolCallback_isReturnedAndInvoked() { + ImmutableMap initialResult = ImmutableMap.of("initial", "result"); + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + AfterToolCallback callback = testCallback.asAfterToolCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallback(callback).build(); + + assertThat(plugin.getAfterToolCallback()).containsExactly(callback); + + Map result = + plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addAfterToolCallbackSync_isReturnedAndInvoked() { + ImmutableMap initialResult = ImmutableMap.of("initial", "result"); + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + AfterToolCallbackSync callback = testCallback.asAfterToolCallbackSync(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallbackSync(callback).build(); + + assertThat(plugin.getAfterToolCallback()).hasSize(1); + + Map result = + plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addCallback_beforeAgentCallback() { + BeforeAgentCallback callback = ctx -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeAgentCallbackSync() { + BeforeAgentCallbackSync callback = ctx -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeAgentCallback()).hasSize(1); + } + + @Test + public void addCallback_afterAgentCallback() { + AfterAgentCallback callback = ctx -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterAgentCallbackSync() { + AfterAgentCallbackSync callback = ctx -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterAgentCallback()).hasSize(1); + } + + @Test + public void addCallback_beforeModelCallback() { + BeforeModelCallback callback = (ctx, req) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeModelCallbackSync() { + BeforeModelCallbackSync callback = (ctx, req) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeModelCallback()).hasSize(1); + } + + @Test + public void addCallback_afterModelCallback() { + AfterModelCallback callback = (ctx, res) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterModelCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterModelCallbackSync() { + AfterModelCallbackSync callback = (ctx, res) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterModelCallback()).hasSize(1); + } + + @Test + public void addCallback_beforeToolCallback() { + BeforeToolCallback callback = (invCtx, tool, toolArgs, toolCtx) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeToolCallbackSync() { + BeforeToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeToolCallback()).hasSize(1); + } + + @Test + public void addCallback_afterToolCallback() { + AfterToolCallback callback = (invCtx, tool, toolArgs, toolCtx, res) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterToolCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterToolCallbackSync() { + AfterToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx, res) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterToolCallback()).hasSize(1); + } + + @Test + public void addMultipleBeforeModelCallbacks_invokedInOrder() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returning(expectedResponse); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleBeforeModelCallbacks_shortCircuit() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returning(expectedResponse); + var testCallback2 = TestCallback.returningEmpty(); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isFalse(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleAfterModelCallbacks_shortCircuit() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("response"))).build(); + var testCallback1 = TestCallback.returning(expectedResponse); + var testCallback2 = TestCallback.returningEmpty(); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isFalse(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleAfterModelCallbacks_invokedInOrder() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("second"))).build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returning(expectedResponse); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleBeforeModelCallbacks_bothEmpty_returnsEmpty() { + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returningEmpty(); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isNull(); + } + + @Test + public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() { + LlmResponse initialResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returningEmpty(); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isNull(); + } +} diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java new file mode 100644 index 000000000..04f83ed9b --- /dev/null +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -0,0 +1,164 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.testing; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +/** + * A test helper that wraps an {@link AtomicBoolean} and provides factory methods for creating + * callbacks that update the boolean when called. + * + * @param The type of the result returned by the callback. + */ +public final class TestCallback { + private final AtomicBoolean called = new AtomicBoolean(false); + private final Optional result; + + private TestCallback(Optional result) { + this.result = result; + } + + /** Creates a {@link TestCallback} that returns the given result. */ + public static TestCallback returning(T result) { + return new TestCallback<>(Optional.of(result)); + } + + /** Creates a {@link TestCallback} that returns an empty result. */ + public static TestCallback returningEmpty() { + return new TestCallback<>(Optional.empty()); + } + + /** Returns true if the callback was called. */ + public boolean wasCalled() { + return called.get(); + } + + /** Marks the callback as called. */ + public void markAsCalled() { + called.set(true); + } + + private Maybe callMaybe() { + called.set(true); + return result.map(Maybe::just).orElseGet(Maybe::empty); + } + + private Optional callOptional() { + called.set(true); + return result; + } + + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + * with an event containing the given content. + */ + public Supplier> asRunAsyncImplSupplier(Content content) { + return () -> + Flowable.defer( + () -> { + markAsCalled(); + return Flowable.just(Event.builder().content(content).build()); + }); + } + + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + */ + public Supplier> asRunAsyncImplSupplier(String contentText) { + return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public BeforeAgentCallback asBeforeAgentCallback() { + return ctx -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { + return ctx -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public AfterAgentCallback asAfterAgentCallback() { + return ctx -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public AfterAgentCallbackSync asAfterAgentCallbackSync() { + return ctx -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public BeforeModelCallback asBeforeModelCallback() { + return (ctx, req) -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public BeforeModelCallbackSync asBeforeModelCallbackSync() { + return (ctx, req) -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public AfterModelCallback asAfterModelCallback() { + return (ctx, res) -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public AfterModelCallbackSync asAfterModelCallbackSync() { + return (ctx, res) -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public BeforeToolCallback asBeforeToolCallback() { + return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public BeforeToolCallbackSync asBeforeToolCallbackSync() { + return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public AfterToolCallback asAfterToolCallback() { + return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public AfterToolCallbackSync asAfterToolCallbackSync() { + return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); + } +}