From d614febf089ad05cdecc79163f6b4fc6e68fc20d Mon Sep 17 00:00:00 2001 From: Zhengfei Wang <38847871+zhengfeiwang@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:39:02 +0800 Subject: [PATCH] [trace][bugfix] Always use Azure CLI credential for trace local to cloud scenario (#2958) # Description - For trace related credential, only use Azure CLI credential, which should be the only credential type we can rely on. - Add back OpenAI instrument in `promptflow.tracing.start_trace`. # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../promptflow/azure/_storage/blob/client.py | 7 +++++-- .../promptflow/azure/_storage/cosmosdb/client.py | 7 +++++-- src/promptflow-azure/promptflow/azure/_utils/_tracing.py | 4 ++-- src/promptflow-devkit/promptflow/_sdk/_service/app.py | 6 ++---- src/promptflow-tracing/promptflow/tracing/_start_trace.py | 5 +++++ 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py index cb4440a2db6..6f2085229d1 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py @@ -29,9 +29,12 @@ def get_datastore_container_client( ) -> Tuple[ContainerClient, str]: try: if credential is None: - from azure.identity import DefaultAzureCredential + # in cloud scenario, runtime will pass in credential + # so this is local to cloud only code, happens in prompt flow service + # which should rely on Azure CLI credential only + from azure.identity import AzureCliCredential - credential = DefaultAzureCredential() + credential = AzureCliCredential() datastore_definition, datastore_credential = _get_default_datastore( subscription_id, resource_group_name, workspace_name, credential diff --git a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py index 5f175b090ed..6e013ad7cfc 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py @@ -29,9 +29,12 @@ def get_client( container_client = _get_client_from_map(client_key) if container_client is None: if credential is None: - from azure.identity import DefaultAzureCredential + # in cloud scenario, runtime will pass in credential + # so this is local to cloud only code, happens in prompt flow service + # which should rely on Azure CLI credential only + from azure.identity import AzureCliCredential - credential = DefaultAzureCredential() + credential = AzureCliCredential() token = _get_resource_token( container_name, subscription_id, resource_group_name, workspace_name, credential ) diff --git a/src/promptflow-azure/promptflow/azure/_utils/_tracing.py b/src/promptflow-azure/promptflow/azure/_utils/_tracing.py index 3b972ea07e2..67f04a5c7b2 100644 --- a/src/promptflow-azure/promptflow/azure/_utils/_tracing.py +++ b/src/promptflow-azure/promptflow/azure/_utils/_tracing.py @@ -7,12 +7,12 @@ from azure.ai.ml import MLClient from azure.core.exceptions import ResourceNotFoundError +from azure.identity import AzureCliCredential from promptflow._constants import AzureWorkspaceKind, CosmosDBContainerName from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider from promptflow._utils.logger_utils import get_cli_sdk_logger from promptflow.azure import PFClient -from promptflow.azure._cli._utils import get_credentials_for_cli from promptflow.azure._restclient.flow_service_caller import FlowRequestException from promptflow.exceptions import ErrorTarget, UserErrorException @@ -73,7 +73,7 @@ def validate_trace_destination(value: str) -> None: # the resource exists _logger.debug("Validating resource exists...") ml_client = MLClient( - credential=get_credentials_for_cli(), + credential=AzureCliCredential(), # this validation only happens in CLI, so use CLI credential subscription_id=workspace_triad.subscription_id, resource_group_name=workspace_triad.resource_group_name, workspace_name=workspace_triad.workspace_name, diff --git a/src/promptflow-devkit/promptflow/_sdk/_service/app.py b/src/promptflow-devkit/promptflow/_sdk/_service/app.py index 89e431b04a9..225c0224ce1 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_service/app.py +++ b/src/promptflow-devkit/promptflow/_sdk/_service/app.py @@ -202,13 +202,11 @@ def get_created_by_info_with_cache(): try: # The total time of collecting info is about 3s. import jwt - from azure.identity import DefaultAzureCredential + from azure.identity import AzureCliCredential from promptflow.azure._utils.general import get_arm_token - default_credential = DefaultAzureCredential() - - token = get_arm_token(credential=default_credential) + token = get_arm_token(credential=AzureCliCredential()) decoded_token = jwt.decode(token, options={"verify_signature": False}) created_by_for_local_to_cloud_trace.update( { diff --git a/src/promptflow-tracing/promptflow/tracing/_start_trace.py b/src/promptflow-tracing/promptflow/tracing/_start_trace.py index 4c19082fc41..795e805ea6f 100644 --- a/src/promptflow-tracing/promptflow/tracing/_start_trace.py +++ b/src/promptflow-tracing/promptflow/tracing/_start_trace.py @@ -44,6 +44,11 @@ def start_trace( logging.info("skip tracing local setup as the environment variable is set.") return + # openai instrumentation + logging.debug("injecting OpenAI API...") + inject_openai_api() + logging.debug("OpenAI API injected.") + # prepare resource.attributes and set tracer provider res_attrs = {ResourceAttributesFieldName.SERVICE_NAME: RESOURCE_ATTRIBUTES_SERVICE_NAME} if isinstance(resource_attributes, dict):