From 692a0e56186952c3e646f29c4340950767971e67 Mon Sep 17 00:00:00 2001 From: Honglin <0mza987@gmail.com> Date: Thu, 7 Sep 2023 20:08:54 +0800 Subject: [PATCH] [Doc] Add doc string for public functions (#330) # Description Refine SDK public class's docstring. # All Promptflow Contribution checklist: - [x] **The pull request does not introduce [breaking changes]** - [x] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [x] **I have read the [contribution guidelines](../CONTRIBUTING.md).** ## General Guidelines and Best Practices - [x] Title of the pull request is clear and informative. - [x] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --- .../promptflow/_core/metric_logger.py | 9 +++ src/promptflow/promptflow/_core/tool.py | 12 ++- src/promptflow/promptflow/_sdk/_pf_client.py | 7 +- .../promptflow/_sdk/entities/_connection.py | 78 +++++++++++++------ .../promptflow/_sdk/entities/_run.py | 67 ++++++++++------ src/promptflow/promptflow/azure/_pf_client.py | 4 +- .../azure/operations/_flow_opearations.py | 1 + src/promptflow/promptflow/exceptions.py | 13 +++- 8 files changed, 141 insertions(+), 50 deletions(-) diff --git a/src/promptflow/promptflow/_core/metric_logger.py b/src/promptflow/promptflow/_core/metric_logger.py index b17891606a4..e5c14f2ceed 100644 --- a/src/promptflow/promptflow/_core/metric_logger.py +++ b/src/promptflow/promptflow/_core/metric_logger.py @@ -45,6 +45,15 @@ def remove_metric_logger(self, logger_func: Callable): def log_metric(key, value, variant_id=None): + """Log a metric for current promptflow run. + + :param key: Metric name. + :type key: str + :param value: Metric value. + :type value: float + :param variant_id: Variant id for the metric. + :type variant_id: str + """ MetricLoggerManager.get_instance().log_metric(key, value, variant_id) diff --git a/src/promptflow/promptflow/_core/tool.py b/src/promptflow/promptflow/_core/tool.py index a65b0b39d5d..48ab30db4dd 100644 --- a/src/promptflow/promptflow/_core/tool.py +++ b/src/promptflow/promptflow/_core/tool.py @@ -7,7 +7,7 @@ import logging from abc import ABC from enum import Enum -from typing import Optional +from typing import Callable, Optional module_logger = logging.getLogger(__name__) @@ -39,7 +39,15 @@ def active_instance(cls) -> Optional["ToolInvoker"]: return cls._active_tool_invoker -def tool(f): +def tool(f: Callable) -> Callable: + """Decorator for tool functions. The decorated function will be registered as a tool and can be used in a flow. + + :param f: The tool function. + :type f: Callable + :return: The decorated function. + :rtype: Callable + """ + @functools.wraps(f) def new_f(*args, **kwargs): tool_invoker = ToolInvoker.active_instance() diff --git a/src/promptflow/promptflow/_sdk/_pf_client.py b/src/promptflow/promptflow/_sdk/_pf_client.py index 385716ddfe8..ee7107c328f 100644 --- a/src/promptflow/promptflow/_sdk/_pf_client.py +++ b/src/promptflow/promptflow/_sdk/_pf_client.py @@ -101,6 +101,7 @@ def stream(self, run: Union[str, Run]) -> Run: :param run: Run object or name of the run. :type run: Union[str, ~promptflow.sdk.entities.Run] :return: flow run info. + :rtype: ~promptflow.sdk.entities.Run """ return self.runs.stream(run) @@ -134,15 +135,17 @@ def visualize(self, runs: Union[List[str], List[Run]]) -> None: @property def runs(self) -> RunOperations: + """Run operations that can manage runs.""" return self._runs @property def connections(self) -> ConnectionOperations: + """Connection operations that can manage connections.""" return self._connections @property def flows(self) -> FlowOperations: - """Operations on the flow, such as test/debug the flow, chat with chat flow.""" + """Operations on the flow that can manage flows.""" return self._flows def test( @@ -154,7 +157,7 @@ def test( node: str = None, environment_variables: dict = None, ) -> dict: - """Test flow or node + """Test flow or node. :param flow: path to flow directory to test :type flow: Union[str, PathLike] diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index f7536493736..0716d2f17f1 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -44,6 +44,20 @@ class _Connection(YAMLTranslatableMixin): + """A connection entity that stores the connection information. + + :param name: Connection name + :type name: str + :param type: Possible values include: "OpenAI", "AzureOpenAI", "Custom". + :type type: str + :param module: The module of connection class, used for execution. + :type module: str + :param configs: The configs kv pairs. + :type configs: Dict[str, str] + :param secrets: The secrets kv pairs. + :type secrets: Dict[str, str] + """ + TYPE = ConnectionType._NOT_SET def __init__( @@ -54,13 +68,6 @@ def __init__( secrets: Dict[str, str] = None, **kwargs, ): - """ - :param name: Connection name - :param type: Possible values include: "OpenAI", "AzureOpenAI", "Custom". - :param module: The module of connection class, used for execution. - :param configs: The configs kv pairs. - :param secrets: The secrets kv pairs. - """ self.name = name self.type = self.TYPE self.class_name = f"{self.TYPE.value}Connection" # The type in executor connection dict @@ -290,14 +297,18 @@ def api_key(self, value): class AzureOpenAIConnection(_StrongTypeConnection): - """ - Azure Open AI connection. + """Azure Open AI connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param api_type: The api type, default "azure". + :type api_type: str :param api_version: The api version, default "2023-07-01-preview". + :type api_version: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.AZURE_OPEN_AI @@ -345,12 +356,14 @@ def api_version(self, value): class OpenAIConnection(_StrongTypeConnection): - """ - Open AI connection. + """Open AI connection. :param api_key: The api key. + :type api_key: str :param organization: Optional. The unique identifier for your organization which can be used in API requests. + :type organization: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.OPEN_AI @@ -376,11 +389,12 @@ def organization(self, value): class SerpConnection(_StrongTypeConnection): - """ - Serp connection. + """Serp connection. :param api_key: The api key. + :type api_key: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.SERP @@ -412,12 +426,14 @@ def api_base(self, value): class QdrantConnection(_EmbeddingStoreConnection): - """ - Qdrant connection. + """Qdrant connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.QDRANT @@ -428,12 +444,14 @@ def _get_schema_cls(cls): class WeaviateConnection(_EmbeddingStoreConnection): - """ - Weaviate connection. + """Weaviate connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.WEAVIATE @@ -444,13 +462,16 @@ def _get_schema_cls(cls): class CognitiveSearchConnection(_StrongTypeConnection): - """ - Cognitive Search connection. + """Cognitive Search connection. :param api_key: The api key. + :type api_key: str :param api_base: The api base. + :type api_base: str :param api_version: The api version, default "2023-07-01-Preview". + :type api_version: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.COGNITIVE_SEARCH @@ -490,10 +511,15 @@ class AzureContentSafetyConnection(_StrongTypeConnection): Azure Content Safety connection. :param api_key: The api key. + :type api_key: str :param endpoint: The api endpoint. + :type endpoint: str :param api_version: The api version, default "2023-04-30-preview". + :type api_version: str :param api_type: The api type, default "Content Safety". + :type api_type: str :param name: Connection name. + :type name: str """ TYPE = ConnectionType.AZURE_CONTENT_SAFETY @@ -546,14 +572,18 @@ def api_type(self, value): class FormRecognizerConnection(AzureContentSafetyConnection): - """ - Form Recognizer connection. + """Form Recognizer connection. :param api_key: The api key. + :type api_key: str :param endpoint: The api endpoint. + :type endpoint: str :param api_version: The api version, default "2023-07-31". + :type api_version: str :param api_type: The api type, default "Form Recognizer". + :type api_type: str :param name: Connection name. + :type name: str """ # Note: FormRecognizer and ContentSafety are using CognitiveService type in ARM, so keys are the same. @@ -570,10 +600,14 @@ def _get_schema_cls(cls): class CustomConnection(_Connection): - """ + """Custom connection. + :param configs: The configs kv pairs. + :type configs: Dict[str, str] :param secrets: The secrets kv pairs. + :type secrets: Dict[str, str] :param name: Connection name + :type name: str """ TYPE = ConnectionType.CUSTOM diff --git a/src/promptflow/promptflow/_sdk/entities/_run.py b/src/promptflow/promptflow/_sdk/entities/_run.py index a3571fdb541..6094afd8945 100644 --- a/src/promptflow/promptflow/_sdk/entities/_run.py +++ b/src/promptflow/promptflow/_sdk/entities/_run.py @@ -45,15 +45,53 @@ class Run(YAMLTranslatableMixin): + """Flow run entity. + + :param flow: Path of the flow directory. + :type flow: Path + :param name: Name of the run. + :type name: Optional[str] + :param data: Input data for the run. + :type data: Optional[str] + :param variant: Variant of the run. + :type variant: Optional[str] + :param run: Parent run or run ID. + :type run: Optional[Union[Run, str]] + :param column_mapping: Column mapping for the run. Optional since it's not stored in the database. + :type column_mapping: Optional[dict] + :param display_name: Display name of the run. + :type display_name: Optional[str] + :param description: Description of the run. + :type description: Optional[str] + :param tags: Tags of the run. + :type tags: Optional[List[Dict[str, str]]] + :param created_on: Date and time the run was created. + :type created_on: Optional[datetime.datetime] + :param start_time: Date and time the run started. + :type start_time: Optional[datetime.datetime] + :param end_time: Date and time the run ended. + :type end_time: Optional[datetime.datetime] + :param status: Status of the run. + :type status: Optional[str] + :param environment_variables: Environment variables for the run. + :type environment_variables: Optional[Dict[str, str]] + :param connections: Connections for the run. + :type connections: Optional[Dict[str, Dict]] + :param properties: Properties of the run. + :type properties: Optional[Dict[str, Any]] + :param kwargs: Additional keyword arguments. + :type kwargs: Optional[dict] + """ + def __init__( self, flow: Path, - name: str = None, + name: Optional[str] = None, # input fields are optional since it's not stored in DB - data: str = None, - variant: str = None, - run: Union["Run", str] = None, - column_mapping: dict = None, + data: Optional[str] = None, + variant: Optional[str] = None, + run: Optional[Union["Run", str]] = None, + column_mapping: Optional[dict] = None, display_name: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, @@ -62,26 +100,11 @@ def __init__( start_time: Optional[datetime.datetime] = None, end_time: Optional[datetime.datetime] = None, status: Optional[str] = None, - environment_variables: Dict[str, str] = None, - connections: Dict[str, Dict] = None, + environment_variables: Optional[Dict[str, str]] = None, + connections: Optional[Dict[str, Dict]] = None, properties: Optional[Dict[str, Any]] = None, **kwargs, ): - """Flow run. - - :param name: Name of the run. - :type name: str - :param type: Type of the run, should be one of "bulk", "evaluate" or "pairwise_evaluate". - :type type: str - :param flow: Path of the flow directory. - :type flow: Path - :param display_name: Display name of the run. - :type display_name: str - :param description: Description of the run. - :type description: str - :param tags: Tags of the run. - :type tags: List[Dict[str, str]] - """ # TODO: remove when RUN CRUD don't depend on this self.type = RunTypes.BATCH self.data = data diff --git a/src/promptflow/promptflow/azure/_pf_client.py b/src/promptflow/promptflow/azure/_pf_client.py index 61370d97510..a7e544cbfb4 100644 --- a/src/promptflow/promptflow/azure/_pf_client.py +++ b/src/promptflow/promptflow/azure/_pf_client.py @@ -86,7 +86,7 @@ def __init__( @property def ml_client(self): - """Return a client class to interact with Azure ML services.""" + """Return a client to interact with Azure ML services.""" return self._ml_client @classmethod @@ -250,6 +250,7 @@ def load_as_component( ) -> "Component": """ Load a flow as a component. + :param source: Source of the flow. Should be a path to a flow dag yaml file or a flow directory. :type source: Union[str, PathLike, IO[AnyStr]] :param component_type: Type of the loaded component, support parallel only for now. @@ -300,4 +301,5 @@ def _add_user_agent(self, kwargs) -> None: @property def runs(self): + """Return the run operation object that can manage runs.""" return self._runs diff --git a/src/promptflow/promptflow/azure/operations/_flow_opearations.py b/src/promptflow/promptflow/azure/operations/_flow_opearations.py index 8d5ecd86f3f..398d5628fe2 100644 --- a/src/promptflow/promptflow/azure/operations/_flow_opearations.py +++ b/src/promptflow/promptflow/azure/operations/_flow_opearations.py @@ -149,6 +149,7 @@ def load_as_component( is_deterministic: bool = True, **kwargs, ) -> Component: + """Load a flow as a component.""" rest_object = LoadFlowAsComponentRequest( node_variant=variant, inputs_mapping=columns_mapping, diff --git a/src/promptflow/promptflow/exceptions.py b/src/promptflow/promptflow/exceptions.py index 512667b4ec9..f7666f5c2cc 100644 --- a/src/promptflow/promptflow/exceptions.py +++ b/src/promptflow/promptflow/exceptions.py @@ -8,6 +8,8 @@ class ErrorTarget(str, Enum): + """The target of the error, indicates which part of the system the error occurs.""" + EXECUTOR = "Executor" FLOW_EXECUTOR = "FlowExecutor" NODE_EXECUTOR = "NodeExecutor" @@ -52,6 +54,7 @@ def __init__( @property def message(self): + """The error message.""" if self._message: return self._message @@ -62,10 +65,12 @@ def message(self): @property def message_format(self): + """The error message format.""" return self._message_format @cached_property def message_parameters(self): + """The error message parameters.""" if not self._kwargs: return {} @@ -74,11 +79,12 @@ def message_parameters(self): @cached_property def serializable_message_parameters(self): + """The serializable error message parameters.""" return {k: str(v) for k, v in self.message_parameters.items()} @property def target(self): - """Return the error target. + """The error target. :return: The error target. :rtype: ErrorTarget @@ -87,6 +93,7 @@ def target(self): @target.setter def target(self, value): + """Set the error target.""" self._target = value @property @@ -155,6 +162,8 @@ def reversed_error_codes(): return result def get_arguments_from_message_format(self, message_format): + """Get the arguments from the message format.""" + def iter_field_name(): if not message_format: return @@ -185,4 +194,6 @@ class SystemErrorException(PromptflowException): class ValidationException(UserErrorException): + """Exception raised when validation fails.""" + pass