diff --git a/core/src/main/java/com/google/adk/models/LlmRegistry.java b/core/src/main/java/com/google/adk/models/LlmRegistry.java index a73d89430..50c4d6f4a 100644 --- a/core/src/main/java/com/google/adk/models/LlmRegistry.java +++ b/core/src/main/java/com/google/adk/models/LlmRegistry.java @@ -34,10 +34,10 @@ public interface LlmFactory { /** Map of model name patterns regex to factories. */ private static final Map llmFactories = new ConcurrentHashMap<>(); - /** Registers default LLM factories, e.g. for Gemini models. */ + /* Registers default LLM factories, e.g. for Gemini models. */ static { - registerLlm("gemini-.*", modelName -> Gemini.builder().modelName(modelName).build()); - registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build()); + registerViaReflection("com.google.adk.models.Gemini", "gemini-.*"); + registerViaReflection("com.google.adk.models.ApigeeLlm", "apigee/.*"); } /** @@ -78,6 +78,31 @@ private static BaseLlm createLlm(String modelName) { throw new IllegalArgumentException("Unsupported model: " + modelName); } + /** + * Registers an LLM factory via reflection, if the class is available. + * + * @param className The fully qualified class name of the LLM. + * @param pattern The regex pattern for matching model names. + */ + private static void registerViaReflection(String className, String pattern) { + try { + Class llmClass = Class.forName(className); + LlmFactory factory = + modelName -> { + try { + Object builder = llmClass.getMethod("builder").invoke(null); + builder.getClass().getMethod("modelName", String.class).invoke(builder, modelName); + return (BaseLlm) builder.getClass().getMethod("build").invoke(builder); + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException("Failed to create instance of " + className, e); + } + }; + registerLlm(pattern, factory); + } catch (ClassNotFoundException e) { + // ignore - LLM not available. + } + } + /** * Registers an LLM factory for testing purposes. Clears cached instances matching the given * pattern to ensure test isolation. diff --git a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java index 0a9f55b16..7a83341bc 100644 --- a/core/src/main/java/com/google/adk/utils/ComponentRegistry.java +++ b/core/src/main/java/com/google/adk/utils/ComponentRegistry.java @@ -22,26 +22,13 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.agents.Callbacks; -import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.LoopAgent; -import com.google.adk.agents.ParallelAgent; -import com.google.adk.agents.SequentialAgent; -import com.google.adk.tools.AgentTool; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseToolset; -import com.google.adk.tools.ExampleTool; -import com.google.adk.tools.ExitLoopTool; -import com.google.adk.tools.GoogleMapsTool; -import com.google.adk.tools.GoogleSearchTool; -import com.google.adk.tools.LoadArtifactsTool; -import com.google.adk.tools.LongRunningFunctionTool; -import com.google.adk.tools.UrlContextTool; -import com.google.adk.tools.mcp.McpToolset; +import com.google.common.collect.ImmutableMap; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nonnull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,6 +80,8 @@ public class ComponentRegistry { private static final Logger logger = LoggerFactory.getLogger(ComponentRegistry.class); + private static volatile ImmutableMap DEFAULT_REGISTRY; + private static volatile ComponentRegistry instance = new ComponentRegistry(); private final Map registry = new ConcurrentHashMap<>(); @@ -103,55 +92,80 @@ protected ComponentRegistry() { /** Initializes the registry with base pre-wired ADK instances. */ private void initializePreWiredEntries() { - registerAdkAgentClass(LlmAgent.class); - registerAdkAgentClass(LoopAgent.class); - registerAdkAgentClass(ParallelAgent.class); - registerAdkAgentClass(SequentialAgent.class); - - registerAdkToolInstance("google_search", GoogleSearchTool.INSTANCE); - registerAdkToolInstance("load_artifacts", LoadArtifactsTool.INSTANCE); - registerAdkToolInstance("exit_loop", ExitLoopTool.INSTANCE); - registerAdkToolInstance("url_context", UrlContextTool.INSTANCE); - registerAdkToolInstance("google_maps_grounding", GoogleMapsTool.INSTANCE); - - registerAdkToolClass(AgentTool.class); - registerAdkToolClass(LongRunningFunctionTool.class); - registerAdkToolClass(ExampleTool.class); - - registerAdkToolsetClass(McpToolset.class); - // TODO: add all python tools that also exist in Java. - + if (DEFAULT_REGISTRY == null) { + synchronized (ComponentRegistry.class) { + if (DEFAULT_REGISTRY == null) { + registerAdkClassByName("com.google.adk.agents.LlmAgent"); + registerAdkClassByName("com.google.adk.agents.LoopAgent"); + registerAdkClassByName("com.google.adk.agents.ParallelAgent"); + registerAdkClassByName("com.google.adk.agents.SequentialAgent"); + + registerAdkToolInstance("google_search", "com.google.adk.tools.GoogleSearchTool"); + registerAdkToolInstance("load_artifacts", "com.google.adk.tools.LoadArtifactsTool"); + registerAdkToolInstance("exit_loop", "com.google.adk.tools.ExitLoopTool"); + registerAdkToolInstance("url_context", "com.google.adk.tools.UrlContextTool"); + registerAdkToolInstance("google_maps_grounding", "com.google.adk.tools.GoogleMapsTool"); + + registerAdkClassByName("com.google.adk.tools.AgentTool"); + registerAdkClassByName("com.google.adk.tools.LongRunningFunctionTool"); + registerAdkClassByName("com.google.adk.tools.ExampleTool"); + + registerAdkClassByName("com.google.adk.tools.mcp.McpToolset"); + // TODO: add all python tools that also exist in Java. + + DEFAULT_REGISTRY = ImmutableMap.copyOf(registry); + return; + } + } + } + registry.putAll(DEFAULT_REGISTRY); logger.debug("Initialized base pre-wired entries in ComponentRegistry"); } - private void registerAdkAgentClass(Class agentClass) { - registry.put(agentClass.getName(), agentClass); - // For python compatibility, also register the name used in ADK Python. - registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass); - } + private void registerAdkClassByName(String className) { + try { + Class clazz = Thread.currentThread().getContextClassLoader().loadClass(className); + String standardPrefix; + if (BaseAgent.class.isAssignableFrom(clazz)) { + standardPrefix = "google.adk.agents."; + } else if (BaseTool.class.isAssignableFrom(clazz)) { + standardPrefix = "google.adk.tools."; + } else if (BaseToolset.class.isAssignableFrom(clazz)) { + standardPrefix = "google.adk.tools."; + } else { + throw new IllegalArgumentException( + "Cannot determine standardPrefix for type " + clazz.getName()); + } - private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) { - registry.put(name, toolInstance); - // For python compatibility, also register the name used in ADK Python. - registry.put("google.adk.tools." + name, toolInstance); - } + registry.put(clazz.getName(), clazz); + // For python compatibility, also register the name used in ADK Python. + registry.put(standardPrefix + clazz.getSimpleName(), clazz); - private void registerAdkToolClass(@Nonnull Class toolClass) { - registry.put(toolClass.getName(), toolClass); - // For python compatibility, also register the name used in ADK Python. - registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass); + if (BaseToolset.class.isAssignableFrom(clazz)) { + registry.put(clazz.getSimpleName(), clazz); + if (clazz.getSimpleName().equals("McpToolset")) { + registry.put("mcp.McpToolset", clazz); + } + } + } catch (Exception e) { + logger.info( + "{} not found, skipping registration: {}", + className.substring(className.lastIndexOf('.') + 1), + e.getMessage()); + } } - private void registerAdkToolsetClass(@Nonnull Class toolsetClass) { - registry.put(toolsetClass.getName(), toolsetClass); - // For python compatibility, also register the name used in ADK Python. - registry.put("google.adk.tools." + toolsetClass.getSimpleName(), toolsetClass); - // Also register by simple class name - registry.put(toolsetClass.getSimpleName(), toolsetClass); - // Special support for toolsets with various naming conventions - String simpleName = toolsetClass.getSimpleName(); - if (simpleName.equals("McpToolset")) { - registry.put("mcp.McpToolset", toolsetClass); + private void registerAdkToolInstance(String name, String toolClassName) { + try { + Object toolInstance = Class.forName(toolClassName).getField("INSTANCE").get(null); + registry.put(name, toolInstance); + // For python compatibility, also register the name used in ADK Python. + registry.put("google.adk.tools." + name, toolInstance); + } catch (Exception e) { + logger.info( + "{} not found, skipping registration: {}", + toolClassName.substring(toolClassName.lastIndexOf('.') + 1), + e.getMessage()); } } @@ -281,30 +295,31 @@ public static Optional resolveAgentInstance(String name) { */ public static Class resolveAgentClass(String agentClassName) { // If no agent_class is specified, it will default to LlmAgent. - if (isNullOrEmpty(agentClassName)) { - return LlmAgent.class; - } + final String effectiveAgentClassName = + isNullOrEmpty(agentClassName) ? "com.google.adk.agents.LlmAgent" : agentClassName; Optional> agentClass; - if (agentClassName.contains(".")) { + if (effectiveAgentClassName.contains(".")) { // If agentClassName contains '.', use it directly - agentClass = getType(agentClassName, BaseAgent.class); + agentClass = getType(effectiveAgentClassName, BaseAgent.class); } else { // First try the simple name agentClass = - getType(agentClassName, BaseAgent.class) + getType(effectiveAgentClassName, BaseAgent.class) // If not found, try with com.google.adk.agents prefix - .or(() -> getType("com.google.adk.agents." + agentClassName, BaseAgent.class)) + .or( + () -> + getType("com.google.adk.agents." + effectiveAgentClassName, BaseAgent.class)) // For Python compatibility, also try with google.adk.agents prefix - .or(() -> getType("google.adk.agents." + agentClassName, BaseAgent.class)); + .or(() -> getType("google.adk.agents." + effectiveAgentClassName, BaseAgent.class)); } return agentClass.orElseThrow( () -> new IllegalArgumentException( "agentClass '" - + agentClassName + + effectiveAgentClassName + "' is not in registry or not a subclass of BaseAgent.")); } diff --git a/core/src/test/java/com/google/adk/utils/ComponentRegistryTest.java b/core/src/test/java/com/google/adk/utils/ComponentRegistryTest.java index ca401e6e9..8643b60a9 100644 --- a/core/src/test/java/com/google/adk/utils/ComponentRegistryTest.java +++ b/core/src/test/java/com/google/adk/utils/ComponentRegistryTest.java @@ -298,7 +298,7 @@ public void testResolveToolClass_withFullyQualifiedName() { @Test public void testMcpToolsetRegistration() { - ComponentRegistry registry = ComponentRegistry.getInstance(); + ComponentRegistry registry = new ComponentRegistry(); // Verify direct registry storage (tests lines 134, 136, 138, 142 in ComponentRegistry.java) Optional directFullName = registry.get("com.google.adk.tools.mcp.McpToolset");