diff --git a/src/promptflow/promptflow/_core/metric_logger.py b/src/promptflow/promptflow/_core/metric_logger.py index b17891606a4..e5c14f2ceed 100644 --- a/src/promptflow/promptflow/_core/metric_logger.py +++ b/src/promptflow/promptflow/_core/metric_logger.py @@ -45,6 +45,15 @@ def remove_metric_logger(self, logger_func: Callable): def log_metric(key, value, variant_id=None): + """Log a metric for current promptflow run. + + :param key: Metric name. + :type key: str + :param value: Metric value. + :type value: float + :param variant_id: Variant id for the metric. + :type variant_id: str + """ MetricLoggerManager.get_instance().log_metric(key, value, variant_id) diff --git a/src/promptflow/promptflow/_core/tool.py b/src/promptflow/promptflow/_core/tool.py index a65b0b39d5d..48ab30db4dd 100644 --- a/src/promptflow/promptflow/_core/tool.py +++ b/src/promptflow/promptflow/_core/tool.py @@ -7,7 +7,7 @@ import logging from abc import ABC from enum import Enum -from typing import Optional +from typing import Callable, Optional module_logger = logging.getLogger(__name__) @@ -39,7 +39,15 @@ def active_instance(cls) -> Optional["ToolInvoker"]: return cls._active_tool_invoker -def tool(f): +def tool(f: Callable) -> Callable: + """Decorator for tool functions. The decorated function will be registered as a tool and can be used in a flow. + + :param f: The tool function. + :type f: Callable + :return: The decorated function. + :rtype: Callable + """ + @functools.wraps(f) def new_f(*args, **kwargs): tool_invoker = ToolInvoker.active_instance() diff --git a/src/promptflow/promptflow/_sdk/_pf_client.py b/src/promptflow/promptflow/_sdk/_pf_client.py index 385716ddfe8..ee7107c328f 100644 --- a/src/promptflow/promptflow/_sdk/_pf_client.py +++ b/src/promptflow/promptflow/_sdk/_pf_client.py @@ -101,6 +101,7 @@ def stream(self, run: Union[str, Run]) -> Run: :param run: Run object or name of the run. :type run: Union[str, ~promptflow.sdk.entities.Run] :return: flow run info. + :rtype: ~promptflow.sdk.entities.Run """ return self.runs.stream(run) @@ -134,15 +135,17 @@ def visualize(self, runs: Union[List[str], List[Run]]) -> None: @property def runs(self) -> RunOperations: + """Run operations that can manage runs.""" return self._runs @property def connections(self) -> ConnectionOperations: + """Connection operations that can manage connections.""" return self._connections @property def flows(self) -> FlowOperations: - """Operations on the flow, such as test/debug the flow, chat with chat flow.""" + """Operations on the flow that can manage flows.""" return self._flows def test( @@ -154,7 +157,7 @@ def test( node: str = None, environment_variables: dict = None, ) -> dict: - """Test flow or node + """Test flow or node. :param flow: path to flow directory to test :type flow: Union[str, PathLike] diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index f7536493736..0716d2f17f1 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -44,6 +44,20 @@ class _Connection(YAMLTranslatableMixin): + """A connection entity that stores the connection information. + + :param name: Connection name + :type name: str + :param type: Possible values include: "OpenAI", "AzureOpenAI", "Custom". + :type type: str + :param module: The module of connection class, used for execution. + :type module: str + :param configs: The configs kv pairs. + :type configs: Dict[str, str] + :param secrets: The secrets kv pairs. + :type secrets: Dict[str, str] + """ + TYPE = ConnectionType._NOT_SET def __init__( @@ -54,13 +68,6 @@ def __init__( secrets: Dict[str, str] = None, **kwargs, ): - """ - :param name: Connection name - :param type: Possible values include: "OpenAI", "AzureOpenAI", "Custom". - :param module: The module of connection class, used for execution. - :param configs: The configs kv pairs. - :param secrets: The secrets kv pairs. - """ self.name = name self.type = self.TYPE self.class_name = f"{self.TYPE.value}Connection" # The type in executor connection dict @@ -290,14 +297,18 @@ def api_key(self, value): class AzureOpenAIConnection(_StrongTypeConnection): - """ - Azure Open AI connection. + """Azure Open AI connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param api_type: The api type, default "azure". + :type api_type: str :param api_version: The api version, default "2023-07-01-preview". + :type api_version: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.AZURE_OPEN_AI @@ -345,12 +356,14 @@ def api_version(self, value): class OpenAIConnection(_StrongTypeConnection): - """ - Open AI connection. + """Open AI connection. :param api_key: The api key. + :type api_key: str :param organization: Optional. The unique identifier for your organization which can be used in API requests. + :type organization: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.OPEN_AI @@ -376,11 +389,12 @@ def organization(self, value): class SerpConnection(_StrongTypeConnection): - """ - Serp connection. + """Serp connection. :param api_key: The api key. + :type api_key: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.SERP @@ -412,12 +426,14 @@ def api_base(self, value): class QdrantConnection(_EmbeddingStoreConnection): - """ - Qdrant connection. + """Qdrant connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.QDRANT @@ -428,12 +444,14 @@ def _get_schema_cls(cls): class WeaviateConnection(_EmbeddingStoreConnection): - """ - Weaviate connection. + """Weaviate connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.WEAVIATE @@ -444,13 +462,16 @@ def _get_schema_cls(cls): class CognitiveSearchConnection(_StrongTypeConnection): - """ - Cognitive Search connection. + """Cognitive Search connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param api_version: The api version, default "2023-07-01-Preview". + :type api_version: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.COGNITIVE_SEARCH @@ -490,10 +511,15 @@ class AzureContentSafetyConnection(_StrongTypeConnection): Azure Content Safety connection. :param api_key: The api key. + :type api_key: str :param endpoint: The api endpoint. + :type endpoint: str :param api_version: The api version, default "2023-04-30-preview". + :type api_version: str :param api_type: The api type, default "Content Safety". + :type api_type: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.AZURE_CONTENT_SAFETY @@ -546,14 +572,18 @@ def api_type(self, value): class FormRecognizerConnection(AzureContentSafetyConnection): - """ - Form Recognizer connection. + """Form Recognizer connection. :param api_key: The api key. + :type api_key: str :param endpoint: The api endpoint. + :type endpoint: str :param api_version: The api version, default "2023-07-31". + :type api_version: str :param api_type: The api type, default "Form Recognizer". + :type api_type: str :param name: Connection name. + :type name: str """ # Note: FormRecognizer and ContentSafety are using CognitiveService type in ARM, so keys are the same. @@ -570,10 +600,14 @@ def _get_schema_cls(cls): class CustomConnection(_Connection): - """ + """Custom connection. + :param configs: The configs kv pairs. + :type configs: Dict[str, str] :param secrets: The secrets kv pairs. + :type secrets: Dict[str, str] :param name: Connection name + :type name: str """ TYPE = ConnectionType.CUSTOM diff --git a/src/promptflow/promptflow/_sdk/entities/_run.py b/src/promptflow/promptflow/_sdk/entities/_run.py index a3571fdb541..6094afd8945 100644 --- a/src/promptflow/promptflow/_sdk/entities/_run.py +++ b/src/promptflow/promptflow/_sdk/entities/_run.py @@ -45,15 +45,53 @@ class Run(YAMLTranslatableMixin): + """Flow run entity. + + :param flow: Path of the flow directory. + :type flow: Path + :param name: Name of the run. + :type name: Optional[str] + :param data: Input data for the run. + :type data: Optional[str] + :param variant: Variant of the run. + :type variant: Optional[str] + :param run: Parent run or run ID. + :type run: Optional[Union[Run, str]] + :param column_mapping: Column mapping for the run. Optional since it's not stored in the database. + :type column_mapping: Optional[dict] + :param display_name: Display name of the run. + :type display_name: Optional[str] + :param description: Description of the run. + :type description: Optional[str] + :param tags: Tags of the run. + :type tags: Optional[List[Dict[str, str]]] + :param created_on: Date and time the run was created. + :type created_on: Optional[datetime.datetime] + :param start_time: Date and time the run started. + :type start_time: Optional[datetime.datetime] + :param end_time: Date and time the run ended. + :type end_time: Optional[datetime.datetime] + :param status: Status of the run. + :type status: Optional[str] + :param environment_variables: Environment variables for the run. + :type environment_variables: Optional[Dict[str, str]] + :param connections: Connections for the run. + :type connections: Optional[Dict[str, Dict]] + :param properties: Properties of the run. + :type properties: Optional[Dict[str, Any]] + :param kwargs: Additional keyword arguments. + :type kwargs: Optional[dict] + """ + def __init__( self, flow: Path, - name: str = None, + name: Optional[str] = None, # input fields are optional since it's not stored in DB - data: str = None, - variant: str = None, - run: Union["Run", str] = None, - column_mapping: dict = None, + data: Optional[str] = None, + variant: Optional[str] = None, + run: Optional[Union["Run", str]] = None, + column_mapping: Optional[dict] = None, display_name: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, @@ -62,26 +100,11 @@ def __init__( start_time: Optional[datetime.datetime] = None, end_time: Optional[datetime.datetime] = None, status: Optional[str] = None, - environment_variables: Dict[str, str] = None, - connections: Dict[str, Dict] = None, + environment_variables: Optional[Dict[str, str]] = None, + connections: Optional[Dict[str, Dict]] = None, properties: Optional[Dict[str, Any]] = None, **kwargs, ): - """Flow run. - - :param name: Name of the run. - :type name: str - :param type: Type of the run, should be one of "bulk", "evaluate" or "pairwise_evaluate". - :type type: str - :param flow: Path of the flow directory. - :type flow: Path - :param display_name: Display name of the run. - :type display_name: str - :param description: Description of the run. - :type description: str - :param tags: Tags of the run. - :type tags: List[Dict[str, str]] - """ # TODO: remove when RUN CRUD don't depend on this self.type = RunTypes.BATCH self.data = data diff --git a/src/promptflow/promptflow/azure/_pf_client.py b/src/promptflow/promptflow/azure/_pf_client.py index 61370d97510..a7e544cbfb4 100644 --- a/src/promptflow/promptflow/azure/_pf_client.py +++ b/src/promptflow/promptflow/azure/_pf_client.py @@ -86,7 +86,7 @@ def __init__( @property def ml_client(self): - """Return a client class to interact with Azure ML services.""" + """Return a client to interact with Azure ML services.""" return self._ml_client @classmethod @@ -250,6 +250,7 @@ def load_as_component( ) -> "Component": """ Load a flow as a component. + :param source: Source of the flow. Should be a path to a flow dag yaml file or a flow directory. :type source: Union[str, PathLike, IO[AnyStr]] :param component_type: Type of the loaded component, support parallel only for now. @@ -300,4 +301,5 @@ def _add_user_agent(self, kwargs) -> None: @property def runs(self): + """Return the run operation object that can manage runs.""" return self._runs diff --git a/src/promptflow/promptflow/azure/operations/_flow_opearations.py b/src/promptflow/promptflow/azure/operations/_flow_opearations.py index 8d5ecd86f3f..398d5628fe2 100644 --- a/src/promptflow/promptflow/azure/operations/_flow_opearations.py +++ b/src/promptflow/promptflow/azure/operations/_flow_opearations.py @@ -149,6 +149,7 @@ def load_as_component( is_deterministic: bool = True, **kwargs, ) -> Component: + """Load a flow as a component.""" rest_object = LoadFlowAsComponentRequest( node_variant=variant, inputs_mapping=columns_mapping, diff --git a/src/promptflow/promptflow/exceptions.py b/src/promptflow/promptflow/exceptions.py index 512667b4ec9..f7666f5c2cc 100644 --- a/src/promptflow/promptflow/exceptions.py +++ b/src/promptflow/promptflow/exceptions.py @@ -8,6 +8,8 @@ class ErrorTarget(str, Enum): + """The target of the error, indicates which part of the system the error occurs.""" + EXECUTOR = "Executor" FLOW_EXECUTOR = "FlowExecutor" NODE_EXECUTOR = "NodeExecutor" @@ -52,6 +54,7 @@ def __init__( @property def message(self): + """The error message.""" if self._message: return self._message @@ -62,10 +65,12 @@ def message(self): @property def message_format(self): + """The error message format.""" return self._message_format @cached_property def message_parameters(self): + """The error message parameters.""" if not self._kwargs: return {} @@ -74,11 +79,12 @@ def message_parameters(self): @cached_property def serializable_message_parameters(self): + """The serializable error message parameters.""" return {k: str(v) for k, v in self.message_parameters.items()} @property def target(self): - """Return the error target. + """The error target. :return: The error target. :rtype: ErrorTarget @@ -87,6 +93,7 @@ def target(self): @target.setter def target(self, value): + """Set the error target.""" self._target = value @property @@ -155,6 +162,8 @@ def reversed_error_codes(): return result def get_arguments_from_message_format(self, message_format): + """Get the arguments from the message format.""" + def iter_field_name(): if not message_format: return @@ -185,4 +194,6 @@ class SystemErrorException(PromptflowException): class ValidationException(UserErrorException): + """Exception raised when validation fails.""" + pass