Skip to content

Commit 76887e9

Browse files
committed
Added Vertex AI spans for request parameters
1 parent d2ae60f commit 76887e9

File tree

11 files changed

+749
-41
lines changed

11 files changed

+749
-41
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
- Added Vertex AI spans for request parameters
11+
([#3192](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3192))
1012
- Initial VertexAI instrumentation
1113
([#3123](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3123))

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,17 @@
4141

4242
from typing import Any, Collection
4343

44+
from wrapt import (
45+
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
46+
)
47+
4448
from opentelemetry._events import get_event_logger
4549
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4650
from opentelemetry.instrumentation.vertexai.package import _instruments
51+
from opentelemetry.instrumentation.vertexai.patch import (
52+
generate_content_create,
53+
)
54+
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
4755
from opentelemetry.semconv.schemas import Schemas
4856
from opentelemetry.trace import get_tracer
4957

@@ -55,20 +63,29 @@ def instrumentation_dependencies(self) -> Collection[str]:
5563
def _instrument(self, **kwargs: Any):
5664
"""Enable VertexAI instrumentation."""
5765
tracer_provider = kwargs.get("tracer_provider")
58-
_tracer = get_tracer(
66+
tracer = get_tracer(
5967
__name__,
6068
"",
6169
tracer_provider,
6270
schema_url=Schemas.V1_28_0.value,
6371
)
6472
event_logger_provider = kwargs.get("event_logger_provider")
65-
_event_logger = get_event_logger(
73+
event_logger = get_event_logger(
6674
__name__,
6775
"",
6876
schema_url=Schemas.V1_28_0.value,
6977
event_logger_provider=event_logger_provider,
7078
)
71-
# TODO: implemented in later PR
79+
80+
wrap_function_wrapper(
81+
module="vertexai.generative_models._generative_models",
82+
# Patching this base class also instruments the vertexai.preview.generative_models
83+
# package
84+
name="_GenerativeModel.generate_content",
85+
wrapper=generate_content_create(
86+
tracer, event_logger, is_content_enabled()
87+
),
88+
)
7289

7390
def _uninstrument(self, **kwargs: Any) -> None:
7491
"""TODO: implemented in later PR"""

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

+101
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,104 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional
18+
19+
from opentelemetry._events import EventLogger
20+
from opentelemetry.instrumentation.vertexai.utils import (
21+
GenerateContentParams,
22+
get_genai_request_attributes,
23+
get_span_name,
24+
handle_span_exception,
25+
)
26+
from opentelemetry.trace import SpanKind, Tracer
27+
28+
if TYPE_CHECKING:
29+
from vertexai.generative_models import (
30+
GenerationResponse,
31+
Tool,
32+
ToolConfig,
33+
)
34+
from vertexai.generative_models._generative_models import (
35+
ContentsType,
36+
GenerationConfigType,
37+
SafetySettingsType,
38+
_GenerativeModel,
39+
)
40+
41+
42+
def generate_content_create(
43+
tracer: Tracer, event_logger: EventLogger, capture_content: bool
44+
):
45+
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
46+
47+
def traced_method(
48+
wrapped: Callable[
49+
..., GenerationResponse | Iterable[GenerationResponse]
50+
],
51+
instance: _GenerativeModel,
52+
args: Any,
53+
kwargs: Any,
54+
):
55+
# Use exact parameter signature to handle named vs positional args robustly
56+
def extract_params(
57+
contents: ContentsType,
58+
*,
59+
generation_config: Optional[GenerationConfigType] = None,
60+
safety_settings: Optional[SafetySettingsType] = None,
61+
tools: Optional[list[Tool]] = None,
62+
tool_config: Optional[ToolConfig] = None,
63+
labels: Optional[dict[str, str]] = None,
64+
stream: bool = False,
65+
) -> GenerateContentParams:
66+
return GenerateContentParams(
67+
contents=contents,
68+
generation_config=generation_config,
69+
safety_settings=safety_settings,
70+
tools=tools,
71+
tool_config=tool_config,
72+
labels=labels,
73+
stream=stream,
74+
)
75+
76+
params = extract_params(*args, **kwargs)
77+
78+
span_attributes = get_genai_request_attributes(instance, params)
79+
80+
span_name = get_span_name(span_attributes)
81+
with tracer.start_as_current_span(
82+
name=span_name,
83+
kind=SpanKind.CLIENT,
84+
attributes=span_attributes,
85+
end_on_exit=False,
86+
) as span:
87+
# TODO: emit request events
88+
# if span.is_recording():
89+
# for message in kwargs.get("messages", []):
90+
# event_logger.emit(
91+
# message_to_event(message, capture_content)
92+
# )
93+
94+
try:
95+
result = wrapped(*args, **kwargs)
96+
# TODO: handle streaming
97+
# if is_streaming(kwargs):
98+
# return StreamWrapper(
99+
# result, span, event_logger, capture_content
100+
# )
101+
102+
# TODO: add response attributes and events
103+
# if span.is_recording():
104+
# _set_response_attributes(
105+
# span, result, event_logger, capture_content
106+
# )
107+
span.end()
108+
return result
109+
110+
except Exception as error:
111+
handle_span_exception(span, error)
112+
raise
113+
114+
return traced_method
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass
18+
from os import environ
19+
from typing import (
20+
TYPE_CHECKING,
21+
Dict,
22+
List,
23+
Mapping,
24+
Optional,
25+
TypedDict,
26+
cast,
27+
)
28+
29+
from opentelemetry.semconv._incubating.attributes import (
30+
gen_ai_attributes as GenAIAttributes,
31+
)
32+
from opentelemetry.semconv.attributes import (
33+
error_attributes as ErrorAttributes,
34+
)
35+
from opentelemetry.trace import Span
36+
from opentelemetry.trace.status import Status, StatusCode
37+
from opentelemetry.util.types import AttributeValue
38+
39+
if TYPE_CHECKING:
40+
from vertexai.generative_models import Tool, ToolConfig
41+
from vertexai.generative_models._generative_models import (
42+
ContentsType,
43+
GenerationConfigType,
44+
SafetySettingsType,
45+
_GenerativeModel,
46+
)
47+
48+
49+
@dataclass(frozen=True)
50+
class GenerateContentParams:
51+
contents: ContentsType
52+
generation_config: Optional[GenerationConfigType]
53+
safety_settings: Optional[SafetySettingsType]
54+
tools: Optional[List["Tool"]]
55+
tool_config: Optional["ToolConfig"]
56+
labels: Optional[Dict[str, str]]
57+
stream: bool
58+
59+
60+
class GenerationConfigDict(TypedDict, total=False):
61+
temperature: Optional[float]
62+
top_p: Optional[float]
63+
top_k: Optional[int]
64+
max_output_tokens: Optional[int]
65+
stop_sequences: Optional[List[str]]
66+
presence_penalty: Optional[float]
67+
frequency_penalty: Optional[float]
68+
seed: Optional[int]
69+
# And more fields which aren't needed yet
70+
71+
72+
def get_genai_request_attributes(
73+
# TODO: use types
74+
instance: _GenerativeModel,
75+
params: GenerateContentParams,
76+
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
77+
):
78+
model = _get_model_name(instance)
79+
generation_config = _get_generation_config(instance, params)
80+
attributes = {
81+
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
82+
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
83+
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
84+
GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE: generation_config.get(
85+
"temperature"
86+
),
87+
GenAIAttributes.GEN_AI_REQUEST_TOP_P: generation_config.get("top_p"),
88+
GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS: generation_config.get(
89+
"max_output_tokens"
90+
),
91+
GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY: generation_config.get(
92+
"presence_penalty"
93+
),
94+
GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: generation_config.get(
95+
"frequency_penalty"
96+
),
97+
GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED: generation_config.get(
98+
"seed"
99+
),
100+
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES: generation_config.get(
101+
"stop_sequences"
102+
),
103+
}
104+
105+
# filter out None values
106+
return {k: v for k, v in attributes.items() if v is not None}
107+
108+
109+
def _get_generation_config(
110+
instance: _GenerativeModel,
111+
params: GenerateContentParams,
112+
) -> GenerationConfigDict:
113+
generation_config = params.generation_config or instance._generation_config
114+
if generation_config is None:
115+
return {}
116+
if isinstance(generation_config, dict):
117+
return cast(GenerationConfigDict, generation_config)
118+
return cast(GenerationConfigDict, generation_config.to_dict())
119+
120+
121+
_RESOURCE_PREFIX = "publishers/google/models/"
122+
123+
124+
def _get_model_name(instance: _GenerativeModel) -> str:
125+
model_name = instance._model_name
126+
127+
# Can use str.removeprefix() once 3.8 is dropped
128+
if model_name.startswith(_RESOURCE_PREFIX):
129+
model_name = model_name[len(_RESOURCE_PREFIX) :]
130+
return model_name
131+
132+
133+
# TODO: Everything below here should be replaced with
134+
# opentelemetry.instrumentation.genai_utils instead once it is released.
135+
# https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3191
136+
137+
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (
138+
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
139+
)
140+
141+
142+
def is_content_enabled() -> bool:
143+
capture_content = environ.get(
144+
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
145+
)
146+
147+
return capture_content.lower() == "true"
148+
149+
150+
def get_span_name(span_attributes: Mapping[str, AttributeValue]):
151+
name = span_attributes.get(GenAIAttributes.GEN_AI_OPERATION_NAME, "")
152+
model = span_attributes.get(GenAIAttributes.GEN_AI_REQUEST_MODEL, "")
153+
return f"{name} {model}"
154+
155+
156+
def handle_span_exception(span: Span, error: Exception):
157+
span.set_status(Status(StatusCode.ERROR, str(error)))
158+
if span.is_recording():
159+
span.set_attribute(
160+
ErrorAttributes.ERROR_TYPE, type(error).__qualname__
161+
)
162+
span.end()

0 commit comments

Comments
 (0)