Skip to content

Commit

Permalink
[dagster-airlift] Airflow instance methods
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Nov 13, 2024
1 parent ef3869a commit 645cd02
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ AirflowInstance
^^^^^^^^^^^^^^^^^

.. autoclass:: AirflowInstance
:members:

.. autoclass:: AirflowAuthBackend

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import requests
from dagster import _check as check
from dagster._annotations import public
from dagster._core.definitions.utils import check_valid_name
from dagster._core.errors import DagsterError
from dagster._record import record
Expand Down Expand Up @@ -265,7 +266,19 @@ def get_dag_runs_batch(
f"Failed to fetch dag runs for {dag_ids}. Status code: {response.status_code}, Message: {response.text}"
)

@public
def trigger_dag(self, dag_id: str, logical_date: Optional[datetime.datetime] = None) -> str:
"""Trigger a dag run for the given dag_id.
Does not wait for the run to finish. To wait for the completed run to finish, use :py:meth:`wait_for_run_completion`.
Args:
dag_id (str): The dag id to trigger.
logical_date (Optional[datetime.datetime]): The logical date to use for the dag run. If not provided, the current time will be used.
Returns:
str: The dag run id.
"""
params = {} if not logical_date else {"logical_date": logical_date.isoformat()}
response = self.auth_backend.get_session().post(
f"{self.get_api_url()}/dags/{dag_id}/dagRuns",
Expand Down Expand Up @@ -303,7 +316,18 @@ def unpause_dag(self, dag_id: str) -> None:
f"Failed to unpause dag {dag_id}. Status code: {response.status_code}, Message: {response.text}"
)

@public
def wait_for_run_completion(self, dag_id: str, run_id: str, timeout: int = 30) -> None:
"""Given a run ID of an airflow dag, wait for that run to reach a completed state.
Args:
dag_id (str): The dag id.
run_id (str): The run id.
timeout (int): The number of seconds to wait before timing out.
Returns:
None
"""
start_time = get_current_datetime()
while get_current_datetime() - start_time < datetime.timedelta(seconds=timeout):
dag_run = self.get_dag_run(dag_id, run_id)
Expand All @@ -314,7 +338,17 @@ def wait_for_run_completion(self, dag_id: str, run_id: str, timeout: int = 30) -
) # Sleep for a second before checking again. This way we don't flood the rest API with requests.
raise DagsterError(f"Timed out waiting for airflow run {run_id} to finish.")

@public
def get_run_state(self, dag_id: str, run_id: str) -> str:
"""Given a run ID of an airflow dag, return the state of that run.
Args:
dag_id (str): The dag id.
run_id (str): The run id.
Returns:
str: The state of the run. Will be one of the states defined by Airflow.
"""
return self.get_dag_run(dag_id, run_id).state

def delete_run(self, dag_id: str, run_id: str) -> None:
Expand Down

0 comments on commit 645cd02

Please sign in to comment.