diff --git a/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py b/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py new file mode 100644 index 0000000000000..e6f6b044ab1bb --- /dev/null +++ b/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py @@ -0,0 +1,89 @@ +from dagster import ( + AutomationConditionSensorDefinition, + DefaultSensorStatus, + Definitions, + MaterializeResult, + multi_asset, +) +from dagster._core.definitions.asset_spec import replace_attributes +from dagster._core.definitions.declarative_automation.automation_condition import ( + AutomationCondition, +) +from dagster_airlift.core import ( + AirflowBasicAuthBackend, + AirflowInstance, + build_airflow_polling_sensor, + load_airflow_dag_asset_specs, +) + +upstream_airflow_instance = AirflowInstance( + auth_backend=AirflowBasicAuthBackend( + webserver_url="http://localhost:8081", + username="admin", + password="admin", + ), + name="upstream", +) + +downstream_airflow_instance = AirflowInstance( + auth_backend=AirflowBasicAuthBackend( + webserver_url="http://localhost:8082", + username="admin", + password="admin", + ), + name="downstream", +) + +load_customers_dag_asset = next( + iter( + load_airflow_dag_asset_specs( + airflow_instance=upstream_airflow_instance, + dag_selector_fn=lambda dag: dag.dag_id == "load_customers", + ) + ) +) +customer_metrics_dag_asset = replace_attributes( + next( + iter( + load_airflow_dag_asset_specs( + airflow_instance=downstream_airflow_instance, + dag_selector_fn=lambda dag: dag.dag_id == "customer_metrics", + ) + ) + # Add a dependency on the load_customers_dag_asset + ), + deps=[load_customers_dag_asset], + automation_condition=AutomationCondition.eager(), +) + + +@multi_asset(specs=[customer_metrics_dag_asset]) +def run_customer_metrics() -> MaterializeResult: + run_id = downstream_airflow_instance.trigger_dag("customer_metrics") + downstream_airflow_instance.wait_for_run_completion("customer_metrics", run_id) + if downstream_airflow_instance.get_run_state("customer_metrics", run_id) == "success": + return MaterializeResult(asset_key=customer_metrics_dag_asset.key) + else: + raise Exception("Dag run failed.") + + +upstream_sensor = build_airflow_polling_sensor( + mapped_assets=[load_customers_dag_asset], + airflow_instance=upstream_airflow_instance, +) +downstream_sensor = build_airflow_polling_sensor( + mapped_assets=[customer_metrics_dag_asset], + airflow_instance=downstream_airflow_instance, +) + +automation_sensor = AutomationConditionSensorDefinition( + name="automation_sensor", + target="*", + default_status=DefaultSensorStatus.RUNNING, + minimum_interval_seconds=1, +) + +defs = Definitions( + assets=[load_customers_dag_asset, run_customer_metrics], + sensors=[upstream_sensor, downstream_sensor], +) diff --git a/examples/airlift-federation-tutorial/airlift_federation_tutorial_tests/test_executable_stage.py b/examples/airlift-federation-tutorial/airlift_federation_tutorial_tests/test_executable_stage.py new file mode 100644 index 0000000000000..8b682d3c9404c --- /dev/null +++ b/examples/airlift-federation-tutorial/airlift_federation_tutorial_tests/test_executable_stage.py @@ -0,0 +1,43 @@ +import subprocess +from typing import Generator + +import pytest +import requests +from airlift_federation_tutorial_tests.conftest import ORIG_DEFS_FILE, makefile_dir, replace_file +from dagster_airlift.in_airflow.gql_queries import ASSET_NODES_QUERY +from dagster_airlift.test.shared_fixtures import stand_up_dagster + +EXECUTABLE_STAGE_FILE = ORIG_DEFS_FILE.parent / "stages" / "executable_and_da.py" + + +@pytest.fixture +def completed_stage() -> Generator[None, None, None]: + with replace_file(ORIG_DEFS_FILE, EXECUTABLE_STAGE_FILE): + yield + + +@pytest.fixture(name="dagster_dev") +def dagster_fixture( + upstream_airflow: subprocess.Popen, downstream_airflow: subprocess.Popen, completed_stage: None +) -> Generator[subprocess.Popen, None, None]: + process = None + try: + with stand_up_dagster( + dagster_dev_cmd=["make", "-C", str(makefile_dir()), "dagster_run"], + port=3000, + ) as process: + yield process + finally: + if process: + process.terminate() + + +def test_executable_stage(dagster_dev: subprocess.Popen) -> None: + response = requests.post( + # Timeout in seconds + "http://localhost:3000/graphql", + json={"query": ASSET_NODES_QUERY}, + timeout=3, + ) + assert response.status_code == 200 + assert len(response.json()["data"]["assetNodes"]) == 2