Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions core/src/main/java/com/google/adk/models/LlmRegistry.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ public interface LlmFactory {
/** Map of model name patterns regex to factories. */
private static final Map<String, LlmFactory> llmFactories = new ConcurrentHashMap<>();

/** Registers default LLM factories, e.g. for Gemini models. */
/* Registers default LLM factories, e.g. for Gemini models. */
static {
registerLlm("gemini-.*", modelName -> Gemini.builder().modelName(modelName).build());
registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build());
registerViaReflection("com.google.adk.models.Gemini", "gemini-.*");
registerViaReflection("com.google.adk.models.ApigeeLlm", "apigee/.*");
}

/**
Expand Down Expand Up @@ -78,6 +78,31 @@ private static BaseLlm createLlm(String modelName) {
throw new IllegalArgumentException("Unsupported model: " + modelName);
}

/**
* Registers an LLM factory via reflection, if the class is available.
*
* @param className The fully qualified class name of the LLM.
* @param pattern The regex pattern for matching model names.
*/
private static void registerViaReflection(String className, String pattern) {
try {
Class<?> llmClass = Class.forName(className);
LlmFactory factory =
modelName -> {
try {
Object builder = llmClass.getMethod("builder").invoke(null);
builder.getClass().getMethod("modelName", String.class).invoke(builder, modelName);
return (BaseLlm) builder.getClass().getMethod("build").invoke(builder);
} catch (ReflectiveOperationException e) {
throw new IllegalArgumentException("Failed to create instance of " + className, e);
}
};
registerLlm(pattern, factory);
} catch (ClassNotFoundException e) {
// ignore - LLM not available.
}
}

/**
* Registers an LLM factory for testing purposes. Clears cached instances matching the given
* pattern to ensure test isolation.
Expand Down
145 changes: 80 additions & 65 deletions core/src/main/java/com/google/adk/utils/ComponentRegistry.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,13 @@

import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.Callbacks;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.LoopAgent;
import com.google.adk.agents.ParallelAgent;
import com.google.adk.agents.SequentialAgent;
import com.google.adk.tools.AgentTool;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.ExampleTool;
import com.google.adk.tools.ExitLoopTool;
import com.google.adk.tools.GoogleMapsTool;
import com.google.adk.tools.GoogleSearchTool;
import com.google.adk.tools.LoadArtifactsTool;
import com.google.adk.tools.LongRunningFunctionTool;
import com.google.adk.tools.UrlContextTool;
import com.google.adk.tools.mcp.McpToolset;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nonnull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -93,6 +80,8 @@
public class ComponentRegistry {

private static final Logger logger = LoggerFactory.getLogger(ComponentRegistry.class);
private static volatile ImmutableMap<String, Object> DEFAULT_REGISTRY;

private static volatile ComponentRegistry instance = new ComponentRegistry();

private final Map<String, Object> registry = new ConcurrentHashMap<>();
Expand All @@ -103,55 +92,80 @@ protected ComponentRegistry() {

/** Initializes the registry with base pre-wired ADK instances. */
private void initializePreWiredEntries() {
registerAdkAgentClass(LlmAgent.class);
registerAdkAgentClass(LoopAgent.class);
registerAdkAgentClass(ParallelAgent.class);
registerAdkAgentClass(SequentialAgent.class);

registerAdkToolInstance("google_search", GoogleSearchTool.INSTANCE);
registerAdkToolInstance("load_artifacts", LoadArtifactsTool.INSTANCE);
registerAdkToolInstance("exit_loop", ExitLoopTool.INSTANCE);
registerAdkToolInstance("url_context", UrlContextTool.INSTANCE);
registerAdkToolInstance("google_maps_grounding", GoogleMapsTool.INSTANCE);

registerAdkToolClass(AgentTool.class);
registerAdkToolClass(LongRunningFunctionTool.class);
registerAdkToolClass(ExampleTool.class);

registerAdkToolsetClass(McpToolset.class);
// TODO: add all python tools that also exist in Java.

if (DEFAULT_REGISTRY == null) {
synchronized (ComponentRegistry.class) {
if (DEFAULT_REGISTRY == null) {
registerAdkClassByName("com.google.adk.agents.LlmAgent");
registerAdkClassByName("com.google.adk.agents.LoopAgent");
registerAdkClassByName("com.google.adk.agents.ParallelAgent");
registerAdkClassByName("com.google.adk.agents.SequentialAgent");

registerAdkToolInstance("google_search", "com.google.adk.tools.GoogleSearchTool");
registerAdkToolInstance("load_artifacts", "com.google.adk.tools.LoadArtifactsTool");
registerAdkToolInstance("exit_loop", "com.google.adk.tools.ExitLoopTool");
registerAdkToolInstance("url_context", "com.google.adk.tools.UrlContextTool");
registerAdkToolInstance("google_maps_grounding", "com.google.adk.tools.GoogleMapsTool");

registerAdkClassByName("com.google.adk.tools.AgentTool");
registerAdkClassByName("com.google.adk.tools.LongRunningFunctionTool");
registerAdkClassByName("com.google.adk.tools.ExampleTool");

registerAdkClassByName("com.google.adk.tools.mcp.McpToolset");
// TODO: add all python tools that also exist in Java.

DEFAULT_REGISTRY = ImmutableMap.copyOf(registry);
return;
}
}
}
registry.putAll(DEFAULT_REGISTRY);
logger.debug("Initialized base pre-wired entries in ComponentRegistry");
}

private void registerAdkAgentClass(Class<? extends BaseAgent> agentClass) {
registry.put(agentClass.getName(), agentClass);
// For python compatibility, also register the name used in ADK Python.
registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass);
}
private void registerAdkClassByName(String className) {
try {
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(className);
String standardPrefix;
if (BaseAgent.class.isAssignableFrom(clazz)) {
standardPrefix = "google.adk.agents.";
} else if (BaseTool.class.isAssignableFrom(clazz)) {
standardPrefix = "google.adk.tools.";
} else if (BaseToolset.class.isAssignableFrom(clazz)) {
standardPrefix = "google.adk.tools.";
} else {
throw new IllegalArgumentException(
"Cannot determine standardPrefix for type " + clazz.getName());
}

private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) {
registry.put(name, toolInstance);
// For python compatibility, also register the name used in ADK Python.
registry.put("google.adk.tools." + name, toolInstance);
}
registry.put(clazz.getName(), clazz);
// For python compatibility, also register the name used in ADK Python.
registry.put(standardPrefix + clazz.getSimpleName(), clazz);

private void registerAdkToolClass(@Nonnull Class<?> toolClass) {
registry.put(toolClass.getName(), toolClass);
// For python compatibility, also register the name used in ADK Python.
registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass);
if (BaseToolset.class.isAssignableFrom(clazz)) {
registry.put(clazz.getSimpleName(), clazz);
if (clazz.getSimpleName().equals("McpToolset")) {
registry.put("mcp.McpToolset", clazz);
}
}
} catch (Exception e) {
logger.info(
"{} not found, skipping registration: {}",
className.substring(className.lastIndexOf('.') + 1),
e.getMessage());
}
}

private void registerAdkToolsetClass(@Nonnull Class<? extends BaseToolset> toolsetClass) {
registry.put(toolsetClass.getName(), toolsetClass);
// For python compatibility, also register the name used in ADK Python.
registry.put("google.adk.tools." + toolsetClass.getSimpleName(), toolsetClass);
// Also register by simple class name
registry.put(toolsetClass.getSimpleName(), toolsetClass);
// Special support for toolsets with various naming conventions
String simpleName = toolsetClass.getSimpleName();
if (simpleName.equals("McpToolset")) {
registry.put("mcp.McpToolset", toolsetClass);
private void registerAdkToolInstance(String name, String toolClassName) {
try {
Object toolInstance = Class.forName(toolClassName).getField("INSTANCE").get(null);
registry.put(name, toolInstance);
// For python compatibility, also register the name used in ADK Python.
registry.put("google.adk.tools." + name, toolInstance);
} catch (Exception e) {
logger.info(
"{} not found, skipping registration: {}",
toolClassName.substring(toolClassName.lastIndexOf('.') + 1),
e.getMessage());
}
}

Expand Down Expand Up @@ -281,30 +295,31 @@ public static Optional<BaseAgent> resolveAgentInstance(String name) {
*/
public static Class<? extends BaseAgent> resolveAgentClass(String agentClassName) {
// If no agent_class is specified, it will default to LlmAgent.
if (isNullOrEmpty(agentClassName)) {
return LlmAgent.class;
}
final String effectiveAgentClassName =
isNullOrEmpty(agentClassName) ? "com.google.adk.agents.LlmAgent" : agentClassName;

Optional<Class<? extends BaseAgent>> agentClass;

if (agentClassName.contains(".")) {
if (effectiveAgentClassName.contains(".")) {
// If agentClassName contains '.', use it directly
agentClass = getType(agentClassName, BaseAgent.class);
agentClass = getType(effectiveAgentClassName, BaseAgent.class);
} else {
// First try the simple name
agentClass =
getType(agentClassName, BaseAgent.class)
getType(effectiveAgentClassName, BaseAgent.class)
// If not found, try with com.google.adk.agents prefix
.or(() -> getType("com.google.adk.agents." + agentClassName, BaseAgent.class))
.or(
() ->
getType("com.google.adk.agents." + effectiveAgentClassName, BaseAgent.class))
// For Python compatibility, also try with google.adk.agents prefix
.or(() -> getType("google.adk.agents." + agentClassName, BaseAgent.class));
.or(() -> getType("google.adk.agents." + effectiveAgentClassName, BaseAgent.class));
}

return agentClass.orElseThrow(
() ->
new IllegalArgumentException(
"agentClass '"
+ agentClassName
+ effectiveAgentClassName
+ "' is not in registry or not a subclass of BaseAgent."));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ public void testResolveToolClass_withFullyQualifiedName() {

@Test
public void testMcpToolsetRegistration() {
ComponentRegistry registry = ComponentRegistry.getInstance();
ComponentRegistry registry = new ComponentRegistry();

// Verify direct registry storage (tests lines 134, 136, 138, 142 in ComponentRegistry.java)
Optional<Object> directFullName = registry.get("com.google.adk.tools.mcp.McpToolset");
Expand Down