Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add AI Monitoring instrumentation for AmazonBedrockRuntimeClient.ConverseAsync() #2997

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
// Copyright 2020 New Relic, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

using System;
using System.Collections.Concurrent;
using System.Net;
using System.Threading.Tasks;
using NewRelic.Agent.Api;
using NewRelic.Agent.Extensions.Helpers;
using NewRelic.Agent.Extensions.Llm;
using NewRelic.Agent.Extensions.Providers.Wrapper;
using NewRelic.Reflection;

namespace NewRelic.Providers.Wrapper.Bedrock;

public class ConverseAsyncWrapper : IWrapper
{
public bool IsTransactionRequired => true; // part of spec, only create events for transactions.

private static ConcurrentDictionary<Type, Func<object, object>> _getResultFromGenericTask = new();
private static ConcurrentDictionary<string, string> _libraryVersions = new();
private const string WrapperName = "BedrockConverseAsync";
private const string VendorName = "Bedrock";

public CanWrapResponse CanWrap(InstrumentedMethodInfo methodInfo)
{
return new CanWrapResponse(WrapperName.Equals(methodInfo.RequestedWrapperName));
}

public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall instrumentedMethodCall, IAgent agent, ITransaction transaction)
{
// Don't do anything, including sending the version Supportability metric, if we're disabled
if (!agent.Configuration.AiMonitoringEnabled)
{
return Delegates.NoOp;
}

if (instrumentedMethodCall.IsAsync)
{
transaction.AttachToAsync();
}

dynamic converseRequest = instrumentedMethodCall.MethodCall.MethodArguments[0];
string modelId = converseRequest.ModelId.ToLower();

var operationType = "completion"; // Converse doesn't support embedding
var segment = transaction.StartCustomSegment(instrumentedMethodCall.MethodCall, $"Llm/{operationType}/{VendorName}/{instrumentedMethodCall.MethodCall.Method.MethodName}");

// required per spec
var version = GetOrAddLibraryVersion(instrumentedMethodCall.MethodCall.Method.Type.Assembly.ManifestModule.Assembly.FullName);
agent.RecordSupportabilityMetric($"DotNet/ML/{VendorName}/{version}");

return Delegates.GetAsyncDelegateFor<Task>(
agent,
segment,
false,
TryProcessConverseResponse,
TaskContinuationOptions.ExecuteSynchronously
);

void TryProcessConverseResponse(Task responseTask)
{
// We need the duration, so we end the segment before creating the events.
segment.End();

if (responseTask.IsFaulted)
{
HandleError(segment, converseRequest, responseTask, agent, modelId);
return;
}

dynamic converseResponse = GetTaskResult(responseTask);
if (converseResponse == null || converseResponse.HttpStatusCode >= HttpStatusCode.MultipleChoices)
{
agent.Logger.Warn($"Error processing Converse response for model {modelId}: Response {(converseResponse == null ? "is null" : $"has non-success HttpStatusCode: {converseResponse.HttpStatusCode}")}");
return;
}

ProcessConverseResponse(segment, converseRequest, converseResponse, agent, modelId);
}
}

private void ProcessConverseResponse(ISegment segment, dynamic converseRequest, dynamic converseResponse, IAgent agent, string requestModelId)
{
// if request message content doesn't have a non-null Text property, we can't support instrumentation
// last message is the current prompt
var requestMessage = converseRequest?.Messages?[converseRequest.Messages.Count - 1];
if (converseRequest == null || requestMessage == null || requestMessage.Content == null || requestMessage.Content.Count == 0 || requestMessage.Content[0].Text == null)
{
agent.Logger.Info($"Unable to process Converse response for model {requestModelId}: request was null or message content was not Text");
return;
}

if (converseResponse == null)
{
agent.Logger.Warn($"Error processing Converse response for model {requestModelId}: response was null");
return;
}

// if response message content doesn't have a non-null Text property, we can't support instrumentation
var responseMessage = converseResponse.Output?.Message;
if (responseMessage == null || responseMessage.Content == null || responseMessage.Content.Count == 0 || responseMessage.Content[0].Text == null)
{
agent.Logger.Info($"Unable to process Converse response for model {requestModelId}: response was null or message content was not Text");
return;
}

string requestRole = requestMessage.Role?.Value ?? "user";
string promptText = requestMessage.Content?[0]?.Text ?? "";

string responseRole = responseMessage.Role?.Value ?? "assistant";
string responseText = responseMessage.Content?[0]?.Text ?? "";
string stopReason = converseResponse.StopReason?.Value;

string requestId = converseResponse.ResponseMetadata?.RequestId;
int? requestMaxTokens = converseRequest.InferenceConfig?.MaxTokens;
float? requestTemperature = converseRequest.InferenceConfig?.Temperature;

int? inputTokens = converseResponse.Usage?.InputTokens;
int? outputTokens = converseResponse.Usage?.OutputTokens;

var completionId = EventHelper.CreateChatCompletionEvent(
agent,
segment,
requestId,
requestTemperature,
requestMaxTokens,
requestModelId,
requestModelId,
2, // one request, one response
stopReason,
VendorName,
false,
null, // not available in AWS
null
);

// Prompt
EventHelper.CreateChatMessageEvent(
agent,
segment,
requestId,
null,
requestModelId,
promptText,
requestRole,
0,
completionId,
false,
VendorName,
inputTokens);

// Response
EventHelper.CreateChatMessageEvent(
agent,
segment,
requestId,
null,
requestModelId,
responseText,
responseRole,
1,
completionId,
true,
VendorName,
outputTokens);
}

private void HandleError(ISegment segment, dynamic converseRequest, Task responseTask, IAgent agent, string modelId)
{
agent.Logger.Info($"Error processing Converse response for model {modelId}: {responseTask.Exception!.Message}");

dynamic bedrockException = responseTask.Exception!.InnerException;
if (bedrockException == null)
{
agent.Logger.Warn($"Error processing Converse response for model {modelId}: Task faulted but there was no inner exception");
return;
}

var requestMessage = converseRequest?.Messages?[converseRequest.Messages.Count - 1];

if (converseRequest == null || requestMessage == null)
{
agent.Logger.Warn($"Error processing Converse response for model {modelId}: request Message was null");
return;
}

HttpStatusCode statusCode = bedrockException.StatusCode;
string errorCode = bedrockException.ErrorCode;
string errorMessage = bedrockException.Message;
string requestId = bedrockException.RequestId;

var errorData = new LlmErrorData
{
HttpStatusCode = ((int)statusCode).ToString(),
ErrorCode = errorCode,
ErrorParam = null, // not available in AWS
ErrorMessage = errorMessage
};

string requestRole = requestMessage.Role?.Value ?? "user";
string promptText = requestMessage.Content?[0]?.Text ?? "";
int? requestMaxTokens = converseRequest.InferenceConfig?.MaxTokens;
float? requestTemperature = converseRequest.InferenceConfig?.Temperature;


var completionId = EventHelper.CreateChatCompletionEvent(
agent,
segment,
requestId,
requestTemperature,
requestMaxTokens,
converseRequest.ModelId,
null,
0,
null,
VendorName,
true,
null,
errorData);

// Prompt
EventHelper.CreateChatMessageEvent(
agent,
segment,
requestId,
null,
converseRequest.ModelId,
promptText,
requestRole,
0,
completionId,
false,
VendorName);
}


private string GetOrAddLibraryVersion(string assemblyFullName)
{
return _libraryVersions.GetOrAdd(assemblyFullName, VersionHelpers.GetLibraryVersion(assemblyFullName));
}

private static object GetTaskResult(object task)
{
if (((Task)task).IsFaulted)
{
return null;
}

var getResponse = _getResultFromGenericTask.GetOrAdd(task.GetType(), t => VisibilityBypasser.Instance.GeneratePropertyAccessor<object>(t, "Result"));
return getResponse(task);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ SPDX-License-Identifier: Apache-2.0
<exactMethodMatcher methodName="InvokeModelAsync" />
</match>
</tracerFactory>


<tracerFactory name="BedrockConverseAsync">
<!--
public virtual Task<ConverseResponse> ConverseAsync(ConverseRequest request, System.Threading.CancellationToken cancellationToken = default(CancellationToken))
-->
<match assemblyName="AWSSDK.BedrockRuntime" className="Amazon.BedrockRuntime.AmazonBedrockRuntimeClient">
<exactMethodMatcher methodName="ConverseAsync" />
</match>
</tracerFactory>
</instrumentation>
</extension>
Loading
Loading