Skip to content

Commit

Permalink
[dagster-airlift] Asset spec sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Nov 15, 2024
1 parent 60bb96e commit 85d6cf7
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from typing import AbstractSet

import dagster_airlift.core.predicates as predicates
from dagster import (
AutomationConditionSensorDefinition,
DefaultSensorStatus,
Definitions,
MaterializeResult,
multi_asset,
)
from dagster._core.definitions.asset_spec import merge_attributes
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster_airlift.core import (
AirflowBasicAuthBackend,
AirflowInstance,
Expand All @@ -36,40 +31,25 @@
name="metrics",
)

warehouse_specs = load_airflow_dag_asset_specs(
airflow_instance=warehouse_airflow_instance,
)

def dag_id_matches(spec, dag_ids: AbstractSet[str]) -> bool:
return spec.metadata.get("Dag ID") in dag_ids


warehouse_specs = load_airflow_dag_asset_specs(airflow_instance=warehouse_airflow_instance)

load_customers_dag_specs = [
spec for spec in warehouse_specs if dag_id_matches(spec, {"load_customers"})
]

is_customer_dag = lambda spec: dag_id_matches(spec, {"customer_metrics"})

metrics_specs = [
merge_attributes(
spec, deps=load_customers_dag_specs, automation_condition=AutomationCondition.eager()
)
if not is_customer_dag(spec)
else spec
for spec in load_airflow_dag_asset_specs(airflow_instance=metrics_airflow_instance)
]
is_customer_metrics = predicates.dag_name_in({"customer_metrics"})
is_load_customers = predicates.dag_name_in({"load_customers"})
metrics_specs = load_airflow_dag_asset_specs(
airflow_instance=metrics_airflow_instance,
).merge_attributes({"deps": warehouse_specs.filter(is_load_customers)}, where=is_customer_metrics)

customer_metrics, remaining_metrics_specs = (
[spec for spec in metrics_specs if is_customer_dag(spec)],
[spec for spec in metrics_specs if not is_customer_dag(spec)],
)
customer_metrics_specs, rest_of_metrics_specs = metrics_specs.split(is_customer_metrics)


@multi_asset(specs=customer_metrics)
@multi_asset(specs=[customer_metrics_specs[0]])
def run_customer_metrics() -> MaterializeResult:
run_id = metrics_airflow_instance.trigger_dag("customer_metrics")
metrics_airflow_instance.wait_for_run_completion("customer_metrics", run_id)
if metrics_airflow_instance.get_run_state("customer_metrics", run_id) == "success":
return MaterializeResult(asset_key=customer_metrics[0].key)
return MaterializeResult(asset_key=customer_metrics_specs[0].key)
else:
raise Exception("Dag run failed.")

Expand All @@ -91,6 +71,6 @@ def run_customer_metrics() -> MaterializeResult:
)

defs = Definitions(
assets=[run_customer_metrics, *remaining_metrics_specs, *warehouse_specs],
assets=[run_customer_metrics, *warehouse_specs, *rest_of_metrics_specs],
sensors=[warehouse_sensor, metrics_sensor, automation_sensor],
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DagInfo,
SerializedAirflowDefinitionsData,
)
from dagster_airlift.core.utils import get_metadata_key, spec_iterator
from dagster_airlift.core.utils import AssetSpecSequence, get_metadata_key, spec_iterator


@dataclass
Expand Down Expand Up @@ -353,11 +353,11 @@ def load_airflow_dag_asset_specs(
airflow_instance: AirflowInstance,
mapped_assets: Optional[Sequence[MappedAsset]] = None,
dag_selector_fn: Optional[DagSelectorFn] = None,
) -> Sequence[AssetSpec]:
) -> AssetSpecSequence:
"""Load asset specs for Airflow DAGs from the provided :py:class:`AirflowInstance`, and link upstreams from mapped assets."""
serialized_data = AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
mapped_assets=mapped_assets or [],
dag_selector_fn=dag_selector_fn,
).get_or_fetch_state()
return list(spec_iterator(construct_dag_assets_defs(serialized_data)))
return AssetSpecSequence(list(spec_iterator(construct_dag_assets_defs(serialized_data))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import AbstractSet

from dagster import AssetSpec

from dagster_airlift.core.utils import AssetSpecPredicate


def dag_name_in(names: AbstractSet[str]) -> AssetSpecPredicate:
def _dag_name_in(spec: AssetSpec) -> bool:
return spec.metadata["Dag ID"] in names

return _dag_name_in
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Set, Union
from typing import (
TYPE_CHECKING,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
Set,
Tuple,
Union,
)

from dagster import (
AssetsDefinition,
AssetSpec,
SourceAsset,
_check as check,
)
from dagster._core.definitions.asset_spec import merge_attributes, replace_attributes
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
from dagster._core.definitions.utils import VALID_NAME_REGEX
from dagster._core.errors import DagsterInvariantViolationError
Expand Down Expand Up @@ -105,3 +116,63 @@ def peered_dag_handles_for_spec(spec: AssetSpec) -> Set["DagHandle"]:
DagHandle(dag_id=dag_handle_dict["dag_id"])
for dag_handle_dict in spec.metadata[PEERED_DAG_MAPPING_METADATA_KEY]
}


AssetSpecPredicate = Callable[[AssetSpec], bool]


class AssetSpecSequence(Sequence[AssetSpec]):
def __init__(self, asset_specs: Sequence[AssetSpec], can_transform: bool = True):
self._asset_specs = asset_specs
self._can_transform = can_transform

def __getitem__(self, item: int) -> AssetSpec:
return self._asset_specs[item]

def __len__(self) -> int:
return len(self._asset_specs)

def __iter__(self) -> Iterator[AssetSpec]:
return iter(self._asset_specs)

def split(
self, include: Callable[[AssetSpec], bool]
) -> Tuple["AssetSpecSequence", "AssetSpecSequence"]:
return AssetSpecSequence(
[asset_spec for asset_spec in self if include(asset_spec)], can_transform=False
), AssetSpecSequence(
[asset_spec for asset_spec in self if not include(asset_spec)], can_transform=False
)

def filter(self, where: AssetSpecPredicate) -> "AssetSpecSequence":
return AssetSpecSequence(
[asset_spec for asset_spec in self if where(asset_spec)], can_transform=False
)

def replace_attributes(
self, attrs: dict, where: Callable[[AssetSpec], bool]
) -> "AssetSpecSequence":
if not self._can_transform:
raise Exception("Cannot transform this sequence")
return AssetSpecSequence(
[
replace_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec
for asset_spec in self
],
)

def merge_attributes(
self, attrs: dict, where: Callable[[AssetSpec], bool]
) -> "AssetSpecSequence":
# We only want to perform operations on the full set of asset specs. This removes the footgun where
# there are multiple sets of assets floating around, and we perform a transformation on a subset, but
# then provide the wrong subset to Definitions.
# The idea would be that we always provide the full set of assets to Definitions
if not self._can_transform:
raise Exception("Cannot transform this sequence")
return AssetSpecSequence(
[
merge_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec
for asset_spec in self
],
)

0 comments on commit 85d6cf7

Please sign in to comment.