Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 29 additions & 71 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

package com.google.adk.agents;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Arrays.stream;

import com.google.adk.Telemetry;
import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
Expand Down Expand Up @@ -255,9 +252,8 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
spanContext,
span,
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(), callbackPlugin),
processAgentCallbackResult(
ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
Expand All @@ -271,10 +267,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
Flowable<Event> afterEvents =
Flowable.defer(
() ->
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
callbackPlugin),
processAgentCallbackResult(
ctx ->
invocationContext
.combinedPlugin()
.afterAgentCallback(this, ctx),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

Expand All @@ -284,71 +281,32 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
}

/**
* Converts before-agent callbacks to functions.
* Processes the result of an agent callback, creating an {@link Event} if necessary.
*
* @return callback functions.
* @param agentCallback The callback function.
* @param invocationContext The current invocation context.
* @return A {@link Single} emitting an {@link Optional} containing the created {@link Event}, or
* {@link Optional#empty()} if no event is produced.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
Plugin... plugins) {
return stream(plugins)
.map(
p ->
(Function<CallbackContext, Maybe<Content>>) ctx -> p.beforeAgentCallback(this, ctx))
.collect(toImmutableList());
}

/**
* Converts after-agent callbacks to functions.
*
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
Plugin... plugins) {
return stream(plugins)
.map(
p -> (Function<CallbackContext, Maybe<Content>>) ctx -> p.afterAgentCallback(this, ctx))
.collect(toImmutableList());
}

/**
* Calls agent callbacks and returns the first produced event, if any.
*
* @param agentCallbacks Callback functions.
* @param invocationContext Current invocation context.
* @return single emitting first event, or empty if none.
*/
private Single<Optional<Event>> callCallback(
List<Function<CallbackContext, Maybe<Content>>> agentCallbacks,
private Single<Optional<Event>> processAgentCallbackResult(
Function<CallbackContext, Maybe<Content>> agentCallback,
InvocationContext invocationContext) {
if (agentCallbacks == null || agentCallbacks.isEmpty()) {
return Single.just(Optional.empty());
}

CallbackContext callbackContext =
new CallbackContext(invocationContext, /* eventActions= */ null);

return Flowable.fromIterable(agentCallbacks)
.concatMap(
callback -> {
Maybe<Content> maybeContent = callback.apply(callbackContext);

return maybeContent
.map(
content -> {
invocationContext.setEndInvocation(true);
return Optional.of(
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author(name())
.branch(invocationContext.branch())
.actions(callbackContext.eventActions())
.content(content)
.build());
})
.toFlowable();
var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null);
return agentCallback
.apply(callbackContext)
.map(
content -> {
invocationContext.setEndInvocation(true);
return Optional.of(
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author(name())
.branch(invocationContext.branch())
.actions(callbackContext.eventActions())
.content(content)
.build());
})
.firstElement()
.switchIfEmpty(
Single.defer(
() -> {
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.adk.plugins.PluginManager;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.InlineMe;
Expand All @@ -44,6 +45,7 @@ public class InvocationContext {
private final BaseArtifactService artifactService;
private final BaseMemoryService memoryService;
private final Plugin pluginManager;
private final Plugin combinedPlugin;
private final Optional<LiveRequestQueue> liveRequestQueue;
private final Map<String, ActiveStreamingTool> activeStreamingTools;
private final String invocationId;
Expand Down Expand Up @@ -73,6 +75,13 @@ protected InvocationContext(Builder builder) {
this.endInvocation = builder.endInvocation;
this.resumabilityConfig = builder.resumabilityConfig;
this.invocationCostManager = builder.invocationCostManager;
this.combinedPlugin =
Optional.ofNullable(builder.agent)
.map(BaseAgent::getPlugin)
.map(
agentPlugin ->
(Plugin) new PluginManager(ImmutableList.of(pluginManager, agentPlugin)))
.orElse(pluginManager);
}

/**
Expand Down Expand Up @@ -235,6 +244,10 @@ public Plugin pluginManager() {
return pluginManager;
}

public Plugin combinedPlugin() {
return combinedPlugin;
}

/** Returns a map of tool call IDs to active streaming tools for the current invocation. */
public Map<String, ActiveStreamingTool> activeStreamingTools() {
return activeStreamingTools;
Expand Down
48 changes: 8 additions & 40 deletions core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ private Flowable<LlmResponse> callLlm(
.onErrorResumeNext(
exception ->
context
.pluginManager()
.combinedPlugin()
.onModelErrorCallback(
new CallbackContext(
context, eventForCallbackUsage.actions()),
Expand Down Expand Up @@ -243,27 +243,9 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
Event callbackEvent = modelResponseEvent.toBuilder().build();
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());

Maybe<LlmResponse> pluginResult =
context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder);

LlmAgent agent = (LlmAgent) context.agent();

Optional<List<? extends BeforeModelCallback>> callbacksOpt = agent.beforeModelCallback();
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty());
}

List<? extends BeforeModelCallback> callbacks = callbacksOpt.get();

Maybe<LlmResponse> callbackResult =
Maybe.defer(
() ->
Flowable.fromIterable(callbacks)
.concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder))
.firstElement());

return pluginResult
.switchIfEmpty(callbackResult)
return context
.combinedPlugin()
.beforeModelCallback(callbackContext, llmRequestBuilder)
.map(Optional::of)
.defaultIfEmpty(Optional.empty());
}
Expand All @@ -279,24 +261,10 @@ private Single<LlmResponse> handleAfterModelCallback(
Event callbackEvent = modelResponseEvent.toBuilder().build();
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());

Maybe<LlmResponse> pluginResult =
context.pluginManager().afterModelCallback(callbackContext, llmResponse);

LlmAgent agent = (LlmAgent) context.agent();
Optional<List<? extends AfterModelCallback>> callbacksOpt = agent.afterModelCallback();

if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
return pluginResult.defaultIfEmpty(llmResponse);
}

Maybe<LlmResponse> callbackResult =
Maybe.defer(
() ->
Flowable.fromIterable(callbacksOpt.get())
.concatMapMaybe(callback -> callback.call(callbackContext, llmResponse))
.firstElement());

return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);
return context
.combinedPlugin()
.afterModelCallback(callbackContext, llmResponse)
.defaultIfEmpty(llmResponse);
}

/**
Expand Down
64 changes: 5 additions & 59 deletions core/src/main/java/com/google/adk/flows/llmflows/Functions.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@

import com.google.adk.Telemetry;
import com.google.adk.agents.ActiveStreamingTool;
import com.google.adk.agents.Callbacks.AfterToolCallback;
import com.google.adk.agents.Callbacks.BeforeToolCallback;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.RunConfig.ToolExecutionMode;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
Expand Down Expand Up @@ -388,7 +385,7 @@ private static Maybe<Event> postProcessFunctionResult(
.onErrorResumeNext(
t ->
invocationContext
.pluginManager()
.combinedPlugin()
.onToolErrorCallback(tool, functionArgs, toolContext, t)
.map(isLive ? Optional::ofNullable : Optional::of)
.switchIfEmpty(Single.error(t)))
Expand Down Expand Up @@ -457,30 +454,7 @@ private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(
BaseTool tool,
Map<String, Object> functionArgs,
ToolContext toolContext) {
if (invocationContext.agent() instanceof LlmAgent) {
LlmAgent agent = (LlmAgent) invocationContext.agent();

Maybe<Map<String, Object>> pluginResult =
invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext);

Optional<List<? extends BeforeToolCallback>> callbacksOpt = agent.beforeToolCallback();
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
return pluginResult;
}
List<? extends BeforeToolCallback> callbacks = callbacksOpt.get();

Maybe<Map<String, Object>> callbackResult =
Maybe.defer(
() ->
Flowable.fromIterable(callbacks)
.concatMapMaybe(
callback ->
callback.call(invocationContext, tool, functionArgs, toolContext))
.firstElement());

return pluginResult.switchIfEmpty(callbackResult);
}
return Maybe.empty();
return invocationContext.combinedPlugin().beforeToolCallback(tool, functionArgs, toolContext);
}

private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
Expand All @@ -489,37 +463,9 @@ private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
Map<String, Object> functionArgs,
ToolContext toolContext,
Map<String, Object> functionResult) {
if (invocationContext.agent() instanceof LlmAgent) {
LlmAgent agent = (LlmAgent) invocationContext.agent();

Maybe<Map<String, Object>> pluginResult =
invocationContext
.pluginManager()
.afterToolCallback(tool, functionArgs, toolContext, functionResult);

Optional<List<? extends AfterToolCallback>> callbacksOpt = agent.afterToolCallback();
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
return pluginResult;
}
List<? extends AfterToolCallback> callbacks = callbacksOpt.get();

Maybe<Map<String, Object>> callbackResult =
Maybe.defer(
() ->
Flowable.fromIterable(callbacks)
.concatMapMaybe(
callback ->
callback.call(
invocationContext,
tool,
functionArgs,
toolContext,
functionResult))
.firstElement());

return pluginResult.switchIfEmpty(callbackResult);
}
return Maybe.empty();
return invocationContext
.combinedPlugin()
.afterToolCallback(tool, functionArgs, toolContext, functionResult);
}

private static Maybe<Map<String, Object>> callTool(
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ public Flowable<Event> runAsync(
updatedSession,
session);
return contextWithUpdatedSession
.pluginManager()
.combinedPlugin()
.onEventCallback(
contextWithUpdatedSession,
registeredEvent)
Expand Down
Loading