diff --git a/core/pom.xml b/core/pom.xml index c8839262..9f42e0a1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -55,7 +55,7 @@ docker-java-transport-httpclient5 - io.modelcontextprotocol.sdk + io.modelcontextprotocol.sdk mcp @@ -189,6 +189,10 @@ opentelemetry-sdk-testing test + + net.javacrumbs.future-converter + future-converter-java8-guava + diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index 74cf78b9..1d2eb02e 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -17,9 +17,12 @@ package com.google.adk.models; import static com.google.common.base.StandardSystemProperty.JAVA_VERSION; +import static net.javacrumbs.futureconverter.java8guava.FutureConverter.toListenableFuture; import com.google.adk.Version; import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.Client; import com.google.genai.ResponseStream; @@ -32,11 +35,14 @@ import com.google.genai.types.LiveConnectConfig; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -205,6 +211,23 @@ public Gemini build() { } } + private static Single toSingle(ListenableFuture future, Scheduler scheduler) { + return Single.create( + emitter -> { + future.addListener( + () -> { + try { + emitter.onSuccess(Futures.getDone(future)); + } catch (ExecutionException e) { + emitter.onError(e.getCause()); + } + }, + scheduler::scheduleDirect); + + emitter.setCancellable(() -> future.cancel(false)); + }); + } + @Override public Flowable generateContent(LlmRequest llmRequest, boolean stream) { llmRequest = @@ -218,14 +241,17 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre if (stream) { logger.debug("Sending streaming generateContent request to model {}", effectiveModelName); - CompletableFuture> streamFuture = - apiClient.async.models.generateContentStream( - effectiveModelName, llmRequest.contents(), config); + ListenableFuture> streamFuture = + toListenableFuture( + apiClient.async.models.generateContentStream( + effectiveModelName, llmRequest.contents(), config)); return Flowable.defer( () -> processRawResponses( - Flowable.fromFuture(streamFuture).flatMapIterable(iterable -> iterable))) + toSingle(streamFuture, Schedulers.io()) + .toFlowable() + .flatMapIterable(iterable -> iterable))) .filter( llmResponse -> llmResponse @@ -243,12 +269,16 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre .orElse(false)); } else { logger.debug("Sending generateContent request to model {}", effectiveModelName); - return Flowable.fromFuture( - apiClient - .async - .models - .generateContent(effectiveModelName, llmRequest.contents(), config) - .thenApplyAsync(LlmResponse::create)); + final LlmRequest finalLlmRequest = llmRequest; + return toSingle( + toListenableFuture( + apiClient + .async + .models + .generateContent(effectiveModelName, finalLlmRequest.contents(), config) + .thenApplyAsync(LlmResponse::create)), + Schedulers.io()) + .toFlowable(); } } diff --git a/pom.xml b/pom.xml index 6a1aa5af..c3476fda 100644 --- a/pom.xml +++ b/pom.xml @@ -72,6 +72,7 @@ 1.4.0 3.9.0 5.4.3 + 1.2.0 @@ -274,6 +275,11 @@ assertj-core ${assertj.version} + + net.javacrumbs.future-converter + future-converter-java8-guava + ${future-converter-java8-guava.version} +