Skip to content

Commit ec3c51d

Browse files
authored
Added Vertex AI spans for request parameters (#3192)
* Added Vertex AI spans for request parameters * small fixes, get CI passing * Use standard OTel tracing error handling * move nested util * Actually use GAPIC client since thats what we use under the hood Also this is what LangChain uses * Comment out seed for now * Remove unnecessary dict.get() calls * Typing improvements to check that we support both v1 and v1beta1 * Add more teest cases for error conditions and fix span name bug * fix typing * Add todos for error.type
1 parent 3f50c08 commit ec3c51d

12 files changed

+848
-41
lines changed

instrumentation-genai/opentelemetry-instrumentation-vertexai/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
- Added Vertex AI spans for request parameters
11+
([#3192](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3192))
1012
- Initial VertexAI instrumentation
1113
([#3123](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3123))

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/__init__.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,17 @@
4141

4242
from typing import Any, Collection
4343

44+
from wrapt import (
45+
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
46+
)
47+
4448
from opentelemetry._events import get_event_logger
4549
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4650
from opentelemetry.instrumentation.vertexai.package import _instruments
51+
from opentelemetry.instrumentation.vertexai.patch import (
52+
generate_content_create,
53+
)
54+
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
4755
from opentelemetry.semconv.schemas import Schemas
4856
from opentelemetry.trace import get_tracer
4957

@@ -55,20 +63,34 @@ def instrumentation_dependencies(self) -> Collection[str]:
5563
def _instrument(self, **kwargs: Any):
5664
"""Enable VertexAI instrumentation."""
5765
tracer_provider = kwargs.get("tracer_provider")
58-
_tracer = get_tracer(
66+
tracer = get_tracer(
5967
__name__,
6068
"",
6169
tracer_provider,
6270
schema_url=Schemas.V1_28_0.value,
6371
)
6472
event_logger_provider = kwargs.get("event_logger_provider")
65-
_event_logger = get_event_logger(
73+
event_logger = get_event_logger(
6674
__name__,
6775
"",
6876
schema_url=Schemas.V1_28_0.value,
6977
event_logger_provider=event_logger_provider,
7078
)
71-
# TODO: implemented in later PR
79+
80+
wrap_function_wrapper(
81+
module="google.cloud.aiplatform_v1beta1.services.prediction_service.client",
82+
name="PredictionServiceClient.generate_content",
83+
wrapper=generate_content_create(
84+
tracer, event_logger, is_content_enabled()
85+
),
86+
)
87+
wrap_function_wrapper(
88+
module="google.cloud.aiplatform_v1.services.prediction_service.client",
89+
name="PredictionServiceClient.generate_content",
90+
wrapper=generate_content_create(
91+
tracer, event_logger, is_content_enabled()
92+
),
93+
)
7294

7395
def _uninstrument(self, **kwargs: Any) -> None:
7496
"""TODO: implemented in later PR"""

instrumentation-genai/opentelemetry-instrumentation-vertexai/src/opentelemetry/instrumentation/vertexai/patch.py

+121
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,124 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Callable,
21+
MutableSequence,
22+
)
23+
24+
from opentelemetry._events import EventLogger
25+
from opentelemetry.instrumentation.vertexai.utils import (
26+
GenerateContentParams,
27+
get_genai_request_attributes,
28+
get_span_name,
29+
)
30+
from opentelemetry.trace import SpanKind, Tracer
31+
32+
if TYPE_CHECKING:
33+
from google.cloud.aiplatform_v1.services.prediction_service import client
34+
from google.cloud.aiplatform_v1.types import (
35+
content,
36+
prediction_service,
37+
)
38+
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
39+
client as client_v1beta1,
40+
)
41+
from google.cloud.aiplatform_v1beta1.types import (
42+
content as content_v1beta1,
43+
)
44+
from google.cloud.aiplatform_v1beta1.types import (
45+
prediction_service as prediction_service_v1beta1,
46+
)
47+
48+
49+
# Use parameter signature from
50+
# https://github.com/googleapis/python-aiplatform/blob/v1.76.0/google/cloud/aiplatform_v1/services/prediction_service/client.py#L2088
51+
# to handle named vs positional args robustly
52+
def _extract_params(
53+
request: prediction_service.GenerateContentRequest
54+
| prediction_service_v1beta1.GenerateContentRequest
55+
| dict[Any, Any]
56+
| None = None,
57+
*,
58+
model: str | None = None,
59+
contents: MutableSequence[content.Content]
60+
| MutableSequence[content_v1beta1.Content]
61+
| None = None,
62+
**_kwargs: Any,
63+
) -> GenerateContentParams:
64+
# Request vs the named parameters are mututally exclusive or the RPC will fail
65+
if not request:
66+
return GenerateContentParams(
67+
model=model or "",
68+
contents=contents,
69+
)
70+
71+
if isinstance(request, dict):
72+
return GenerateContentParams(**request)
73+
74+
return GenerateContentParams(
75+
model=request.model,
76+
contents=request.contents,
77+
system_instruction=request.system_instruction,
78+
tools=request.tools,
79+
tool_config=request.tool_config,
80+
labels=request.labels,
81+
safety_settings=request.safety_settings,
82+
generation_config=request.generation_config,
83+
)
84+
85+
86+
def generate_content_create(
87+
tracer: Tracer, event_logger: EventLogger, capture_content: bool
88+
):
89+
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
90+
91+
def traced_method(
92+
wrapped: Callable[
93+
...,
94+
prediction_service.GenerateContentResponse
95+
| prediction_service_v1beta1.GenerateContentResponse,
96+
],
97+
instance: client.PredictionServiceClient
98+
| client_v1beta1.PredictionServiceClient,
99+
args: Any,
100+
kwargs: Any,
101+
):
102+
params = _extract_params(*args, **kwargs)
103+
span_attributes = get_genai_request_attributes(params)
104+
105+
span_name = get_span_name(span_attributes)
106+
with tracer.start_as_current_span(
107+
name=span_name,
108+
kind=SpanKind.CLIENT,
109+
attributes=span_attributes,
110+
) as _span:
111+
# TODO: emit request events
112+
# if span.is_recording():
113+
# for message in kwargs.get("messages", []):
114+
# event_logger.emit(
115+
# message_to_event(message, capture_content)
116+
# )
117+
118+
# TODO: set error.type attribute
119+
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
120+
result = wrapped(*args, **kwargs)
121+
# TODO: handle streaming
122+
# if is_streaming(kwargs):
123+
# return StreamWrapper(
124+
# result, span, event_logger, capture_content
125+
# )
126+
127+
# TODO: add response attributes and events
128+
# if span.is_recording():
129+
# _set_response_attributes(
130+
# span, result, event_logger, capture_content
131+
# )
132+
return result
133+
134+
return traced_method
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import re
18+
from dataclasses import dataclass
19+
from os import environ
20+
from typing import (
21+
TYPE_CHECKING,
22+
Mapping,
23+
Sequence,
24+
)
25+
26+
from opentelemetry.semconv._incubating.attributes import (
27+
gen_ai_attributes as GenAIAttributes,
28+
)
29+
from opentelemetry.util.types import AttributeValue
30+
31+
if TYPE_CHECKING:
32+
from google.cloud.aiplatform_v1.types import content, tool
33+
from google.cloud.aiplatform_v1beta1.types import (
34+
content as content_v1beta1,
35+
)
36+
from google.cloud.aiplatform_v1beta1.types import (
37+
tool as tool_v1beta1,
38+
)
39+
40+
41+
@dataclass(frozen=True)
42+
class GenerateContentParams:
43+
model: str
44+
contents: (
45+
Sequence[content.Content] | Sequence[content_v1beta1.Content] | None
46+
) = None
47+
system_instruction: content.Content | content_v1beta1.Content | None = None
48+
tools: Sequence[tool.Tool] | Sequence[tool_v1beta1.Tool] | None = None
49+
tool_config: tool.ToolConfig | tool_v1beta1.ToolConfig | None = None
50+
labels: Mapping[str, str] | None = None
51+
safety_settings: (
52+
Sequence[content.SafetySetting]
53+
| Sequence[content_v1beta1.SafetySetting]
54+
| None
55+
) = None
56+
generation_config: (
57+
content.GenerationConfig | content_v1beta1.GenerationConfig | None
58+
) = None
59+
60+
61+
def get_genai_request_attributes(
62+
params: GenerateContentParams,
63+
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
64+
):
65+
model = _get_model_name(params.model)
66+
generation_config = params.generation_config
67+
attributes: dict[str, AttributeValue] = {
68+
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
69+
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
70+
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
71+
}
72+
73+
if not generation_config:
74+
return attributes
75+
76+
# Check for optional fields
77+
# https://proto-plus-python.readthedocs.io/en/stable/fields.html#optional-fields
78+
if "temperature" in generation_config:
79+
attributes[GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE] = (
80+
generation_config.temperature
81+
)
82+
if "top_p" in generation_config:
83+
attributes[GenAIAttributes.GEN_AI_REQUEST_TOP_P] = (
84+
generation_config.top_p
85+
)
86+
if "max_output_tokens" in generation_config:
87+
attributes[GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS] = (
88+
generation_config.max_output_tokens
89+
)
90+
if "presence_penalty" in generation_config:
91+
attributes[GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY] = (
92+
generation_config.presence_penalty
93+
)
94+
if "frequency_penalty" in generation_config:
95+
attributes[GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY] = (
96+
generation_config.frequency_penalty
97+
)
98+
# Uncomment once GEN_AI_REQUEST_SEED is released in 1.30
99+
# https://github.com/open-telemetry/semantic-conventions/pull/1710
100+
# if "seed" in generation_config:
101+
# attributes[GenAIAttributes.GEN_AI_REQUEST_SEED] = (
102+
# generation_config.seed
103+
# )
104+
if "stop_sequences" in generation_config:
105+
attributes[GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES] = (
106+
generation_config.stop_sequences
107+
)
108+
109+
return attributes
110+
111+
112+
_MODEL_STRIP_RE = re.compile(
113+
r"^projects/(.*)/locations/(.*)/publishers/google/models/"
114+
)
115+
116+
117+
def _get_model_name(model: str) -> str:
118+
return _MODEL_STRIP_RE.sub("", model)
119+
120+
121+
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (
122+
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
123+
)
124+
125+
126+
def is_content_enabled() -> bool:
127+
capture_content = environ.get(
128+
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
129+
)
130+
131+
return capture_content.lower() == "true"
132+
133+
134+
def get_span_name(span_attributes: Mapping[str, AttributeValue]) -> str:
135+
name = span_attributes[GenAIAttributes.GEN_AI_OPERATION_NAME]
136+
model = span_attributes[GenAIAttributes.GEN_AI_REQUEST_MODEL]
137+
if not model:
138+
return f"{name}"
139+
return f"{name} {model}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
interactions:
2+
- request:
3+
body: |-
4+
{
5+
"contents": [
6+
{
7+
"role": "user",
8+
"parts": [
9+
{
10+
"text": "Say this is a test"
11+
}
12+
]
13+
}
14+
]
15+
}
16+
headers:
17+
Accept:
18+
- '*/*'
19+
Accept-Encoding:
20+
- gzip, deflate
21+
Connection:
22+
- keep-alive
23+
Content-Length:
24+
- '141'
25+
Content-Type:
26+
- application/json
27+
User-Agent:
28+
- python-requests/2.32.3
29+
method: POST
30+
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/fake-project/locations/us-central1/publishers/google/models/gemini-1.5-flash-002:generateContent?%24alt=json%3Benum-encoding%3Dint
31+
response:
32+
body:
33+
string: |-
34+
{
35+
"candidates": [
36+
{
37+
"content": {
38+
"role": "model",
39+
"parts": [
40+
{
41+
"text": "Okay, I understand. I'm ready for your test. Please proceed.\n"
42+
}
43+
]
44+
},
45+
"finishReason": 1,
46+
"avgLogprobs": -0.005692833348324424
47+
}
48+
],
49+
"usageMetadata": {
50+
"promptTokenCount": 5,
51+
"candidatesTokenCount": 19,
52+
"totalTokenCount": 24
53+
},
54+
"modelVersion": "gemini-1.5-flash-002"
55+
}
56+
headers:
57+
Content-Type:
58+
- application/json; charset=UTF-8
59+
Transfer-Encoding:
60+
- chunked
61+
Vary:
62+
- Origin
63+
- X-Origin
64+
- Referer
65+
content-length:
66+
- '453'
67+
status:
68+
code: 200
69+
message: OK
70+
version: 1

0 commit comments

Comments
 (0)