diff --git a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py index cb4440a2db6..6f2085229d1 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/blob/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/blob/client.py @@ -29,9 +29,12 @@ def get_datastore_container_client( ) -> Tuple[ContainerClient, str]: try: if credential is None: - from azure.identity import DefaultAzureCredential + # in cloud scenario, runtime will pass in credential + # so this is local to cloud only code, happens in prompt flow service + # which should rely on Azure CLI credential only + from azure.identity import AzureCliCredential - credential = DefaultAzureCredential() + credential = AzureCliCredential() datastore_definition, datastore_credential = _get_default_datastore( subscription_id, resource_group_name, workspace_name, credential diff --git a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py index 5f175b090ed..6e013ad7cfc 100644 --- a/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py +++ b/src/promptflow-azure/promptflow/azure/_storage/cosmosdb/client.py @@ -29,9 +29,12 @@ def get_client( container_client = _get_client_from_map(client_key) if container_client is None: if credential is None: - from azure.identity import DefaultAzureCredential + # in cloud scenario, runtime will pass in credential + # so this is local to cloud only code, happens in prompt flow service + # which should rely on Azure CLI credential only + from azure.identity import AzureCliCredential - credential = DefaultAzureCredential() + credential = AzureCliCredential() token = _get_resource_token( container_name, subscription_id, resource_group_name, workspace_name, credential ) diff --git a/src/promptflow-azure/promptflow/azure/_utils/_tracing.py b/src/promptflow-azure/promptflow/azure/_utils/_tracing.py index 3b972ea07e2..67f04a5c7b2 100644 --- a/src/promptflow-azure/promptflow/azure/_utils/_tracing.py +++ b/src/promptflow-azure/promptflow/azure/_utils/_tracing.py @@ -7,12 +7,12 @@ from azure.ai.ml import MLClient from azure.core.exceptions import ResourceNotFoundError +from azure.identity import AzureCliCredential from promptflow._constants import AzureWorkspaceKind, CosmosDBContainerName from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider from promptflow._utils.logger_utils import get_cli_sdk_logger from promptflow.azure import PFClient -from promptflow.azure._cli._utils import get_credentials_for_cli from promptflow.azure._restclient.flow_service_caller import FlowRequestException from promptflow.exceptions import ErrorTarget, UserErrorException @@ -73,7 +73,7 @@ def validate_trace_destination(value: str) -> None: # the resource exists _logger.debug("Validating resource exists...") ml_client = MLClient( - credential=get_credentials_for_cli(), + credential=AzureCliCredential(), # this validation only happens in CLI, so use CLI credential subscription_id=workspace_triad.subscription_id, resource_group_name=workspace_triad.resource_group_name, workspace_name=workspace_triad.workspace_name, diff --git a/src/promptflow-devkit/promptflow/_sdk/_service/app.py b/src/promptflow-devkit/promptflow/_sdk/_service/app.py index 89e431b04a9..225c0224ce1 100644 --- a/src/promptflow-devkit/promptflow/_sdk/_service/app.py +++ b/src/promptflow-devkit/promptflow/_sdk/_service/app.py @@ -202,13 +202,11 @@ def get_created_by_info_with_cache(): try: # The total time of collecting info is about 3s. import jwt - from azure.identity import DefaultAzureCredential + from azure.identity import AzureCliCredential from promptflow.azure._utils.general import get_arm_token - default_credential = DefaultAzureCredential() - - token = get_arm_token(credential=default_credential) + token = get_arm_token(credential=AzureCliCredential()) decoded_token = jwt.decode(token, options={"verify_signature": False}) created_by_for_local_to_cloud_trace.update( { diff --git a/src/promptflow-tracing/promptflow/tracing/_start_trace.py b/src/promptflow-tracing/promptflow/tracing/_start_trace.py index 4c19082fc41..795e805ea6f 100644 --- a/src/promptflow-tracing/promptflow/tracing/_start_trace.py +++ b/src/promptflow-tracing/promptflow/tracing/_start_trace.py @@ -44,6 +44,11 @@ def start_trace( logging.info("skip tracing local setup as the environment variable is set.") return + # openai instrumentation + logging.debug("injecting OpenAI API...") + inject_openai_api() + logging.debug("OpenAI API injected.") + # prepare resource.attributes and set tracer provider res_attrs = {ResourceAttributesFieldName.SERVICE_NAME: RESOURCE_ATTRIBUTES_SERVICE_NAME} if isinstance(resource_attributes, dict):