Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dagster-airlift] Asset spec sequence #25926

Draft
wants to merge 1 commit into
base: executable_and_da_strawman
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from typing import AbstractSet

import dagster_airlift.core.predicates as predicates
from dagster import (
AutomationConditionSensorDefinition,
DefaultSensorStatus,
Definitions,
MaterializeResult,
multi_asset,
)
from dagster._core.definitions.asset_spec import merge_attributes
from dagster._core.definitions.declarative_automation.automation_condition import (
AutomationCondition,
)
from dagster_airlift.core import (
AirflowBasicAuthBackend,
AirflowInstance,
Expand All @@ -36,40 +31,25 @@
name="metrics",
)

warehouse_specs = load_airflow_dag_asset_specs(
airflow_instance=warehouse_airflow_instance,
)

def dag_id_matches(spec, dag_ids: AbstractSet[str]) -> bool:
return spec.metadata.get("Dag ID") in dag_ids


warehouse_specs = load_airflow_dag_asset_specs(airflow_instance=warehouse_airflow_instance)

load_customers_dag_specs = [
spec for spec in warehouse_specs if dag_id_matches(spec, {"load_customers"})
]

is_customer_dag = lambda spec: dag_id_matches(spec, {"customer_metrics"})

metrics_specs = [
merge_attributes(
spec, deps=load_customers_dag_specs, automation_condition=AutomationCondition.eager()
)
if not is_customer_dag(spec)
else spec
for spec in load_airflow_dag_asset_specs(airflow_instance=metrics_airflow_instance)
]
is_customer_metrics = predicates.dag_name_in({"customer_metrics"})
is_load_customers = predicates.dag_name_in({"load_customers"})
metrics_specs = load_airflow_dag_asset_specs(
airflow_instance=metrics_airflow_instance,
).merge_attributes({"deps": warehouse_specs.filter(is_load_customers)}, where=is_customer_metrics)

customer_metrics, remaining_metrics_specs = (
[spec for spec in metrics_specs if is_customer_dag(spec)],
[spec for spec in metrics_specs if not is_customer_dag(spec)],
)
customer_metrics_specs, rest_of_metrics_specs = metrics_specs.split(is_customer_metrics)


@multi_asset(specs=customer_metrics)
@multi_asset(specs=[customer_metrics_specs[0]])
def run_customer_metrics() -> MaterializeResult:
run_id = metrics_airflow_instance.trigger_dag("customer_metrics")
metrics_airflow_instance.wait_for_run_completion("customer_metrics", run_id)
if metrics_airflow_instance.get_run_state("customer_metrics", run_id) == "success":
return MaterializeResult(asset_key=customer_metrics[0].key)
return MaterializeResult(asset_key=customer_metrics_specs[0].key)
else:
raise Exception("Dag run failed.")

Expand All @@ -91,6 +71,6 @@ def run_customer_metrics() -> MaterializeResult:
)

defs = Definitions(
assets=[run_customer_metrics, *remaining_metrics_specs, *warehouse_specs],
assets=[run_customer_metrics, *warehouse_specs, *rest_of_metrics_specs],
sensors=[warehouse_sensor, metrics_sensor, automation_sensor],
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DagInfo,
SerializedAirflowDefinitionsData,
)
from dagster_airlift.core.utils import get_metadata_key, spec_iterator
from dagster_airlift.core.utils import AssetSpecSequence, get_metadata_key, spec_iterator


@dataclass
Expand Down Expand Up @@ -353,11 +353,11 @@ def load_airflow_dag_asset_specs(
airflow_instance: AirflowInstance,
mapped_assets: Optional[Sequence[MappedAsset]] = None,
dag_selector_fn: Optional[DagSelectorFn] = None,
) -> Sequence[AssetSpec]:
) -> AssetSpecSequence:
"""Load asset specs for Airflow DAGs from the provided :py:class:`AirflowInstance`, and link upstreams from mapped assets."""
serialized_data = AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
mapped_assets=mapped_assets or [],
dag_selector_fn=dag_selector_fn,
).get_or_fetch_state()
return list(spec_iterator(construct_dag_assets_defs(serialized_data)))
return AssetSpecSequence(list(spec_iterator(construct_dag_assets_defs(serialized_data))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import AbstractSet

from dagster import AssetSpec

from dagster_airlift.core.utils import AssetSpecPredicate


def dag_name_in(names: AbstractSet[str]) -> AssetSpecPredicate:
def _dag_name_in(spec: AssetSpec) -> bool:
return spec.metadata["Dag ID"] in names

return _dag_name_in
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Set, Union
from typing import (
TYPE_CHECKING,
Callable,
Iterable,
Iterator,
Optional,
Sequence,
Set,
Tuple,
Union,
)

from dagster import (
AssetsDefinition,
AssetSpec,
SourceAsset,
_check as check,
)
from dagster._core.definitions.asset_spec import merge_attributes, replace_attributes
from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition
from dagster._core.definitions.utils import VALID_NAME_REGEX
from dagster._core.errors import DagsterInvariantViolationError
Expand Down Expand Up @@ -105,3 +116,63 @@ def peered_dag_handles_for_spec(spec: AssetSpec) -> Set["DagHandle"]:
DagHandle(dag_id=dag_handle_dict["dag_id"])
for dag_handle_dict in spec.metadata[PEERED_DAG_MAPPING_METADATA_KEY]
}


AssetSpecPredicate = Callable[[AssetSpec], bool]


class AssetSpecSequence(Sequence[AssetSpec]):
def __init__(self, asset_specs: Sequence[AssetSpec], can_transform: bool = True):
self._asset_specs = asset_specs
self._can_transform = can_transform

def __getitem__(self, item: int) -> AssetSpec:
return self._asset_specs[item]

def __len__(self) -> int:
return len(self._asset_specs)

def __iter__(self) -> Iterator[AssetSpec]:
return iter(self._asset_specs)

def split(
self, include: Callable[[AssetSpec], bool]
) -> Tuple["AssetSpecSequence", "AssetSpecSequence"]:
return AssetSpecSequence(
[asset_spec for asset_spec in self if include(asset_spec)], can_transform=False
), AssetSpecSequence(
[asset_spec for asset_spec in self if not include(asset_spec)], can_transform=False
)

def filter(self, where: AssetSpecPredicate) -> "AssetSpecSequence":
return AssetSpecSequence(
[asset_spec for asset_spec in self if where(asset_spec)], can_transform=False
)

def replace_attributes(
self, attrs: dict, where: Callable[[AssetSpec], bool]
) -> "AssetSpecSequence":
if not self._can_transform:
raise Exception("Cannot transform this sequence")
return AssetSpecSequence(
[
replace_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec
for asset_spec in self
],
)

def merge_attributes(
self, attrs: dict, where: Callable[[AssetSpec], bool]
) -> "AssetSpecSequence":
# We only want to perform operations on the full set of asset specs. This removes the footgun where
# there are multiple sets of assets floating around, and we perform a transformation on a subset, but
# then provide the wrong subset to Definitions.
# The idea would be that we always provide the full set of assets to Definitions
if not self._can_transform:
raise Exception("Cannot transform this sequence")
return AssetSpecSequence(
[
merge_attributes(asset_spec, **attrs) if where(asset_spec) else asset_spec
for asset_spec in self
],
)