Skip to content

Commit

Permalink
Vertex response gen_ai.choice events
Browse files Browse the repository at this point in the history
  • Loading branch information
aabmass committed Feb 3, 2025
1 parent 7398657 commit ca06c60
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
schematized in YAML and the Weaver tool supports it.
"""

from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import Literal

from opentelemetry._events import Event
from opentelemetry.semconv._incubating.attributes import gen_ai_attributes
from opentelemetry.util.types import AnyValue
Expand Down Expand Up @@ -89,3 +94,46 @@ def system_event(
},
body=body,
)


@dataclass
class ChoiceMessage:
"""The message field for a gen_ai.choice event"""

content: AnyValue = None
role: str = "assistant"


FinishReason = Literal[
"content_filter", "error", "length", "stop", "tool_calls"
]


# TODO add tool calls
# https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3216
def choice_event(
*,
finish_reason: FinishReason | str,
index: int,
message: ChoiceMessage,
) -> Event:
"""Creates a choice event, which describes the Gen AI response message.
https://github.com/open-telemetry/semantic-conventions/blob/v1.28.0/docs/gen-ai/gen-ai-events.md#event-gen_aichoice
"""
body: dict[str, AnyValue] = {
"finish_reason": finish_reason,
"index": index,
"message": asdict(
message,
# filter nulls
dict_factory=lambda kvs: {k: v for (k, v) in kvs if v is not None},
),
}

return Event(
name="gen_ai.choice",
attributes={
gen_ai_attributes.GEN_AI_SYSTEM: gen_ai_attributes.GenAiSystemValues.VERTEX_AI.value,
},
body=body,
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
get_server_attributes,
get_span_name,
request_to_events,
response_to_events,
)
from opentelemetry.trace import SpanKind, Tracer

Expand Down Expand Up @@ -131,10 +132,11 @@ def traced_method(

if span.is_recording():
span.set_attributes(get_genai_response_attributes(response))
# TODO: add response attributes and events
# _set_response_attributes(
# span, result, event_logger, capture_content
# )
for event in response_to_events(
response=response, capture_content=capture_content
):
event_logger.emit(event)

return response

return traced_method
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@

from opentelemetry._events import Event
from opentelemetry.instrumentation.vertexai.events import (
ChoiceMessage,
FinishReason,
assistant_event,
choice_event,
system_event,
user_event,
)
Expand All @@ -55,6 +58,9 @@
)


_MODEL = "model"


@dataclass(frozen=True)
class GenerateContentParams:
model: str
Expand Down Expand Up @@ -204,7 +210,7 @@ def request_to_events(

for content in params.contents or []:
# Assistant message
if content.role == "model":
if content.role == _MODEL:
request_content = _parts_to_any_value(
capture_content=capture_content, parts=content.parts
)
Expand All @@ -218,6 +224,27 @@ def request_to_events(
yield user_event(role=content.role, content=request_content)


def response_to_events(
*,
response: prediction_service.GenerateContentResponse
| prediction_service_v1beta1.GenerateContentResponse,
capture_content: bool,
) -> Iterable[Event]:
for candidate in response.candidates:
yield choice_event(
finish_reason=_map_finish_reason(candidate.finish_reason),
index=candidate.index,
# default to "model" since Vertex uses that instead of assistant
message=ChoiceMessage(
role=candidate.content.role or _MODEL,
content=_parts_to_any_value(
capture_content=capture_content,
parts=candidate.content.parts,
),
),
)


def _parts_to_any_value(
*,
capture_content: bool,
Expand All @@ -230,3 +257,26 @@ def _parts_to_any_value(
cast("dict[str, AnyValue]", type(part).to_dict(part)) # type: ignore[reportUnknownMemberType]
for part in parts
]


def _map_finish_reason(
finish_reason: content.Candidate.FinishReason
| content_v1beta1.Candidate.FinishReason,
) -> FinishReason | str:
EnumType = type(finish_reason) # pylint: disable=invalid-name
if (
finish_reason is EnumType.FINISH_REASON_UNSPECIFIED
or finish_reason is EnumType.OTHER
):
return "error"
if finish_reason is EnumType.STOP:
return "stop"
if finish_reason is EnumType.MAX_TOKENS:
return "length"

# There are a lot of specific enum values from Vertex that would map to "content_filter".
# I'm worried trying to map the enum obfuscates the telemetry because 1) it over
# generalizes and 2) half of the values are from the OTel enum and others from the vertex
# enum. See for reference
# https://github.com/googleapis/python-aiplatform/blob/c5023698c7068e2f84523f91b824641c9ef2d694/google/cloud/aiplatform_v1/types/content.py#L786-L822
return finish_reason.name.lower()
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,23 @@ interactions:
"usageMetadata": {
"promptTokenCount": 25,
"candidatesTokenCount": 9,
"totalTokenCount": 34
"totalTokenCount": 34,
"promptTokensDetails": [
{
"modality": 1,
"tokenCount": 25
}
],
"candidatesTokensDetails": [
{
"modality": 1,
"tokenCount": 9
}
]
},
"modelVersion": "gemini-1.5-flash-002"
"modelVersion": "gemini-1.5-flash-002",
"createTime": "2025-02-03T23:33:06.046251Z",
"responseId": "MlKhZ6vpAoifnvgPhceYyA4"
}
headers:
Content-Type:
Expand All @@ -87,7 +101,7 @@ interactions:
- X-Origin
- Referer
content-length:
- '422'
- '715'
status:
code: 200
message: OK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,44 @@ def test_generate_content(
"server.port": 443,
}

# Emits content event
# Emits user and choice events
logs = log_exporter.get_finished_logs()
assert len(logs) == 1
log_record = logs[0].log_record
assert len(logs) == 2
user_log, choice_log = [log_data.log_record for log_data in logs]

span_context = spans[0].get_span_context()
assert log_record.trace_id == span_context.trace_id
assert log_record.span_id == span_context.span_id
assert log_record.trace_flags == span_context.trace_flags
assert log_record.attributes == {
assert user_log.trace_id == span_context.trace_id
assert user_log.span_id == span_context.span_id
assert user_log.trace_flags == span_context.trace_flags
assert user_log.attributes == {
"gen_ai.system": "vertex_ai",
"event.name": "gen_ai.user.message",
}
assert log_record.body == {
assert user_log.body == {
"content": [{"text": "Say this is a test"}],
"role": "user",
}

assert choice_log.trace_id == span_context.trace_id
assert choice_log.span_id == span_context.span_id
assert choice_log.trace_flags == span_context.trace_flags
assert choice_log.attributes == {
"gen_ai.system": "vertex_ai",
"event.name": "gen_ai.choice",
}
assert choice_log.body == {
"finish_reason": "stop",
"index": 0,
"message": {
"content": [
{
"text": "Okay, I understand. I'm ready for your test. Please proceed.\n"
}
],
"role": "model",
},
}


@pytest.mark.vcr
def test_generate_content_without_events(
Expand Down Expand Up @@ -94,15 +115,25 @@ def test_generate_content_without_events(
"server.port": 443,
}

# Emits event without body.content
# Emits user and choice event without body.content
logs = log_exporter.get_finished_logs()
assert len(logs) == 1
log_record = logs[0].log_record
assert log_record.attributes == {
assert len(logs) == 2
user_log, choice_log = [log_data.log_record for log_data in logs]
assert user_log.attributes == {
"gen_ai.system": "vertex_ai",
"event.name": "gen_ai.user.message",
}
assert log_record.body == {"role": "user"}
assert user_log.body == {"role": "user"}

assert choice_log.attributes == {
"gen_ai.system": "vertex_ai",
"event.name": "gen_ai.choice",
}
assert choice_log.body == {
"finish_reason": "stop",
"index": 0,
"message": {"role": "model"},
}


@pytest.mark.vcr
Expand Down Expand Up @@ -286,7 +317,7 @@ def assert_span_error(span: ReadableSpan) -> None:


@pytest.mark.vcr
def test_generate_content_all_input_events(
def test_generate_content_all_events(
log_exporter: InMemoryLogExporter,
instrument_with_content: VertexAIInstrumentor,
):
Expand All @@ -311,10 +342,10 @@ def test_generate_content_all_input_events(
],
)

# Emits a system event, 2 users events, and a assistant event
# Emits a system event, 2 users events, an assistant event, and the choice (response) event
logs = log_exporter.get_finished_logs()
assert len(logs) == 4
system_log, user_log1, assistant_log, user_log2 = [
assert len(logs) == 5
system_log, user_log1, assistant_log, user_log2, choice_log = [
log_data.log_record for log_data in logs
]

Expand Down Expand Up @@ -354,3 +385,16 @@ def test_generate_content_all_input_events(
"content": [{"text": "Address me by name and say this is a test"}],
"role": "user",
}

assert choice_log.attributes == {
"gen_ai.system": "vertex_ai",
"event.name": "gen_ai.choice",
}
assert choice_log.body == {
"finish_reason": "stop",
"index": 0,
"message": {
"content": [{"text": "OpenTelemetry, this is a test.\n"}],
"role": "model",
},
}

0 comments on commit ca06c60

Please sign in to comment.