From 75f487f6cc5326ba2ce89be2a3e1436aea334b81 Mon Sep 17 00:00:00 2001 From: sirivarma Date: Thu, 6 Mar 2025 11:10:32 -0800 Subject: [PATCH 1/6] Conversation first commit Signed-off-by: Siri Varma Vegiraju Signed-off-by: sirivarma --- examples/pom.xml | 5 + .../conversation/DemoConversationAI.java | 32 ++++ pom.xml | 3 +- sdk-ai/pom.xml | 164 ++++++++++++++++++ .../java/io/dapr/ai/client/DaprAiClient.java | 29 ++++ .../ai/client/DaprConversationClient.java | 164 ++++++++++++++++++ .../dapr/ai/client/DaprConversationInput.java | 66 +++++++ .../ai/client/DaprConversationOutput.java | 43 +++++ .../ai/client/DaprConversationResponse.java | 52 ++++++ .../dapr/ai/client/DaprConversationRole.java | 13 ++ sdk-ai/src/test/java/io/dapr/ai/AITest.java | 14 ++ sdk/pom.xml | 30 ++++ 12 files changed, 614 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java create mode 100644 sdk-ai/pom.xml create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java create mode 100644 sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java create mode 100644 sdk-ai/src/test/java/io/dapr/ai/AITest.java diff --git a/examples/pom.xml b/examples/pom.xml index 0380911f72..4a0026a566 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -119,6 +119,11 @@ dapr-sdk ${project.version} + + io.dapr + dapr-sdk-ai + ${project.version} + com.evanlennick retry4j diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java new file mode 100644 index 0000000000..7d876c030c --- /dev/null +++ b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java @@ -0,0 +1,32 @@ +package io.dapr.examples.conversation; + +import io.dapr.ai.client.DaprConversationClient; +import io.dapr.ai.client.DaprConversationInput; +import io.dapr.ai.client.DaprConversationResponse; +import io.dapr.v1.DaprProtos; +import reactor.core.publisher.Mono; + +import java.util.ArrayList; +import java.util.Collections; + +public class DemoConversationAI { + /** + * The main method to start the client. + * + * @param args Input arguments (unused). + */ + public static void main(String[] args) { + try (DaprConversationClient client = new DaprConversationClient(null)) { + DaprConversationInput daprConversationInput = new DaprConversationInput("11"); + + // Component name is the name provided in the metadata block of the conversation.yaml file. + Mono instanceId = client.converse("openai", new ArrayList<>(Collections.singleton(daprConversationInput)), "1234", false, 0.0d); + System.out.printf("Started a new chaining model workflow with instance ID: %s%n", instanceId); + DaprConversationResponse response = instanceId.block(); + + System.out.println(response); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index bf727fe94d..3d94c55502 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ 1.69.0 3.25.5 protoc - https://raw.githubusercontent.com/dapr/dapr/v1.14.4/dapr/proto + https://raw.githubusercontent.com/dapr/dapr/v1.15.2/dapr/proto 1.15.0-SNAPSHOT 0.15.0-SNAPSHOT 1.7.1 @@ -333,6 +333,7 @@ sdk-autogen sdk sdk-actors + sdk-ai sdk-workflows sdk-springboot dapr-spring diff --git a/sdk-ai/pom.xml b/sdk-ai/pom.xml new file mode 100644 index 0000000000..fe42edfea6 --- /dev/null +++ b/sdk-ai/pom.xml @@ -0,0 +1,164 @@ + + 4.0.0 + + + io.dapr + dapr-sdk-parent + 1.15.0-SNAPSHOT + + + dapr-sdk-ai + jar + 1.15.0-SNAPSHOT + dapr-sdk-ai + SDK for AI on Dapr + + + false + + + + + io.dapr + dapr-sdk + ${project.parent.version} + + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + + + org.mockito + mockito-core + test + + + org.mockito + mockito-inline + 4.2.0 + test + + + org.junit.jupiter + junit-jupiter + test + + + org.junit.vintage + junit-vintage-engine + 5.7.0 + test + + + com.microsoft + durabletask-client + 1.5.0 + + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.fasterxml.jackson.core + jackson-annotations + ${jackson.version} + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + + + + + + + org.apache.maven.plugins + maven-source-plugin + 3.2.1 + + + attach-sources + + jar-no-fork + + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.2.0 + + + attach-javadocs + + jar + + + + + + org.jacoco + jacoco-maven-plugin + 0.8.11 + + + default-prepare-agent + + prepare-agent + + + + report + test + + report + + + target/jacoco-report/ + + + + check + + check + + + + + BUNDLE + + + LINE + COVEREDRATIO + 80% + + + + + + + + + + + + diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java new file mode 100644 index 0000000000..193132e130 --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java @@ -0,0 +1,29 @@ +package io.dapr.ai.client; + +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Defines client operations for managing Dapr AI instances. + */ +interface DaprAiClient { + + /** + * Method to call the Dapr Converse API. + * + * @param conversationComponentName name for the conversation component. + * @param daprConversationInputs prompts that are part of the conversation. + * @param contextId identifier of an existing chat (like in ChatGPT) + * @param scrubPii data that comes from the LLM. + * @param temperature to optimize from creativity or predictability. + * @return @ConversationResponse. + */ + Mono converse( + String conversationComponentName, + List daprConversationInputs, + @Nullable String contextId, + boolean scrubPii, + double temperature); +} diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java new file mode 100644 index 0000000000..1f78a2cc5f --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java @@ -0,0 +1,164 @@ +package io.dapr.ai.client; + +import com.google.protobuf.Any; +import io.dapr.client.resiliency.ResiliencyOptions; +import io.dapr.config.Properties; +import io.dapr.exceptions.DaprException; +import io.dapr.internal.exceptions.DaprHttpException; +import io.dapr.internal.grpc.interceptors.DaprTimeoutInterceptor; +import io.dapr.internal.grpc.interceptors.DaprTracingInterceptor; +import io.dapr.internal.resiliency.RetryPolicy; +import io.dapr.internal.resiliency.TimeoutPolicy; +import io.dapr.utils.NetworkUtils; +import io.dapr.v1.DaprGrpc; +import io.dapr.v1.DaprProtos; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.jetbrains.annotations.Nullable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.util.context.ContextView; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +public class DaprConversationClient implements AutoCloseable, DaprAiClient { + + /** + * Stub that has the method to call the conversation apis. + */ + private final DaprGrpc.DaprStub asyncStub; + + /** + * The GRPC managed channel to be used. + */ + private final ManagedChannel channel; + + /** + * The retry policy. + */ + private final RetryPolicy retryPolicy; + + /** + * The timeout policy. + */ + private final TimeoutPolicy timeoutPolicy; + + /** + * ConversationClient constructor. + * + * @param resiliencyOptions timeout and retry policies. + */ + public DaprConversationClient( + @Nullable ResiliencyOptions resiliencyOptions) { + this.channel = NetworkUtils.buildGrpcManagedChannel(new Properties()); + this.asyncStub = DaprGrpc.newStub(this.channel); + this.retryPolicy = new RetryPolicy(resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries()); + this.timeoutPolicy = new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()); + } + + @Override + public Mono converse( + String conversationComponentName, + List daprConversationInputs, + @Nullable String contextId, + boolean scrubPii, + double temperature) { + + try { + if ((conversationComponentName == null) || (conversationComponentName.trim().isEmpty())) { + throw new IllegalArgumentException("Conversation component name cannot be null or empty."); + } + + if ((daprConversationInputs == null) || (daprConversationInputs.isEmpty())) { + throw new IllegalArgumentException("Conversation inputs cannot be null or empty."); + } + + DaprProtos.ConversationRequest.Builder conversationRequest = DaprProtos.ConversationRequest.newBuilder() + .setTemperature(temperature) + .setScrubPII(scrubPii) + .setName(conversationComponentName); + + if (contextId != null) { + conversationRequest.setContextID(contextId); + } + + for (DaprConversationInput input : daprConversationInputs) { + conversationRequest.addInputs(DaprProtos.ConversationInput.newBuilder() + .setContent(input.getContent()).build()); + } + + Mono conversationResponseMono = Mono.deferContextual( + context -> this.createMono( + it -> intercept(context, asyncStub) + .converseAlpha1(conversationRequest.build(), it) + ) + ); + + return conversationResponseMono.map(conversationResponse -> { + + List daprConversationOutputs = new ArrayList<>(); + for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) { + Map parameters = new HashMap<>(); + for (Map.Entry entrySet : conversationResult.getParametersMap().entrySet()) { + parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray()); + } + + DaprConversationOutput daprConversationOutput = + new DaprConversationOutput(conversationResult.getResult(), parameters); + daprConversationOutputs.add(daprConversationOutput); + } + + return new DaprConversationResponse(conversationResponse.getContextID(), daprConversationOutputs); + }); + } catch (Exception ex) { + return DaprException.wrapMono(ex); + } + } + + @Override + public void close() throws Exception { + DaprException.wrap(() -> { + if (channel != null && !channel.isShutdown()) { + channel.shutdown(); + } + + return true; + }).call(); + } + + private DaprGrpc.DaprStub intercept( + ContextView context, DaprGrpc.DaprStub client) { + return client.withInterceptors( + new DaprTimeoutInterceptor(this.timeoutPolicy), + new DaprTracingInterceptor(context)); + } + + private Mono createMono(Consumer> consumer) { + return retryPolicy.apply( + Mono.create(sink -> DaprException.wrap(() -> consumer.accept( + createStreamObserver(sink))).run())); + } + + private StreamObserver createStreamObserver(MonoSink sink) { + return new StreamObserver() { + @Override + public void onNext(T value) { + sink.success(value); + } + + @Override + public void onError(Throwable t) { + sink.error(DaprException.propagate(DaprHttpException.fromGrpcExecutionException(null, t))); + } + + @Override + public void onCompleted() { + sink.success(); + } + }; + } +} diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java new file mode 100644 index 0000000000..5189f1da40 --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java @@ -0,0 +1,66 @@ +package io.dapr.ai.client; + +/** + * Represents an input message for a conversation with an LLM. + */ +public class DaprConversationInput { + + private final String content; + + private DaprConversationRole role; + + private boolean scrubPii; + + public DaprConversationInput(String content) { + this.content = content; + } + + /** + * Retrieves the content of the conversation input. + * + * @return The content to be sent to the LLM. + */ + public String getContent() { + return content; + } + + /** + * Retrieves the role associated with the conversation input. + * + * @return this. + */ + public DaprConversationRole getRole() { + return role; + } + + /** + * Sets the role associated with the conversation input. + * + * @param role The role to assign to the message. + * @return this. + */ + public DaprConversationInput setRole(DaprConversationRole role) { + this.role = role; + return this; + } + + /** + * Checks if Personally Identifiable Information (PII) should be scrubbed before sending to the LLM. + * + * @return {@code true} if PII should be scrubbed, {@code false} otherwise. + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Sets whether to scrub Personally Identifiable Information (PII) before sending to the LLM. + * + * @param scrubPii A boolean indicating whether to remove PII. + * @return this. + */ + public DaprConversationInput setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } +} diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java new file mode 100644 index 0000000000..a1a5d5ec76 --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java @@ -0,0 +1,43 @@ +package io.dapr.ai.client; + +import java.util.Collections; +import java.util.Map; + +/** + * Returns the conversation output. + */ +public class DaprConversationOutput { + + private final String result; + + private final Map parameters; + + /** + * Constructor. + * + * @param result result for one of the conversation input. + * @param parameters all custom fields. + */ + public DaprConversationOutput(String result, Map parameters) { + this.result = result; + this.parameters = parameters; + } + + /** + * Result for the one conversation input. + * + * @return result output from the LLM. + */ + public String getResult() { + return this.result; + } + + /** + * Parameters for all custom fields. + * + * @return parameters. + */ + public Map getParameters() { + return Collections.unmodifiableMap(this.parameters); + } +} diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java new file mode 100644 index 0000000000..677352495c --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java @@ -0,0 +1,52 @@ +package io.dapr.ai.client; + +import java.util.Collections; +import java.util.List; + +/** + * Response from the Dapr Conversation API. + */ +public class DaprConversationResponse { + + private String contextId; + + private final List daprConversationOutputs; + + /** + * Constructor. + * + * @param daprConversationOutputs outputs from the LLM. + */ + public DaprConversationResponse(List daprConversationOutputs) { + this.daprConversationOutputs = daprConversationOutputs; + } + + /** + * Constructor. + * + * @param contextId context id supplied to LLM. + * @param daprConversationOutputs outputs from the LLM. + */ + public DaprConversationResponse(String contextId, List daprConversationOutputs) { + this.contextId = contextId; + this.daprConversationOutputs = daprConversationOutputs; + } + + /** + * The ID of an existing chat (like in ChatGPT). + * + * @return String identifier. + */ + public String getContextId() { + return this.contextId; + } + + /** + * Get list of conversation outputs. + * + * @return List{@link DaprConversationOutput}. + */ + public List getDaprConversationOutputs() { + return Collections.unmodifiableList(this.daprConversationOutputs); + } +} diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java new file mode 100644 index 0000000000..11991d857d --- /dev/null +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java @@ -0,0 +1,13 @@ +package io.dapr.ai.client; + +/** + * Conversation AI supported roles. + */ +public enum DaprConversationRole { + + USER, + + TOOL, + + ASSISSTANT +} diff --git a/sdk-ai/src/test/java/io/dapr/ai/AITest.java b/sdk-ai/src/test/java/io/dapr/ai/AITest.java new file mode 100644 index 0000000000..f45217ab1f --- /dev/null +++ b/sdk-ai/src/test/java/io/dapr/ai/AITest.java @@ -0,0 +1,14 @@ +package io.dapr.ai; + +import org.junit.Test; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +public class AITest { + + @Test + public void testAI() { + } +} diff --git a/sdk/pom.xml b/sdk/pom.xml index c608e30bb6..bbf7c45d32 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -127,6 +127,36 @@ grpc-inprocess test + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + + + io.dapr + dapr-sdk-autogen + 1.14.0-SNAPSHOT + compile + From 88fe015edc69619a16a392e7ebcee27eb8df72d1 Mon Sep 17 00:00:00 2001 From: sirivarma Date: Thu, 6 Mar 2025 19:46:01 -0800 Subject: [PATCH 2/6] Add unit tests Signed-off-by: sirivarma --- .../conversation/DemoConversationAI.java | 2 +- .../java/io/dapr/ai/client/DaprAiClient.java | 13 +- .../ai/client/DaprConversationClient.java | 60 ++++-- .../ai/client/DaprConversationResponse.java | 9 - sdk-ai/src/test/java/io/dapr/ai/AITest.java | 14 -- .../ai/client/DaprConversationClientTest.java | 185 ++++++++++++++++++ 6 files changed, 245 insertions(+), 38 deletions(-) delete mode 100644 sdk-ai/src/test/java/io/dapr/ai/AITest.java create mode 100644 sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java index 7d876c030c..014ca686ea 100644 --- a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java +++ b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java @@ -16,7 +16,7 @@ public class DemoConversationAI { * @param args Input arguments (unused). */ public static void main(String[] args) { - try (DaprConversationClient client = new DaprConversationClient(null)) { + try (DaprConversationClient client = new DaprConversationClient()) { DaprConversationInput daprConversationInput = new DaprConversationInput("11"); // Component name is the name provided in the metadata block of the conversation.yaml file. diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java index 193132e130..c3f6b536a8 100644 --- a/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java @@ -23,7 +23,18 @@ interface DaprAiClient { Mono converse( String conversationComponentName, List daprConversationInputs, - @Nullable String contextId, + String contextId, boolean scrubPii, double temperature); + + /** + * Method to call the Dapr Converse API. + * + * @param conversationComponentName name for the conversation component. + * @param daprConversationInputs prompts that are part of the conversation. + * @return @ConversationResponse. + */ + Mono converse( + String conversationComponentName, + List daprConversationInputs); } diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java index 1f78a2cc5f..55e4108a3f 100644 --- a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java @@ -12,9 +12,9 @@ import io.dapr.utils.NetworkUtils; import io.dapr.v1.DaprGrpc; import io.dapr.v1.DaprProtos; +import io.grpc.Channel; import io.grpc.ManagedChannel; import io.grpc.stub.StreamObserver; -import org.jetbrains.annotations.Nullable; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; import reactor.util.context.ContextView; @@ -32,11 +32,6 @@ public class DaprConversationClient implements AutoCloseable, DaprAiClient { */ private final DaprGrpc.DaprStub asyncStub; - /** - * The GRPC managed channel to be used. - */ - private final ManagedChannel channel; - /** * The retry policy. */ @@ -47,24 +42,50 @@ public class DaprConversationClient implements AutoCloseable, DaprAiClient { */ private final TimeoutPolicy timeoutPolicy; + /** + * Constructor to create conversation client. + */ + public DaprConversationClient() { + this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(new Properties())), null); + } + + /** + * Constructor. + * + * @param properties with client configuration options. + * @param resiliencyOptions retry options. + */ + public DaprConversationClient( + Properties properties, + ResiliencyOptions resiliencyOptions) { + this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(properties)), resiliencyOptions); + } + /** * ConversationClient constructor. * * @param resiliencyOptions timeout and retry policies. */ - public DaprConversationClient( - @Nullable ResiliencyOptions resiliencyOptions) { - this.channel = NetworkUtils.buildGrpcManagedChannel(new Properties()); - this.asyncStub = DaprGrpc.newStub(this.channel); + protected DaprConversationClient( + DaprGrpc.DaprStub asyncStub, + ResiliencyOptions resiliencyOptions) { + this.asyncStub = asyncStub; this.retryPolicy = new RetryPolicy(resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries()); this.timeoutPolicy = new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()); } + @Override + public Mono converse( + String conversationComponentName, + List daprConversationInputs) { + return converse(conversationComponentName, daprConversationInputs, null, false, 0.0d); + } + @Override public Mono converse( String conversationComponentName, List daprConversationInputs, - @Nullable String contextId, + String contextId, boolean scrubPii, double temperature) { @@ -87,8 +108,19 @@ public Mono converse( } for (DaprConversationInput input : daprConversationInputs) { - conversationRequest.addInputs(DaprProtos.ConversationInput.newBuilder() - .setContent(input.getContent()).build()); + if (input.getContent() == null || input.getContent().isEmpty()) { + throw new IllegalArgumentException("Conversation input content cannot be null or empty."); + } + + DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder() + .setContent(input.getContent()) + .setScrubPII(input.isScrubPii()); + + if (input.getRole() != null) { + conversationInputOrBuilder.setRole(input.getRole().toString()); + } + + conversationRequest.addInputs(conversationInputOrBuilder.build()); } Mono conversationResponseMono = Mono.deferContextual( @@ -121,6 +153,8 @@ public Mono converse( @Override public void close() throws Exception { + ManagedChannel channel = (ManagedChannel) this.asyncStub.getChannel(); + DaprException.wrap(() -> { if (channel != null && !channel.isShutdown()) { channel.shutdown(); diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java index 677352495c..e7ab001a5c 100644 --- a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java +++ b/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java @@ -12,15 +12,6 @@ public class DaprConversationResponse { private final List daprConversationOutputs; - /** - * Constructor. - * - * @param daprConversationOutputs outputs from the LLM. - */ - public DaprConversationResponse(List daprConversationOutputs) { - this.daprConversationOutputs = daprConversationOutputs; - } - /** * Constructor. * diff --git a/sdk-ai/src/test/java/io/dapr/ai/AITest.java b/sdk-ai/src/test/java/io/dapr/ai/AITest.java deleted file mode 100644 index f45217ab1f..0000000000 --- a/sdk-ai/src/test/java/io/dapr/ai/AITest.java +++ /dev/null @@ -1,14 +0,0 @@ -package io.dapr.ai; - -import org.junit.Test; - -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.*; - -public class AITest { - - @Test - public void testAI() { - } -} diff --git a/sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java b/sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java new file mode 100644 index 0000000000..d39ef1413d --- /dev/null +++ b/sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java @@ -0,0 +1,185 @@ +package io.dapr.ai.client; + +import io.dapr.client.resiliency.ResiliencyOptions; +import io.dapr.config.Properties; +import io.dapr.v1.DaprGrpc; +import io.dapr.v1.DaprProtos; +import io.grpc.ManagedChannel; +import io.grpc.stub.StreamObserver; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.*; + +public class DaprConversationClientTest { + + private DaprGrpc.DaprStub daprStub; + + @Before + public void initialize() { + + ManagedChannel channel = mock(ManagedChannel.class); + daprStub = mock(DaprGrpc.DaprStub.class); + when(daprStub.getChannel()).thenReturn(channel); + when(daprStub.withInterceptors(Mockito.any(), Mockito.any())).thenReturn(daprStub); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception { + try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new DaprConversationInput("Hello there !")); + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse(null, daprConversationInputs).block()); + Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception { + try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new DaprConversationInput("Hello there !")); + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse("", daprConversationInputs).block()); + Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsEmpty() throws Exception { + try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { + List daprConversationInputs = new ArrayList<>(); + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse("openai", daprConversationInputs).block()); + Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsNull() throws Exception { + try (DaprConversationClient daprConversationClient = + new DaprConversationClient(new Properties(), null)) { + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse("openai", null).block()); + Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsNull() throws Exception { + try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new DaprConversationInput(null)); + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse("openai", daprConversationInputs).block()); + Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsEmpty() throws Exception { + try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new DaprConversationInput("")); + + IllegalArgumentException exception = + Assert.assertThrows(IllegalArgumentException.class, () -> + daprConversationClient.converse("openai", daprConversationInputs).block()); + Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + try (DaprConversationClient daprConversationClient = + new DaprConversationClient(daprStub, new ResiliencyOptions())) { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new DaprConversationInput("Hello there")); + + DaprConversationResponse daprConversationResponse = + daprConversationClient.converse("openai", daprConversationInputs).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + Assert.assertEquals("openai", conversationRequest.getName()); + Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + Assert.assertEquals("Hello How are you", + daprConversationResponse.getDaprConversationOutputs().get(0).getResult()); + } + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .setContextID("contextId") + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + try (DaprConversationClient daprConversationClient = new DaprConversationClient(daprStub, null)) { + DaprConversationInput daprConversationInput = new DaprConversationInput("Hello there") + .setRole(DaprConversationRole.ASSISSTANT) + .setScrubPii(true); + + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(daprConversationInput); + + DaprConversationResponse daprConversationResponse = + daprConversationClient.converse("openai", daprConversationInputs, + "contextId", true, 1.1d).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + Assert.assertEquals("openai", conversationRequest.getName()); + Assert.assertEquals("contextId", conversationRequest.getContextID()); + Assert.assertTrue(conversationRequest.getScrubPII()); + Assert.assertEquals(1.1d, conversationRequest.getTemperature(), 0d); + Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + Assert.assertTrue(conversationRequest.getInputs(0).getScrubPII()); + Assert.assertEquals(DaprConversationRole.ASSISSTANT.toString(), conversationRequest.getInputs(0).getRole()); + Assert.assertEquals("contextId", daprConversationResponse.getContextId()); + Assert.assertEquals("Hello How are you", + daprConversationResponse.getDaprConversationOutputs().get(0).getResult()); + } + } +} From fada33fc61c8e8c3b43cbd520044b375c8fb81e1 Mon Sep 17 00:00:00 2001 From: sirivarma Date: Thu, 13 Mar 2025 17:30:11 -0700 Subject: [PATCH 3/6] change ai to conv Signed-off-by: sirivarma --- examples/pom.xml | 2 +- pom.xml | 2 +- {sdk-ai => sdk-conversation}/pom.xml | 11 ++++++++--- .../src/main/java/io/dapr/ai/client/DaprAiClient.java | 0 .../io/dapr/ai/client/DaprConversationClient.java | 0 .../java/io/dapr/ai/client/DaprConversationInput.java | 0 .../io/dapr/ai/client/DaprConversationOutput.java | 0 .../io/dapr/ai/client/DaprConversationResponse.java | 0 .../java/io/dapr/ai/client/DaprConversationRole.java | 0 .../io/dapr/ai/client/DaprConversationClientTest.java | 0 10 files changed, 10 insertions(+), 5 deletions(-) rename {sdk-ai => sdk-conversation}/pom.xml (95%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprAiClient.java (100%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprConversationClient.java (100%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprConversationInput.java (100%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprConversationOutput.java (100%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprConversationResponse.java (100%) rename {sdk-ai => sdk-conversation}/src/main/java/io/dapr/ai/client/DaprConversationRole.java (100%) rename {sdk-ai => sdk-conversation}/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java (100%) diff --git a/examples/pom.xml b/examples/pom.xml index 4a0026a566..4ae0be0240 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -121,7 +121,7 @@ io.dapr - dapr-sdk-ai + dapr-sdk-conversation ${project.version} diff --git a/pom.xml b/pom.xml index 3d94c55502..7ad9bd18ef 100644 --- a/pom.xml +++ b/pom.xml @@ -333,7 +333,7 @@ sdk-autogen sdk sdk-actors - sdk-ai + sdk-conversation sdk-workflows sdk-springboot dapr-spring diff --git a/sdk-ai/pom.xml b/sdk-conversation/pom.xml similarity index 95% rename from sdk-ai/pom.xml rename to sdk-conversation/pom.xml index fe42edfea6..a7e83674b1 100644 --- a/sdk-ai/pom.xml +++ b/sdk-conversation/pom.xml @@ -10,11 +10,11 @@ 1.15.0-SNAPSHOT - dapr-sdk-ai + dapr-sdk-conversation jar 1.15.0-SNAPSHOT - dapr-sdk-ai - SDK for AI on Dapr + dapr-sdk-conversation + SDK for Conversation false @@ -88,6 +88,11 @@ + + org.sonatype.plugins + nexus-staging-maven-plugin + + org.apache.maven.plugins maven-source-plugin diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprAiClient.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationClient.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationInput.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationOutput.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationResponse.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java diff --git a/sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java similarity index 100% rename from sdk-ai/src/main/java/io/dapr/ai/client/DaprConversationRole.java rename to sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java diff --git a/sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java b/sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java similarity index 100% rename from sdk-ai/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java rename to sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java From 842b0e7ec24a9328bae06719eea3a35ee88d268d Mon Sep 17 00:00:00 2001 From: sirivarma Date: Tue, 18 Mar 2025 08:09:41 -0700 Subject: [PATCH 4/6] Move to single module Signed-off-by: sirivarma --- pom.xml | 1 - sdk-conversation/pom.xml | 169 --------------- .../java/io/dapr/ai/client/DaprAiClient.java | 40 ---- .../ai/client/DaprConversationClient.java | 198 ------------------ .../dapr/ai/client/DaprConversationRole.java | 13 -- .../ai/client/DaprConversationClientTest.java | 185 ---------------- .../java/io/dapr/client/DaprClientImpl.java | 78 +++++++ .../io/dapr/client/DaprPreviewClient.java | 11 + .../dapr/client/domain/ConversationInput.java | 27 ++- .../client/domain/ConversationOutput.java | 6 +- .../client/domain/ConversationRequest.java | 106 ++++++++++ .../client/domain/ConversationResponse.java | 12 +- .../dapr/client/domain/ConversationRole.java | 22 ++ .../client/DaprPreviewClientGrpcTest.java | 143 +++++++++++++ 14 files changed, 385 insertions(+), 626 deletions(-) delete mode 100644 sdk-conversation/pom.xml delete mode 100644 sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java delete mode 100644 sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java delete mode 100644 sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java delete mode 100644 sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java rename sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java => sdk/src/main/java/io/dapr/client/domain/ConversationInput.java (59%) rename sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java => sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java (83%) create mode 100644 sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java rename sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java => sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java (65%) create mode 100644 sdk/src/main/java/io/dapr/client/domain/ConversationRole.java diff --git a/pom.xml b/pom.xml index c348c7344f..2c77855d05 100644 --- a/pom.xml +++ b/pom.xml @@ -338,7 +338,6 @@ sdk-autogen sdk sdk-actors - sdk-conversation sdk-workflows sdk-springboot dapr-spring diff --git a/sdk-conversation/pom.xml b/sdk-conversation/pom.xml deleted file mode 100644 index a7e83674b1..0000000000 --- a/sdk-conversation/pom.xml +++ /dev/null @@ -1,169 +0,0 @@ - - 4.0.0 - - - io.dapr - dapr-sdk-parent - 1.15.0-SNAPSHOT - - - dapr-sdk-conversation - jar - 1.15.0-SNAPSHOT - dapr-sdk-conversation - SDK for Conversation - - - false - - - - - io.dapr - dapr-sdk - ${project.parent.version} - - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile - - - org.mockito - mockito-core - test - - - org.mockito - mockito-inline - 4.2.0 - test - - - org.junit.jupiter - junit-jupiter - test - - - org.junit.vintage - junit-vintage-engine - 5.7.0 - test - - - com.microsoft - durabletask-client - 1.5.0 - - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.datatype - jackson-datatype-jsr310 - ${jackson.version} - - - - - - - org.sonatype.plugins - nexus-staging-maven-plugin - - - - org.apache.maven.plugins - maven-source-plugin - 3.2.1 - - - attach-sources - - jar-no-fork - - - - - - - org.apache.maven.plugins - maven-javadoc-plugin - 3.2.0 - - - attach-javadocs - - jar - - - - - - org.jacoco - jacoco-maven-plugin - 0.8.11 - - - default-prepare-agent - - prepare-agent - - - - report - test - - report - - - target/jacoco-report/ - - - - check - - check - - - - - BUNDLE - - - LINE - COVEREDRATIO - 80% - - - - - - - - - - - - diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java deleted file mode 100644 index c3f6b536a8..0000000000 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprAiClient.java +++ /dev/null @@ -1,40 +0,0 @@ -package io.dapr.ai.client; - -import reactor.core.publisher.Mono; - -import javax.annotation.Nullable; -import java.util.List; - -/** - * Defines client operations for managing Dapr AI instances. - */ -interface DaprAiClient { - - /** - * Method to call the Dapr Converse API. - * - * @param conversationComponentName name for the conversation component. - * @param daprConversationInputs prompts that are part of the conversation. - * @param contextId identifier of an existing chat (like in ChatGPT) - * @param scrubPii data that comes from the LLM. - * @param temperature to optimize from creativity or predictability. - * @return @ConversationResponse. - */ - Mono converse( - String conversationComponentName, - List daprConversationInputs, - String contextId, - boolean scrubPii, - double temperature); - - /** - * Method to call the Dapr Converse API. - * - * @param conversationComponentName name for the conversation component. - * @param daprConversationInputs prompts that are part of the conversation. - * @return @ConversationResponse. - */ - Mono converse( - String conversationComponentName, - List daprConversationInputs); -} diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java deleted file mode 100644 index 55e4108a3f..0000000000 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationClient.java +++ /dev/null @@ -1,198 +0,0 @@ -package io.dapr.ai.client; - -import com.google.protobuf.Any; -import io.dapr.client.resiliency.ResiliencyOptions; -import io.dapr.config.Properties; -import io.dapr.exceptions.DaprException; -import io.dapr.internal.exceptions.DaprHttpException; -import io.dapr.internal.grpc.interceptors.DaprTimeoutInterceptor; -import io.dapr.internal.grpc.interceptors.DaprTracingInterceptor; -import io.dapr.internal.resiliency.RetryPolicy; -import io.dapr.internal.resiliency.TimeoutPolicy; -import io.dapr.utils.NetworkUtils; -import io.dapr.v1.DaprGrpc; -import io.dapr.v1.DaprProtos; -import io.grpc.Channel; -import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoSink; -import reactor.util.context.ContextView; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - -public class DaprConversationClient implements AutoCloseable, DaprAiClient { - - /** - * Stub that has the method to call the conversation apis. - */ - private final DaprGrpc.DaprStub asyncStub; - - /** - * The retry policy. - */ - private final RetryPolicy retryPolicy; - - /** - * The timeout policy. - */ - private final TimeoutPolicy timeoutPolicy; - - /** - * Constructor to create conversation client. - */ - public DaprConversationClient() { - this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(new Properties())), null); - } - - /** - * Constructor. - * - * @param properties with client configuration options. - * @param resiliencyOptions retry options. - */ - public DaprConversationClient( - Properties properties, - ResiliencyOptions resiliencyOptions) { - this(DaprGrpc.newStub(NetworkUtils.buildGrpcManagedChannel(properties)), resiliencyOptions); - } - - /** - * ConversationClient constructor. - * - * @param resiliencyOptions timeout and retry policies. - */ - protected DaprConversationClient( - DaprGrpc.DaprStub asyncStub, - ResiliencyOptions resiliencyOptions) { - this.asyncStub = asyncStub; - this.retryPolicy = new RetryPolicy(resiliencyOptions == null ? null : resiliencyOptions.getMaxRetries()); - this.timeoutPolicy = new TimeoutPolicy(resiliencyOptions == null ? null : resiliencyOptions.getTimeout()); - } - - @Override - public Mono converse( - String conversationComponentName, - List daprConversationInputs) { - return converse(conversationComponentName, daprConversationInputs, null, false, 0.0d); - } - - @Override - public Mono converse( - String conversationComponentName, - List daprConversationInputs, - String contextId, - boolean scrubPii, - double temperature) { - - try { - if ((conversationComponentName == null) || (conversationComponentName.trim().isEmpty())) { - throw new IllegalArgumentException("Conversation component name cannot be null or empty."); - } - - if ((daprConversationInputs == null) || (daprConversationInputs.isEmpty())) { - throw new IllegalArgumentException("Conversation inputs cannot be null or empty."); - } - - DaprProtos.ConversationRequest.Builder conversationRequest = DaprProtos.ConversationRequest.newBuilder() - .setTemperature(temperature) - .setScrubPII(scrubPii) - .setName(conversationComponentName); - - if (contextId != null) { - conversationRequest.setContextID(contextId); - } - - for (DaprConversationInput input : daprConversationInputs) { - if (input.getContent() == null || input.getContent().isEmpty()) { - throw new IllegalArgumentException("Conversation input content cannot be null or empty."); - } - - DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder() - .setContent(input.getContent()) - .setScrubPII(input.isScrubPii()); - - if (input.getRole() != null) { - conversationInputOrBuilder.setRole(input.getRole().toString()); - } - - conversationRequest.addInputs(conversationInputOrBuilder.build()); - } - - Mono conversationResponseMono = Mono.deferContextual( - context -> this.createMono( - it -> intercept(context, asyncStub) - .converseAlpha1(conversationRequest.build(), it) - ) - ); - - return conversationResponseMono.map(conversationResponse -> { - - List daprConversationOutputs = new ArrayList<>(); - for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) { - Map parameters = new HashMap<>(); - for (Map.Entry entrySet : conversationResult.getParametersMap().entrySet()) { - parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray()); - } - - DaprConversationOutput daprConversationOutput = - new DaprConversationOutput(conversationResult.getResult(), parameters); - daprConversationOutputs.add(daprConversationOutput); - } - - return new DaprConversationResponse(conversationResponse.getContextID(), daprConversationOutputs); - }); - } catch (Exception ex) { - return DaprException.wrapMono(ex); - } - } - - @Override - public void close() throws Exception { - ManagedChannel channel = (ManagedChannel) this.asyncStub.getChannel(); - - DaprException.wrap(() -> { - if (channel != null && !channel.isShutdown()) { - channel.shutdown(); - } - - return true; - }).call(); - } - - private DaprGrpc.DaprStub intercept( - ContextView context, DaprGrpc.DaprStub client) { - return client.withInterceptors( - new DaprTimeoutInterceptor(this.timeoutPolicy), - new DaprTracingInterceptor(context)); - } - - private Mono createMono(Consumer> consumer) { - return retryPolicy.apply( - Mono.create(sink -> DaprException.wrap(() -> consumer.accept( - createStreamObserver(sink))).run())); - } - - private StreamObserver createStreamObserver(MonoSink sink) { - return new StreamObserver() { - @Override - public void onNext(T value) { - sink.success(value); - } - - @Override - public void onError(Throwable t) { - sink.error(DaprException.propagate(DaprHttpException.fromGrpcExecutionException(null, t))); - } - - @Override - public void onCompleted() { - sink.success(); - } - }; - } -} diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java b/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java deleted file mode 100644 index 11991d857d..0000000000 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationRole.java +++ /dev/null @@ -1,13 +0,0 @@ -package io.dapr.ai.client; - -/** - * Conversation AI supported roles. - */ -public enum DaprConversationRole { - - USER, - - TOOL, - - ASSISSTANT -} diff --git a/sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java b/sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java deleted file mode 100644 index d39ef1413d..0000000000 --- a/sdk-conversation/src/test/java/io/dapr/ai/client/DaprConversationClientTest.java +++ /dev/null @@ -1,185 +0,0 @@ -package io.dapr.ai.client; - -import io.dapr.client.resiliency.ResiliencyOptions; -import io.dapr.config.Properties; -import io.dapr.v1.DaprGrpc; -import io.dapr.v1.DaprProtos; -import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.List; - -import static org.mockito.Mockito.*; - -public class DaprConversationClientTest { - - private DaprGrpc.DaprStub daprStub; - - @Before - public void initialize() { - - ManagedChannel channel = mock(ManagedChannel.class); - daprStub = mock(DaprGrpc.DaprStub.class); - when(daprStub.getChannel()).thenReturn(channel); - when(daprStub.withInterceptors(Mockito.any(), Mockito.any())).thenReturn(daprStub); - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception { - try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(new DaprConversationInput("Hello there !")); - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse(null, daprConversationInputs).block()); - Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception { - try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(new DaprConversationInput("Hello there !")); - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse("", daprConversationInputs).block()); - Assert.assertEquals("Conversation component name cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsEmpty() throws Exception { - try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { - List daprConversationInputs = new ArrayList<>(); - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse("openai", daprConversationInputs).block()); - Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsNull() throws Exception { - try (DaprConversationClient daprConversationClient = - new DaprConversationClient(new Properties(), null)) { - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse("openai", null).block()); - Assert.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsNull() throws Exception { - try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(new DaprConversationInput(null)); - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse("openai", daprConversationInputs).block()); - Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsEmpty() throws Exception { - try (DaprConversationClient daprConversationClient = new DaprConversationClient()) { - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(new DaprConversationInput("")); - - IllegalArgumentException exception = - Assert.assertThrows(IllegalArgumentException.class, () -> - daprConversationClient.converse("openai", daprConversationInputs).block()); - Assert.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); - } - } - - @Test - public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception { - DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() - .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); - - doAnswer(invocation -> { - StreamObserver observer = invocation.getArgument(1); - observer.onNext(conversationResponse); - observer.onCompleted(); - return null; - }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); - - try (DaprConversationClient daprConversationClient = - new DaprConversationClient(daprStub, new ResiliencyOptions())) { - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(new DaprConversationInput("Hello there")); - - DaprConversationResponse daprConversationResponse = - daprConversationClient.converse("openai", daprConversationInputs).block(); - - ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); - verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); - - DaprProtos.ConversationRequest conversationRequest = captor.getValue(); - - Assert.assertEquals("openai", conversationRequest.getName()); - Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); - Assert.assertEquals("Hello How are you", - daprConversationResponse.getDaprConversationOutputs().get(0).getResult()); - } - } - - @Test - public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception { - DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() - .setContextID("contextId") - .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); - - doAnswer(invocation -> { - StreamObserver observer = invocation.getArgument(1); - observer.onNext(conversationResponse); - observer.onCompleted(); - return null; - }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); - - try (DaprConversationClient daprConversationClient = new DaprConversationClient(daprStub, null)) { - DaprConversationInput daprConversationInput = new DaprConversationInput("Hello there") - .setRole(DaprConversationRole.ASSISSTANT) - .setScrubPii(true); - - List daprConversationInputs = new ArrayList<>(); - daprConversationInputs.add(daprConversationInput); - - DaprConversationResponse daprConversationResponse = - daprConversationClient.converse("openai", daprConversationInputs, - "contextId", true, 1.1d).block(); - - ArgumentCaptor captor = - ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); - verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); - - DaprProtos.ConversationRequest conversationRequest = captor.getValue(); - - Assert.assertEquals("openai", conversationRequest.getName()); - Assert.assertEquals("contextId", conversationRequest.getContextID()); - Assert.assertTrue(conversationRequest.getScrubPII()); - Assert.assertEquals(1.1d, conversationRequest.getTemperature(), 0d); - Assert.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); - Assert.assertTrue(conversationRequest.getInputs(0).getScrubPII()); - Assert.assertEquals(DaprConversationRole.ASSISSTANT.toString(), conversationRequest.getInputs(0).getRole()); - Assert.assertEquals("contextId", daprConversationResponse.getContextId()); - Assert.assertEquals("Hello How are you", - daprConversationResponse.getDaprConversationOutputs().get(0).getResult()); - } - } -} diff --git a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java index 0c1264eb19..08037d22da 100644 --- a/sdk/src/main/java/io/dapr/client/DaprClientImpl.java +++ b/sdk/src/main/java/io/dapr/client/DaprClientImpl.java @@ -14,6 +14,7 @@ package io.dapr.client; import com.google.common.base.Strings; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; import com.google.protobuf.Empty; import io.dapr.client.domain.ActorMetadata; @@ -26,6 +27,10 @@ import io.dapr.client.domain.CloudEvent; import io.dapr.client.domain.ComponentMetadata; import io.dapr.client.domain.ConfigurationItem; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationOutput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import io.dapr.client.domain.DaprMetadata; import io.dapr.client.domain.DeleteStateRequest; import io.dapr.client.domain.ExecuteStateTransactionRequest; @@ -1402,6 +1407,79 @@ public Mono getMetadata() { }); } + /** + * {@inheritDoc} + */ + @Override + public Mono converse(ConversationRequest conversationRequest) { + + try { + validateConversationRequest(conversationRequest); + + DaprProtos.ConversationRequest.Builder protosConversationRequestBuilder = DaprProtos.ConversationRequest + .newBuilder().setTemperature(conversationRequest.getTemperature()) + .setScrubPII(conversationRequest.isScrubPii()) + .setName(conversationRequest.getLlmName()); + + if (conversationRequest.getContextId() != null) { + protosConversationRequestBuilder.setContextID(conversationRequest.getContextId()); + } + + for (ConversationInput input : conversationRequest.getConversationInputs()) { + if (input.getContent() == null || input.getContent().isEmpty()) { + throw new IllegalArgumentException("Conversation input content cannot be null or empty."); + } + + DaprProtos.ConversationInput.Builder conversationInputOrBuilder = DaprProtos.ConversationInput.newBuilder() + .setContent(input.getContent()) + .setScrubPII(input.isScrubPii()); + + if (input.getRole() != null) { + conversationInputOrBuilder.setRole(input.getRole().toString()); + } + + protosConversationRequestBuilder.addInputs(conversationInputOrBuilder.build()); + } + + Mono conversationResponseMono = Mono.deferContextual( + context -> this.createMono( + it -> intercept(context, asyncStub) + .converseAlpha1(protosConversationRequestBuilder.build(), it) + ) + ); + + return conversationResponseMono.map(conversationResponse -> { + + List conversationOutputs = new ArrayList<>(); + for (DaprProtos.ConversationResult conversationResult : conversationResponse.getOutputsList()) { + Map parameters = new HashMap<>(); + for (Map.Entry entrySet : conversationResult.getParametersMap().entrySet()) { + parameters.put(entrySet.getKey(), entrySet.getValue().toByteArray()); + } + + ConversationOutput conversationOutput = + new ConversationOutput(conversationResult.getResult(), parameters); + conversationOutputs.add(conversationOutput); + } + + return new ConversationResponse(conversationResponse.getContextID(), conversationOutputs); + }); + } catch (Exception ex) { + return DaprException.wrapMono(ex); + } + } + + private void validateConversationRequest(ConversationRequest conversationRequest) { + if ((conversationRequest.getLlmName() == null) || (conversationRequest.getLlmName().trim().isEmpty())) { + throw new IllegalArgumentException("LLM name cannot be null or empty."); + } + + if ((conversationRequest.getConversationInputs() == null) || (conversationRequest + .getConversationInputs().isEmpty())) { + throw new IllegalArgumentException("Conversation inputs cannot be null or empty."); + } + } + private DaprMetadata buildDaprMetadata(DaprProtos.GetMetadataResponse response) throws IOException { String id = response.getId(); String runtimeVersion = response.getRuntimeVersion(); diff --git a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java index 95911efc23..cdcf595f4c 100644 --- a/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java +++ b/sdk/src/main/java/io/dapr/client/DaprPreviewClient.java @@ -17,6 +17,9 @@ import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.BulkPublishResponseFailedEntry; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import io.dapr.client.domain.LockRequest; import io.dapr.client.domain.QueryStateRequest; import io.dapr.client.domain.QueryStateResponse; @@ -268,4 +271,12 @@ Mono> publishEvents(String pubsubName, String topicNa */ Subscription subscribeToEvents( String pubsubName, String topic, SubscriptionListener listener, TypeRef type); + + /** + * Converse with an LLM. + * + * @param conversationRequest request to be passed to the LLM. + * @return {@link ConversationResponse}. + */ + Mono converse(ConversationRequest conversationRequest); } diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java b/sdk/src/main/java/io/dapr/client/domain/ConversationInput.java similarity index 59% rename from sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java rename to sdk/src/main/java/io/dapr/client/domain/ConversationInput.java index 5189f1da40..a0e3729365 100644 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationInput.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationInput.java @@ -1,22 +1,27 @@ -package io.dapr.ai.client; +package io.dapr.client.domain; /** * Represents an input message for a conversation with an LLM. */ -public class DaprConversationInput { +public class ConversationInput { private final String content; - private DaprConversationRole role; + private ConversationRole role; private boolean scrubPii; - public DaprConversationInput(String content) { + /** + * Constructor. + * + * @param content for the llm. + */ + public ConversationInput(String content) { this.content = content; } /** - * Retrieves the content of the conversation input. + * The message content to send to the LLM. Required * * @return The content to be sent to the LLM. */ @@ -25,21 +30,21 @@ public String getContent() { } /** - * Retrieves the role associated with the conversation input. + * The role for the LLM to assume. * * @return this. */ - public DaprConversationRole getRole() { + public ConversationRole getRole() { return role; } /** - * Sets the role associated with the conversation input. + * Sets the role for LLM to assume. * * @param role The role to assign to the message. * @return this. */ - public DaprConversationInput setRole(DaprConversationRole role) { + public ConversationInput setRole(ConversationRole role) { this.role = role; return this; } @@ -54,12 +59,12 @@ public boolean isScrubPii() { } /** - * Sets whether to scrub Personally Identifiable Information (PII) before sending to the LLM. + * Enable obfuscation of sensitive information present in the content field. Optional * * @param scrubPii A boolean indicating whether to remove PII. * @return this. */ - public DaprConversationInput setScrubPii(boolean scrubPii) { + public ConversationInput setScrubPii(boolean scrubPii) { this.scrubPii = scrubPii; return this; } diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java b/sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java similarity index 83% rename from sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java rename to sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java index a1a5d5ec76..6279ca6049 100644 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationOutput.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationOutput.java @@ -1,4 +1,4 @@ -package io.dapr.ai.client; +package io.dapr.client.domain; import java.util.Collections; import java.util.Map; @@ -6,7 +6,7 @@ /** * Returns the conversation output. */ -public class DaprConversationOutput { +public class ConversationOutput { private final String result; @@ -18,7 +18,7 @@ public class DaprConversationOutput { * @param result result for one of the conversation input. * @param parameters all custom fields. */ - public DaprConversationOutput(String result, Map parameters) { + public ConversationOutput(String result, Map parameters) { this.result = result; this.parameters = parameters; } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java b/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java new file mode 100644 index 0000000000..de58df6827 --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationRequest.java @@ -0,0 +1,106 @@ +package io.dapr.client.domain; + +import java.util.List; + +/** + * Represents a conversation configuration with details about component name, + * conversation inputs, context identifier, PII scrubbing, and temperature control. + */ +public class ConversationRequest { + + private final String llmName; + private final List daprConversationInputs; + private String contextId; + private boolean scrubPii; + private double temperature; + + /** + * Constructs a DaprConversation with a component name and conversation inputs. + * + * @param llmName The name of the LLM component. See a list of all available conversation components + * @see + * @param conversationInputs the list of Dapr conversation inputs + */ + public ConversationRequest(String llmName, List conversationInputs) { + this.llmName = llmName; + this.daprConversationInputs = conversationInputs; + } + + /** + * Gets the conversation component name. + * + * @return the conversation component name + */ + public String getLlmName() { + return llmName; + } + + /** + * Gets the list of Dapr conversation inputs. + * + * @return the list of conversation inputs + */ + public List getConversationInputs() { + return daprConversationInputs; + } + + /** + * Gets the context identifier. + * + * @return the context identifier + */ + public String getContextId() { + return contextId; + } + + /** + * Sets the context identifier. + * + * @param contextId the context identifier to set + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setContextId(String contextId) { + this.contextId = contextId; + return this; + } + + /** + * Checks if PII scrubbing is enabled. + * + * @return true if PII scrubbing is enabled, false otherwise + */ + public boolean isScrubPii() { + return scrubPii; + } + + /** + * Enable obfuscation of sensitive information returning from the LLM. Optional. + * + * @param scrubPii whether to enable PII scrubbing + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setScrubPii(boolean scrubPii) { + this.scrubPii = scrubPii; + return this; + } + + /** + * Gets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @return the temperature value + */ + public double getTemperature() { + return temperature; + } + + /** + * Sets the temperature of the model. Used to optimize for consistency and creativity. Optional + * + * @param temperature the temperature value to set + * @return the current instance of {@link ConversationRequest} + */ + public ConversationRequest setTemperature(double temperature) { + this.temperature = temperature; + return this; + } +} \ No newline at end of file diff --git a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java b/sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java similarity index 65% rename from sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java rename to sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java index e7ab001a5c..9e87b8b4ff 100644 --- a/sdk-conversation/src/main/java/io/dapr/ai/client/DaprConversationResponse.java +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationResponse.java @@ -1,4 +1,4 @@ -package io.dapr.ai.client; +package io.dapr.client.domain; import java.util.Collections; import java.util.List; @@ -6,11 +6,11 @@ /** * Response from the Dapr Conversation API. */ -public class DaprConversationResponse { +public class ConversationResponse { private String contextId; - private final List daprConversationOutputs; + private final List daprConversationOutputs; /** * Constructor. @@ -18,7 +18,7 @@ public class DaprConversationResponse { * @param contextId context id supplied to LLM. * @param daprConversationOutputs outputs from the LLM. */ - public DaprConversationResponse(String contextId, List daprConversationOutputs) { + public ConversationResponse(String contextId, List daprConversationOutputs) { this.contextId = contextId; this.daprConversationOutputs = daprConversationOutputs; } @@ -35,9 +35,9 @@ public String getContextId() { /** * Get list of conversation outputs. * - * @return List{@link DaprConversationOutput}. + * @return List{@link ConversationOutput}. */ - public List getDaprConversationOutputs() { + public List getConversationOutpus() { return Collections.unmodifiableList(this.daprConversationOutputs); } } diff --git a/sdk/src/main/java/io/dapr/client/domain/ConversationRole.java b/sdk/src/main/java/io/dapr/client/domain/ConversationRole.java new file mode 100644 index 0000000000..5374c7fa9d --- /dev/null +++ b/sdk/src/main/java/io/dapr/client/domain/ConversationRole.java @@ -0,0 +1,22 @@ +package io.dapr.client.domain; + +/** + * Conversation AI supported roles. + */ +public enum ConversationRole { + + /** + * User Role. + */ + USER, + + /** + * Tool Role. + */ + TOOL, + + /** + * Assistant Role. + */ + ASSISTANT, +} diff --git a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java index a28dad0f42..21c16278ed 100644 --- a/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java +++ b/sdk/src/test/java/io/dapr/client/DaprPreviewClientGrpcTest.java @@ -21,6 +21,10 @@ import io.dapr.client.domain.BulkPublishRequest; import io.dapr.client.domain.BulkPublishResponse; import io.dapr.client.domain.CloudEvent; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; +import io.dapr.client.domain.ConversationRole; import io.dapr.client.domain.QueryStateItem; import io.dapr.client.domain.QueryStateRequest; import io.dapr.client.domain.QueryStateResponse; @@ -39,7 +43,9 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; import org.mockito.stubbing.Answer; import reactor.core.publisher.Mono; @@ -64,6 +70,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -539,6 +546,142 @@ public void onError(RuntimeException exception) { assertEquals(numDrops, dropCounter.get()); assertEquals(numErrors, errors.size()); } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenComponentNameIsNull() throws Exception { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new ConversationInput("Hello there !")); + + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest(null, daprConversationInputs)).block()); + Assertions.assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationComponentIsEmpty() throws Exception { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new ConversationInput("Hello there !")); + + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("", daprConversationInputs)).block()); + Assertions.assertEquals("LLM name cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsEmpty() throws Exception { + List daprConversationInputs = new ArrayList<>(); + + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", daprConversationInputs)).block()); + Assertions.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputIsNull() throws Exception { + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", null)).block()); + Assertions.assertEquals("Conversation inputs cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsNull() throws Exception { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new ConversationInput(null)); + + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", daprConversationInputs)).block()); + Assertions.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldThrowIllegalArgumentExceptionWhenConversationInputContentIsEmpty() throws Exception { + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new ConversationInput("")); + + IllegalArgumentException exception = + Assertions.assertThrows(IllegalArgumentException.class, () -> + previewClient.converse(new ConversationRequest("openai", daprConversationInputs)).block()); + Assertions.assertEquals("Conversation input content cannot be null or empty.", exception.getMessage()); + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(new ConversationInput("Hello there")); + ConversationResponse daprConversationResponse = + previewClient.converse(new ConversationRequest("openai", daprConversationInputs)).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + Assertions.assertEquals("openai", conversationRequest.getName()); + Assertions.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + Assertions.assertEquals("Hello How are you", + daprConversationResponse.getConversationOutpus().get(0).getResult()); + } + + @Test + public void converseShouldReturnConversationResponseWhenRequiredAndOptionalInputsAreValid() throws Exception { + DaprProtos.ConversationResponse conversationResponse = DaprProtos.ConversationResponse.newBuilder() + .setContextID("contextId") + .addOutputs(DaprProtos.ConversationResult.newBuilder().setResult("Hello How are you").build()).build(); + + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(conversationResponse); + observer.onCompleted(); + return null; + }).when(daprStub).converseAlpha1(any(DaprProtos.ConversationRequest.class), any()); + + ConversationInput daprConversationInput = new ConversationInput("Hello there") + .setRole(ConversationRole.ASSISTANT) + .setScrubPii(true); + + List daprConversationInputs = new ArrayList<>(); + daprConversationInputs.add(daprConversationInput); + + ConversationResponse daprConversationResponse = + previewClient.converse(new ConversationRequest("openai", daprConversationInputs) + .setContextId("contextId") + .setScrubPii(true) + .setTemperature(1.1d)).block(); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(DaprProtos.ConversationRequest.class); + verify(daprStub, times(1)).converseAlpha1(captor.capture(), Mockito.any()); + + DaprProtos.ConversationRequest conversationRequest = captor.getValue(); + + Assertions.assertEquals("openai", conversationRequest.getName()); + Assertions.assertEquals("contextId", conversationRequest.getContextID()); + Assertions.assertTrue(conversationRequest.getScrubPII()); + Assertions.assertEquals(1.1d, conversationRequest.getTemperature(), 0d); + Assertions.assertEquals("Hello there", conversationRequest.getInputs(0).getContent()); + Assertions.assertTrue(conversationRequest.getInputs(0).getScrubPII()); + Assertions.assertEquals(ConversationRole.ASSISTANT.toString(), conversationRequest.getInputs(0).getRole()); + Assertions.assertEquals("contextId", daprConversationResponse.getContextId()); + Assertions.assertEquals("Hello How are you", + daprConversationResponse.getConversationOutpus().get(0).getResult()); + } + private DaprProtos.QueryStateResponse buildQueryStateResponse(List> resp,String token) throws JsonProcessingException { List items = new ArrayList<>(); From c18ccf52aa50973ed4be3331f92f81e47d3440a9 Mon Sep 17 00:00:00 2001 From: sirivarma Date: Tue, 18 Mar 2025 08:14:25 -0700 Subject: [PATCH 5/6] Remove module Signed-off-by: sirivarma --- examples/pom.xml | 5 ----- .../conversation/DemoConversationAI.java | 19 +++++++++++-------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/pom.xml b/examples/pom.xml index 2d7b47a4b5..1bd0448596 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -118,11 +118,6 @@ dapr-sdk ${project.version} - - io.dapr - dapr-sdk-conversation - ${project.version} - com.evanlennick retry4j diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java index 014ca686ea..07652ef223 100644 --- a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java +++ b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java @@ -1,9 +1,10 @@ package io.dapr.examples.conversation; -import io.dapr.ai.client.DaprConversationClient; -import io.dapr.ai.client.DaprConversationInput; -import io.dapr.ai.client.DaprConversationResponse; -import io.dapr.v1.DaprProtos; +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; import reactor.core.publisher.Mono; import java.util.ArrayList; @@ -16,13 +17,15 @@ public class DemoConversationAI { * @param args Input arguments (unused). */ public static void main(String[] args) { - try (DaprConversationClient client = new DaprConversationClient()) { - DaprConversationInput daprConversationInput = new DaprConversationInput("11"); + try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { + ConversationInput daprConversationInput = new ConversationInput("11"); // Component name is the name provided in the metadata block of the conversation.yaml file. - Mono instanceId = client.converse("openai", new ArrayList<>(Collections.singleton(daprConversationInput)), "1234", false, 0.0d); + Mono instanceId = client.converse(new ConversationRequest("openai", new ArrayList<>(Collections.singleton(daprConversationInput))) + .setContextId("contextId") + .setScrubPii(true).setTemperature(1.1d)); System.out.printf("Started a new chaining model workflow with instance ID: %s%n", instanceId); - DaprConversationResponse response = instanceId.block(); + ConversationResponse response = instanceId.block(); System.out.println(response); } catch (Exception e) { From a169d00fea5bcad73743aac25a2c7121b2419f1a Mon Sep 17 00:00:00 2001 From: siri-varma Date: Tue, 18 Mar 2025 13:33:50 -0700 Subject: [PATCH 6/6] Add Integration tests Signed-off-by: siri-varma --- .../conversation/DemoConversationAI.java | 2 +- .../it/testcontainers/DaprConversationIT.java | 119 ++++++++++++++++++ .../TestConversationApplication.java | 26 ++++ .../TestDaprConversationConfiguration.java | 41 ++++++ sdk/pom.xml | 30 ----- 5 files changed, 187 insertions(+), 31 deletions(-) create mode 100644 sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java create mode 100644 sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java create mode 100644 sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprConversationConfiguration.java diff --git a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java index 07652ef223..574de137a7 100644 --- a/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java +++ b/examples/src/main/java/io/dapr/examples/conversation/DemoConversationAI.java @@ -18,7 +18,7 @@ public class DemoConversationAI { */ public static void main(String[] args) { try (DaprPreviewClient client = new DaprClientBuilder().buildPreviewClient()) { - ConversationInput daprConversationInput = new ConversationInput("11"); + ConversationInput daprConversationInput = new ConversationInput("Hello How are you ?"); // Component name is the name provided in the metadata block of the conversation.yaml file. Mono instanceId = client.converse(new ConversationRequest("openai", new ArrayList<>(Collections.singleton(daprConversationInput))) diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java new file mode 100644 index 0000000000..2f3c3b3a86 --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/DaprConversationIT.java @@ -0,0 +1,119 @@ +package io.dapr.it.testcontainers; + +import io.dapr.client.DaprPreviewClient; +import io.dapr.client.domain.ConversationInput; +import io.dapr.client.domain.ConversationRequest; +import io.dapr.client.domain.ConversationResponse; +import io.dapr.testcontainers.Component; +import io.dapr.testcontainers.DaprContainer; +import io.dapr.testcontainers.DaprLogLevel; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.test.context.DynamicPropertyRegistry; +import org.springframework.test.context.DynamicPropertySource; +import org.testcontainers.containers.Network; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Random; + +@SpringBootTest( + webEnvironment = WebEnvironment.RANDOM_PORT, + classes = { + TestDaprConversationConfiguration.class, + TestConversationApplication.class + } +) +@Testcontainers +@Tag("testcontainers") +public class DaprConversationIT { + + private static final Network DAPR_NETWORK = Network.newNetwork(); + private static final Random RANDOM = new Random(); + private static final int PORT = RANDOM.nextInt(1000) + 8000; + + @Container + private static final DaprContainer DAPR_CONTAINER = new DaprContainer("daprio/daprd:1.15.2") + .withAppName("conversation-dapr-app") + .withComponent(new Component("echo", "conversation.echo", "v1", new HashMap<>())) + .withNetwork(DAPR_NETWORK) + .withDaprLogLevel(DaprLogLevel.DEBUG) + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withAppChannelAddress("host.testcontainers.internal") + .withAppPort(PORT); + + /** + * Expose the Dapr ports to the host. + * + * @param registry the dynamic property registry + */ + @DynamicPropertySource + static void daprProperties(DynamicPropertyRegistry registry) { + registry.add("dapr.http.endpoint", DAPR_CONTAINER::getHttpEndpoint); + registry.add("dapr.grpc.endpoint", DAPR_CONTAINER::getGrpcEndpoint); + registry.add("server.port", () -> PORT); + } + + @Autowired + private DaprPreviewClient daprPreviewClient; + + @BeforeEach + public void setUp(){ + org.testcontainers.Testcontainers.exposeHostPorts(PORT); + // Ensure the subscriptions are registered + } + + @Test + public void testConversationSDKShouldHaveSameOutputAndInput() { + ConversationInput conversationInput = new ConversationInput("input this"); + List conversationInputList = new ArrayList<>(); + conversationInputList.add(conversationInput); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this", response.getConversationOutpus().get(0).getResult()); + } + + @Test + public void testConversationSDKShouldScrubPIIEntirelyWhenScrubPIIIsSetInRequestBody() { + List conversationInputList = new ArrayList<>(); + conversationInputList.add(new ConversationInput("input this abcd@gmail.com")); + conversationInputList.add(new ConversationInput("input this +12341567890")); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList) + .setScrubPii(true)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this ", + response.getConversationOutpus().get(0).getResult()); + Assertions.assertEquals("input this ", + response.getConversationOutpus().get(1).getResult()); + } + + @Test + public void testConversationSDKShouldScrubPIIOnlyForTheInputWhereScrubPIIIsSet() { + List conversationInputList = new ArrayList<>(); + conversationInputList.add(new ConversationInput("input this abcd@gmail.com")); + conversationInputList.add(new ConversationInput("input this +12341567890").setScrubPii(true)); + + ConversationResponse response = + this.daprPreviewClient.converse(new ConversationRequest("echo", conversationInputList)).block(); + + Assertions.assertEquals("", response.getContextId()); + Assertions.assertEquals("input this abcd@gmail.com", + response.getConversationOutpus().get(0).getResult()); + Assertions.assertEquals("input this ", + response.getConversationOutpus().get(1).getResult()); + } +} \ No newline at end of file diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java new file mode 100644 index 0000000000..ec33bdf791 --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestConversationApplication.java @@ -0,0 +1,26 @@ +/* + * Copyright 2024 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.it.testcontainers; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class TestConversationApplication { + + public static void main(String[] args) { + SpringApplication.run(TestConversationApplication.class, args); + } + +} \ No newline at end of file diff --git a/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprConversationConfiguration.java b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprConversationConfiguration.java new file mode 100644 index 0000000000..6a096c0584 --- /dev/null +++ b/sdk-tests/src/test/java/io/dapr/it/testcontainers/TestDaprConversationConfiguration.java @@ -0,0 +1,41 @@ +/* + * Copyright 2025 The Dapr Authors + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and +limitations under the License. +*/ + +package io.dapr.it.testcontainers; + +import io.dapr.client.DaprClientBuilder; +import io.dapr.client.DaprPreviewClient; +import io.dapr.config.Properties; +import io.dapr.config.Property; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.util.Map; + +@Configuration +public class TestDaprConversationConfiguration { + + @Bean + public DaprPreviewClient daprPreviewClient( + @Value("${dapr.http.endpoint}") String daprHttpEndpoint, + @Value("${dapr.grpc.endpoint}") String daprGrpcEndpoint + ){ + Map, String> overrides = Map.of( + Properties.HTTP_ENDPOINT, daprHttpEndpoint, + Properties.GRPC_ENDPOINT, daprGrpcEndpoint + ); + + return new DaprClientBuilder().withPropertyOverrides(overrides).buildPreviewClient(); + } +} \ No newline at end of file diff --git a/sdk/pom.xml b/sdk/pom.xml index 041edd8c7d..a97f506dff 100644 --- a/sdk/pom.xml +++ b/sdk/pom.xml @@ -127,36 +127,6 @@ grpc-inprocess test - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile - - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile - - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile - - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile - - - io.dapr - dapr-sdk-autogen - 1.14.0-SNAPSHOT - compile -