diff --git a/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py b/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py index e7bf744423c2b..be22ec9f28a67 100644 --- a/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py +++ b/examples/airlift-federation-tutorial/airlift_federation_tutorial/dagster_defs/stages/executable_and_da.py @@ -1,5 +1,4 @@ -from typing import AbstractSet - +import dagster_airlift.core.predicates as predicates from dagster import ( AutomationConditionSensorDefinition, DefaultSensorStatus, @@ -7,10 +6,6 @@ MaterializeResult, multi_asset, ) -from dagster._core.definitions.asset_spec import merge_attributes -from dagster._core.definitions.declarative_automation.automation_condition import ( - AutomationCondition, -) from dagster_airlift.core import ( AirflowBasicAuthBackend, AirflowInstance, @@ -36,40 +31,25 @@ name="metrics", ) +warehouse_specs = load_airflow_dag_asset_specs( + airflow_instance=warehouse_airflow_instance, +) -def dag_id_matches(spec, dag_ids: AbstractSet[str]) -> bool: - return spec.metadata.get("Dag ID") in dag_ids - - -warehouse_specs = load_airflow_dag_asset_specs(airflow_instance=warehouse_airflow_instance) - -load_customers_dag_specs = [ - spec for spec in warehouse_specs if dag_id_matches(spec, {"load_customers"}) -] - -is_customer_dag = lambda spec: dag_id_matches(spec, {"customer_metrics"}) - -metrics_specs = [ - merge_attributes( - spec, deps=load_customers_dag_specs, automation_condition=AutomationCondition.eager() - ) - if not is_customer_dag(spec) - else spec - for spec in load_airflow_dag_asset_specs(airflow_instance=metrics_airflow_instance) -] +is_customer_metrics = predicates.dag_name_in({"customer_metrics"}) +is_load_customers = predicates.dag_name_in({"load_customers"}) +metrics_specs = load_airflow_dag_asset_specs( + airflow_instance=metrics_airflow_instance, +).merge_attributes({"deps": warehouse_specs.filter(is_load_customers)}, where=is_customer_metrics) -customer_metrics, remaining_metrics_specs = ( - [spec for spec in metrics_specs if is_customer_dag(spec)], - [spec for spec in metrics_specs if not is_customer_dag(spec)], -) +customer_metrics_specs, rest_of_metrics_specs = metrics_specs.split(is_customer_metrics) -@multi_asset(specs=customer_metrics) +@multi_asset(specs=[customer_metrics_specs[0]]) def run_customer_metrics() -> MaterializeResult: run_id = metrics_airflow_instance.trigger_dag("customer_metrics") metrics_airflow_instance.wait_for_run_completion("customer_metrics", run_id) if metrics_airflow_instance.get_run_state("customer_metrics", run_id) == "success": - return MaterializeResult(asset_key=customer_metrics[0].key) + return MaterializeResult(asset_key=customer_metrics_specs[0].key) else: raise Exception("Dag run failed.") @@ -91,6 +71,6 @@ def run_customer_metrics() -> MaterializeResult: ) defs = Definitions( - assets=[run_customer_metrics, *remaining_metrics_specs, *warehouse_specs], + assets=[run_customer_metrics, *warehouse_specs, *rest_of_metrics_specs], sensors=[warehouse_sensor, metrics_sensor, automation_sensor], ) diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/load_defs.py b/examples/experimental/dagster-airlift/dagster_airlift/core/load_defs.py index a0f3a0a4ec8b6..70010ad4bb7e3 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/load_defs.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/load_defs.py @@ -31,7 +31,7 @@ DagInfo, SerializedAirflowDefinitionsData, ) -from dagster_airlift.core.utils import get_metadata_key, spec_iterator +from dagster_airlift.core.utils import AssetSpecSequence, get_metadata_key, spec_iterator @dataclass @@ -353,11 +353,11 @@ def load_airflow_dag_asset_specs( airflow_instance: AirflowInstance, mapped_assets: Optional[Sequence[MappedAsset]] = None, dag_selector_fn: Optional[DagSelectorFn] = None, -) -> Sequence[AssetSpec]: +) -> AssetSpecSequence: """Load asset specs for Airflow DAGs from the provided :py:class:`AirflowInstance`, and link upstreams from mapped assets.""" serialized_data = AirflowInstanceDefsLoader( airflow_instance=airflow_instance, mapped_assets=mapped_assets or [], dag_selector_fn=dag_selector_fn, ).get_or_fetch_state() - return list(spec_iterator(construct_dag_assets_defs(serialized_data))) + return AssetSpecSequence(list(spec_iterator(construct_dag_assets_defs(serialized_data)))) diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/predicates.py b/examples/experimental/dagster-airlift/dagster_airlift/core/predicates.py new file mode 100644 index 0000000000000..9bf9def633f18 --- /dev/null +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/predicates.py @@ -0,0 +1,12 @@ +from typing import AbstractSet + +from dagster import AssetSpec + +from dagster_airlift.core.utils import AssetSpecPredicate + + +def dag_name_in(names: AbstractSet[str]) -> AssetSpecPredicate: + def _dag_name_in(spec: AssetSpec) -> bool: + return spec.metadata["Dag ID"] in names + + return _dag_name_in diff --git a/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py b/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py index 3be91d3a69595..2f23fd0c85ebd 100644 --- a/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py +++ b/examples/experimental/dagster-airlift/dagster_airlift/core/utils.py @@ -1,4 +1,14 @@ -from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Set, Union +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + Iterator, + Optional, + Sequence, + Set, + Tuple, + Union, +) from dagster import ( AssetsDefinition, @@ -6,6 +16,7 @@ SourceAsset, _check as check, ) +from dagster._core.definitions.asset_spec import merge_attributes, replace_attributes from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition from dagster._core.definitions.utils import VALID_NAME_REGEX from dagster._core.errors import DagsterInvariantViolationError @@ -105,3 +116,63 @@ def peered_dag_handles_for_spec(spec: AssetSpec) -> Set["DagHandle"]: DagHandle(dag_id=dag_handle_dict["dag_id"]) for dag_handle_dict in spec.metadata[PEERED_DAG_MAPPING_METADATA_KEY] } + + +AssetSpecPredicate = Callable[[AssetSpec], bool] + + +class AssetSpecSequence(Sequence[AssetSpec]): + def __init__(self, asset_specs: Sequence[AssetSpec], can_transform: bool = True): + self._asset_specs = asset_specs + self._can_transform = can_transform + + def __getitem__(self, item: int) -> AssetSpec: + return self._asset_specs[item] + + def __len__(self) -> int: + return len(self._asset_specs) + + def __iter__(self) -> Iterator[AssetSpec]: + return iter(self._asset_specs) + + def split( + self, include: Callable[[AssetSpec], bool] + ) -> Tuple["AssetSpecSequence", "AssetSpecSequence"]: + return AssetSpecSequence( + [asset_spec for asset_spec in self if include(asset_spec)], can_transform=False + ), AssetSpecSequence( + [asset_spec for asset_spec in self if not include(asset_spec)], can_transform=False + ) + + def filter(self, where: AssetSpecPredicate) -> "AssetSpecSequence": + return AssetSpecSequence( + [asset_spec for asset_spec in self if where(asset_spec)], can_transform=False + ) + + def replace_attributes( + self, attrs: dict, where: Callable[[AssetSpec], bool] + ) -> "AssetSpecSequence": + if not self._can_transform: + raise Exception("Cannot transform this sequence") + return AssetSpecSequence( + [ + replace_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec + for asset_spec in self + ], + ) + + def merge_attributes( + self, attrs: dict, where: Callable[[AssetSpec], bool] + ) -> "AssetSpecSequence": + # We only want to perform operations on the full set of asset specs. This removes the footgun where + # there are multiple sets of assets floating around, and we perform a transformation on a subset, but + # then provide the wrong subset to Definitions. + # The idea would be that we always provide the full set of assets to Definitions + if not self._can_transform: + raise Exception("Cannot transform this sequence") + return AssetSpecSequence( + [ + merge_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec + for asset_spec in self + ], + )