Skip to content

Commit

Permalink
[Orchestrator] Add experiment orchestrator (#1847)
Browse files Browse the repository at this point in the history
# Description
- Node parallel execution is supported
- support cancel running experiment



Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Feb 5, 2024
1 parent 5d64e41 commit fde9f4d
Show file tree
Hide file tree
Showing 15 changed files with 998 additions and 167 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"otel",
"OTLP",
"spawnv",
"spawnve",
"addrs"
],
"flagWords": [
Expand Down
27 changes: 26 additions & 1 deletion src/promptflow/promptflow/_cli/_pf/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,24 @@ def add_experiment_start(subparsers):
)


def add_experiment_stop(subparsers):
epilog = """
Examples:
# Stop an experiment:
pf experiment stop -n my_experiment
"""
activate_action(
name="stop",
description="Stop an experiment.",
epilog=epilog,
add_params=[add_param_name] + base_params,
subparsers=subparsers,
help_message="Stop an experiment.",
action_param_name="sub_action",
)


def add_experiment_parser(subparsers):
experiment_parser = subparsers.add_parser(
"experiment",
Expand All @@ -128,6 +146,7 @@ def add_experiment_parser(subparsers):
add_experiment_list(subparsers)
add_experiment_show(subparsers)
add_experiment_start(subparsers)
add_experiment_stop(subparsers)
experiment_parser.set_defaults(action="experiment")


Expand All @@ -147,7 +166,7 @@ def dispatch_experiment_commands(args: argparse.Namespace):
elif args.sub_action == "delete":
pass
elif args.sub_action == "stop":
pass
stop_experiment(args)
elif args.sub_action == "test":
pass
elif args.sub_action == "clone":
Expand Down Expand Up @@ -185,3 +204,9 @@ def show_experiment(args: argparse.Namespace):
def start_experiment(args: argparse.Namespace):
result = _get_pf_client()._experiments.start(args.name)
print(json.dumps(result._to_dict(), indent=4))


@exception_handler("Stop experiment")
def stop_experiment(args: argparse.Namespace):
result = _get_pf_client()._experiments.stop(args.name)
print(json.dumps(result._to_dict(), indent=4))
13 changes: 13 additions & 0 deletions src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _prepare_home_dir() -> Path:
RUN_INFO_CREATED_ON_INDEX_NAME = "idx_run_info_created_on"
CONNECTION_TABLE_NAME = "connection"
EXPERIMENT_TABLE_NAME = "experiment"
ORCHESTRATOR_TABLE_NAME = "orchestrator"
EXP_NODE_RUN_TABLE_NAME = "exp_node_run"
EXPERIMENT_CREATED_ON_INDEX_NAME = "idx_experiment_created_on"
BASE_PATH_CONTEXT_KEY = "base_path"
SCHEMA_KEYS_CONTEXT_CONFIG_KEY = "schema_configs_keys"
Expand Down Expand Up @@ -408,15 +410,26 @@ class DownloadedRun:

class ExperimentNodeType(object):
FLOW = "flow"
CHAT_GROUP = "chat_group"
COMMAND = "command"


class ExperimentStatus(object):
NOT_STARTED = "NotStarted"
QUEUING = "Queuing"
IN_PROGRESS = "InProgress"
TERMINATED = "Terminated"


class ExperimentNodeRunStatus(object):
NOT_STARTED = "NotStarted"
QUEUING = "Queuing"
IN_PROGRESS = "InProgress"
COMPLETED = "Completed"
FAILED = "Failed"
CANCELED = "Canceled"


class ExperimentContextKey:
EXPERIMENT = "experiment"
# Note: referenced id not used for lineage, only for evaluation
Expand Down
12 changes: 12 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ class DownloadInternalError(SDKInternalError):
pass


class ExperimentNodeRunFailedError(SDKError):
"""Orchestrator raised if node run failed."""

pass


class ExperimentNodeRunNotFoundError(SDKError):
"""ExpNodeRun raised if node run cannot be found."""

pass


class ExperimentCommandRunError(SDKError):
"""Exception raised if experiment validation failed."""

Expand Down
4 changes: 4 additions & 0 deletions src/promptflow/promptflow/_sdk/_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore

from promptflow._sdk._orm.run_info import RunInfo
from promptflow._sdk._orm.orchestrator import Orchestrator
from promptflow._sdk._orm.experiment_node_run import ExperimentNodeRun

from .connection import Connection
from .experiment import Experiment
Expand All @@ -14,5 +16,7 @@
"RunInfo",
"Connection",
"Experiment",
"ExperimentNodeRun",
"Orchestrator",
"mgmt_db_session",
]
95 changes: 95 additions & 0 deletions src/promptflow/promptflow/_sdk/_orm/experiment_node_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from sqlalchemy import TEXT, Column
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base

from promptflow._sdk._constants import EXP_NODE_RUN_TABLE_NAME, ExperimentNodeRunStatus
from promptflow._sdk._errors import ExperimentNodeRunNotFoundError

from .retry import sqlite_retry
from .session import mgmt_db_session

Base = declarative_base()


class ExperimentNodeRun(Base):
__tablename__ = EXP_NODE_RUN_TABLE_NAME

run_id = Column(TEXT, primary_key=True)
snapshot_id = Column(TEXT)
node_name = Column(TEXT, nullable=False)
experiment_name = Column(TEXT, nullable=False)
status = Column(TEXT, nullable=False)

# schema version, increase the version number when you change the schema
__pf_schema_version__ = "1"

@staticmethod
@sqlite_retry
def create_or_update(node_run: "ExperimentNodeRun") -> None:
session = mgmt_db_session()
run_id = node_run.run_id
try:
session.add(node_run)
session.commit()
except IntegrityError:
session = mgmt_db_session()
# Remove the _sa_instance_state
update_dict = {k: v for k, v in node_run.__dict__.items() if not k.startswith("_")}
session.query(ExperimentNodeRun).filter(ExperimentNodeRun.run_id == run_id).update(update_dict)
session.commit()

@staticmethod
@sqlite_retry
def delete(snapshot_id: str) -> None:
with mgmt_db_session() as session:
session.query(ExperimentNodeRun).filter(ExperimentNodeRun.snapshot_id == snapshot_id).delete()
session.commit()

@staticmethod
@sqlite_retry
def get(run_id: str, raise_error=True) -> "ExperimentNodeRun":
with mgmt_db_session() as session:
orchestrator = session.query(ExperimentNodeRun).filter(ExperimentNodeRun.run_id == run_id).first()
if orchestrator is None and raise_error:
raise ExperimentNodeRunNotFoundError(f"Not found the node run {run_id!r}.")
return orchestrator

@staticmethod
@sqlite_retry
def get_completed_node_by_snapshot_id(
snapshot_id: str, experiment_name: str, raise_error=True
) -> "ExperimentNodeRun":
with mgmt_db_session() as session:
node_run = (
session.query(ExperimentNodeRun)
.filter(
ExperimentNodeRun.snapshot_id == snapshot_id,
ExperimentNodeRun.experiment_name == experiment_name,
ExperimentNodeRun.status == ExperimentNodeRunStatus.COMPLETED,
)
.first()
)
if node_run is None and raise_error:
raise ExperimentNodeRunNotFoundError(
f"Not found the completed node run with snapshot id {snapshot_id!r}."
)
return node_run

@staticmethod
@sqlite_retry
def get_node_runs_by_experiment(experiment_name: str) -> "ExperimentNodeRun":
with mgmt_db_session() as session:
node_runs = (
session.query(ExperimentNodeRun).filter(ExperimentNodeRun.experiment_name == experiment_name).all()
)
return node_runs

@sqlite_retry
def update_status(self, status: str) -> None:
update_dict = {"status": status}
with mgmt_db_session() as session:
session.query(ExperimentNodeRun).filter(ExperimentNodeRun.run_id == self.run_id).update(update_dict)
session.commit()
55 changes: 55 additions & 0 deletions src/promptflow/promptflow/_sdk/_orm/orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from sqlalchemy import INTEGER, TEXT, Column
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base

from promptflow._sdk._constants import ORCHESTRATOR_TABLE_NAME
from promptflow._sdk._errors import ExperimentNotFoundError

from .retry import sqlite_retry
from .session import mgmt_db_session

Base = declarative_base()


class Orchestrator(Base):
__tablename__ = ORCHESTRATOR_TABLE_NAME

experiment_name = Column(TEXT, primary_key=True)
pid = Column(INTEGER, nullable=True)
status = Column(TEXT, nullable=False)
# schema version, increase the version number when you change the schema
__pf_schema_version__ = "1"

@staticmethod
@sqlite_retry
def create_or_update(orchestrator: "Orchestrator") -> None:
session = mgmt_db_session()
experiment_name = orchestrator.experiment_name
try:
session.add(orchestrator)
session.commit()
except IntegrityError:
session = mgmt_db_session()
# Remove the _sa_instance_state
update_dict = {k: v for k, v in orchestrator.__dict__.items() if not k.startswith("_")}
session.query(Orchestrator).filter(Orchestrator.experiment_name == experiment_name).update(update_dict)
session.commit()

@staticmethod
@sqlite_retry
def get(experiment_name: str, raise_error=True) -> "Orchestrator":
with mgmt_db_session() as session:
orchestrator = session.query(Orchestrator).filter(Orchestrator.experiment_name == experiment_name).first()
if orchestrator is None and raise_error:
raise ExperimentNotFoundError(f"The experiment {experiment_name!r} hasn't been started yet.")
return orchestrator

@staticmethod
@sqlite_retry
def delete(name: str) -> None:
with mgmt_db_session() as session:
session.query(Orchestrator).filter(Orchestrator.experiment_name == name).delete()
session.commit()
8 changes: 6 additions & 2 deletions src/promptflow/promptflow/_sdk/_orm/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._constants import (
CONNECTION_TABLE_NAME,
EXP_NODE_RUN_TABLE_NAME,
EXPERIMENT_CREATED_ON_INDEX_NAME,
EXPERIMENT_TABLE_NAME,
LOCAL_MGMT_DB_PATH,
LOCAL_MGMT_DB_SESSION_ACQUIRE_LOCK_PATH,
ORCHESTRATOR_TABLE_NAME,
RUN_INFO_CREATED_ON_INDEX_NAME,
RUN_INFO_TABLENAME,
SCHEMA_INFO_TABLENAME,
Expand Down Expand Up @@ -83,13 +85,15 @@ def mgmt_db_session() -> Session:
return session_maker()
if not LOCAL_MGMT_DB_PATH.parent.is_dir():
LOCAL_MGMT_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
engine = create_engine(f"sqlite:///{str(LOCAL_MGMT_DB_PATH)}", future=True)
engine = create_engine(f"sqlite:///{str(LOCAL_MGMT_DB_PATH)}?check_same_thread=False", future=True)
engine = support_transaction(engine)

from promptflow._sdk._orm import Connection, Experiment, RunInfo
from promptflow._sdk._orm import Connection, Experiment, ExperimentNodeRun, Orchestrator, RunInfo

create_or_update_table(engine, orm_class=RunInfo, tablename=RUN_INFO_TABLENAME)
create_table_if_not_exists(engine, CONNECTION_TABLE_NAME, Connection)
create_table_if_not_exists(engine, ORCHESTRATOR_TABLE_NAME, Orchestrator)
create_table_if_not_exists(engine, EXP_NODE_RUN_TABLE_NAME, ExperimentNodeRun)

create_index_if_not_exists(engine, RUN_INFO_CREATED_ON_INDEX_NAME, RUN_INFO_TABLENAME, "created_on")
if Configuration.get_instance().is_internal_features_enabled():
Expand Down
Loading

0 comments on commit fde9f4d

Please sign in to comment.