Skip to content

Commit

Permalink
[dagster-airlift] enrich mapped assets
Browse files Browse the repository at this point in the history
  • Loading branch information
dpeng817 committed Nov 12, 2024
1 parent 924767f commit 8cb9eb6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .load_defs import (
AirflowInstance as AirflowInstance,
DagSelectorFn as DagSelectorFn,
build_airflow_mapped_defs as build_airflow_mapped_defs,
build_defs_from_airflow_instance as build_defs_from_airflow_instance,
)
from .multiple_tasks import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import defaultdict
from functools import cached_property
from typing import AbstractSet, Mapping, Set
from typing import AbstractSet, Iterable, Mapping, Set, Union

from dagster import AssetKey, AssetSpec, Definitions
from dagster import AssetKey, AssetsDefinition, AssetSpec
from dagster._annotations import public
from dagster._record import record

Expand All @@ -18,9 +18,12 @@
is_peered_dag_asset_spec,
is_task_mapped_asset_spec,
peered_dag_handles_for_spec,
spec_iterator,
task_handles_for_spec,
)

MappedAsset = Union[AssetSpec, AssetsDefinition]


@record
class AirflowDefinitionsData:
Expand All @@ -32,21 +35,25 @@ class AirflowDefinitionsData:
"""

airflow_instance: AirflowInstance
mapped_defs: Definitions
mapped_assets: Iterable[MappedAsset]

@public
@property
def instance_name(self) -> str:
"""The name of the Airflow instance."""
return self.airflow_instance.name

@cached_property
def spec_iterator(self) -> Iterable[AssetSpec]:
return spec_iterator(self.mapped_assets)

@cached_property
def mapping_info(self) -> AirliftMetadataMappingInfo:
return build_airlift_metadata_mapping_info(self.mapped_defs)
return build_airlift_metadata_mapping_info(self.mapped_assets)

@cached_property
def all_asset_specs_by_key(self) -> Mapping[AssetKey, AssetSpec]:
return {spec.key: spec for spec in self.mapped_defs.get_all_asset_specs()}
return {spec.key: spec for spec in self.spec_iterator}

@public
def task_ids_in_dag(self, dag_id: str) -> Set[str]:
Expand All @@ -64,7 +71,7 @@ def dag_ids_with_mapped_asset_keys(self) -> AbstractSet[str]:
@cached_property
def mapped_asset_keys_by_task_handle(self) -> Mapping[TaskHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
for spec in self.spec_iterator:
if is_task_mapped_asset_spec(spec):
task_handles = task_handles_for_spec(spec)
for task_handle in task_handles:
Expand All @@ -74,7 +81,7 @@ def mapped_asset_keys_by_task_handle(self) -> Mapping[TaskHandle, AbstractSet[As
@cached_property
def mapped_asset_keys_by_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
for spec in self.spec_iterator:
if is_dag_mapped_asset_spec(spec):
dag_handles = dag_handles_for_spec(spec)
for dag_handle in dag_handles:
Expand All @@ -84,7 +91,7 @@ def mapped_asset_keys_by_dag_handle(self) -> Mapping[DagHandle, AbstractSet[Asse
@cached_property
def peered_dag_asset_keys_by_dag_handle(self) -> Mapping[DagHandle, AbstractSet[AssetKey]]:
asset_keys_per_handle = defaultdict(set)
for spec in self.mapped_defs.get_all_asset_specs():
for spec in self.spec_iterator:
if is_peered_dag_asset_spec(spec):
dag_handles = peered_dag_handles_for_spec(spec)
for dag_handle in dag_handles:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable, Iterable, Iterator, Optional, Union
from typing import Any, Callable, Iterable, Iterator, Optional, Sequence, Union

from dagster import (
AssetsDefinition,
Expand All @@ -11,6 +11,7 @@
from dagster._core.definitions.external_asset import external_asset_from_spec
from dagster._utils.warnings import suppress_dagster_warnings

from dagster_airlift.core.airflow_defs_data import MappedAsset
from dagster_airlift.core.airflow_instance import AirflowInstance
from dagster_airlift.core.sensor.event_translation import (
DagsterEventTransformerFn,
Expand All @@ -36,7 +37,7 @@
@dataclass
class AirflowInstanceDefsLoader(StateBackedDefinitionsLoader[SerializedAirflowDefinitionsData]):
airflow_instance: AirflowInstance
explicit_defs: Definitions
mapped_assets: Iterable[MappedAsset]
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS
dag_selector_fn: Optional[Callable[[DagInfo], bool]] = None

Expand All @@ -47,32 +48,21 @@ def defs_key(self) -> str:
def fetch_state(self) -> SerializedAirflowDefinitionsData:
return compute_serialized_data(
airflow_instance=self.airflow_instance,
defs=self.explicit_defs,
mapped_assets=self.mapped_assets,
dag_selector_fn=self.dag_selector_fn,
)

def defs_from_state(
self, serialized_airflow_data: SerializedAirflowDefinitionsData
) -> Definitions:
return Definitions.merge(
enrich_explicit_defs_with_airflow_metadata(self.explicit_defs, serialized_airflow_data),
construct_dag_assets_defs(serialized_airflow_data),
return Definitions(
assets=[
*_apply_airflow_data_to_specs(self.mapped_assets, serialized_airflow_data),
*construct_dag_assets_defs(serialized_airflow_data),
]
)


def build_airflow_mapped_defs(
*,
airflow_instance: AirflowInstance,
defs: Optional[Definitions] = None,
dag_selector_fn: Optional[DagSelectorFn] = None,
) -> Definitions:
return AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
explicit_defs=defs or Definitions(),
dag_selector_fn=dag_selector_fn,
).build_defs()


@suppress_dagster_warnings
def build_defs_from_airflow_instance(
*,
Expand Down Expand Up @@ -216,13 +206,25 @@ def only_include_dag(dag_info: DagInfo) -> bool:
)
"""
mapped_defs = build_airflow_mapped_defs(
airflow_instance=airflow_instance, defs=defs, dag_selector_fn=dag_selector_fn
defs = defs or Definitions()
mapped_assets = _type_narrow_defs_assets(defs)
serialized_airflow_data = AirflowInstanceDefsLoader(
airflow_instance=airflow_instance,
mapped_assets=mapped_assets,
dag_selector_fn=dag_selector_fn,
).get_or_fetch_state()
mapped_and_constructed_assets = [
*_apply_airflow_data_to_specs(mapped_assets, serialized_airflow_data),
*construct_dag_assets_defs(serialized_airflow_data),
]
defs_with_airflow_assets = replace_assets_in_defs(
defs=defs, assets=mapped_and_constructed_assets
)

return Definitions.merge(
mapped_defs,
defs_with_airflow_assets,
build_airflow_polling_sensor_defs(
mapped_defs=mapped_defs,
mapped_assets=mapped_and_constructed_assets,
airflow_instance=airflow_instance,
minimum_interval_seconds=sensor_minimum_interval_seconds,
event_transformer_fn=event_transformer_fn,
Expand All @@ -233,7 +235,7 @@ def only_include_dag(dag_info: DagInfo) -> bool:
@dataclass
class FullAutomappedDagsLoader(StateBackedDefinitionsLoader[SerializedAirflowDefinitionsData]):
airflow_instance: AirflowInstance
explicit_defs: Definitions
mapped_assets: Iterable[MappedAsset]
sensor_minimum_interval_seconds: int

@property
Expand All @@ -242,15 +244,19 @@ def defs_key(self) -> str:

def fetch_state(self) -> SerializedAirflowDefinitionsData:
return compute_serialized_data(
airflow_instance=self.airflow_instance, defs=self.explicit_defs, dag_selector_fn=None
airflow_instance=self.airflow_instance,
mapped_assets=self.mapped_assets,
dag_selector_fn=None,
)

def defs_from_state(
self, serialized_airflow_data: SerializedAirflowDefinitionsData
) -> Definitions:
return Definitions.merge(
enrich_explicit_defs_with_airflow_metadata(self.explicit_defs, serialized_airflow_data),
construct_automapped_dag_assets_defs(serialized_airflow_data),
return Definitions(
assets=[
*_apply_airflow_data_to_specs(self.mapped_assets, serialized_airflow_data),
*construct_automapped_dag_assets_defs(serialized_airflow_data),
]
)


Expand All @@ -260,50 +266,51 @@ def build_full_automapped_dags_from_airflow_instance(
sensor_minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
defs: Optional[Definitions] = None,
) -> Definitions:
resolved_defs = FullAutomappedDagsLoader(
defs = defs or Definitions()
mapped_assets = _type_narrow_defs_assets(defs or Definitions())
serialized_data = FullAutomappedDagsLoader(
airflow_instance=airflow_instance,
sensor_minimum_interval_seconds=sensor_minimum_interval_seconds,
explicit_defs=defs or Definitions(),
).build_defs()
mapped_assets=mapped_assets,
).get_or_fetch_state()
airflow_assets = [
*_apply_airflow_data_to_specs(mapped_assets, serialized_data),
*construct_automapped_dag_assets_defs(serialized_data),
]
resolved_defs = replace_assets_in_defs(defs=defs, assets=airflow_assets)
return Definitions.merge(
resolved_defs,
build_airflow_polling_sensor_defs(
minimum_interval_seconds=sensor_minimum_interval_seconds,
mapped_defs=resolved_defs,
mapped_assets=mapped_assets,
airflow_instance=airflow_instance,
),
)


def enrich_explicit_defs_with_airflow_metadata(
explicit_defs: Definitions, serialized_data: SerializedAirflowDefinitionsData
) -> Definitions:
return Definitions(
assets=list(_apply_airflow_data_to_specs(explicit_defs, serialized_data)),
asset_checks=explicit_defs.asset_checks,
sensors=explicit_defs.sensors,
schedules=explicit_defs.schedules,
jobs=explicit_defs.jobs,
executor=explicit_defs.executor,
loggers=explicit_defs.loggers,
resources=explicit_defs.resources,
metadata=explicit_defs.metadata,
def _type_check_asset(asset: Any) -> MappedAsset:
return check.inst(
asset,
(AssetSpec, AssetsDefinition),
"Expected passed assets to all be AssetsDefinitions or AssetSpecs.",
)


def _type_narrow_defs_assets(defs: Definitions) -> Iterable[Union[AssetSpec, AssetsDefinition]]:
return [_type_check_asset(asset) for asset in defs.assets or []]


def _apply_airflow_data_to_specs(
explicit_defs: Definitions,
assets: Iterable[Union[AssetSpec, AssetsDefinition]],
serialized_data: SerializedAirflowDefinitionsData,
) -> Iterator[AssetsDefinition]:
"""Apply asset spec transformations to the asset definitions."""
for asset in explicit_defs.assets or []:
asset = check.inst( # noqa: PLW2901
asset,
(AssetSpec, AssetsDefinition),
"Expected passed assets to all be AssetsDefinitions or AssetSpecs.",
)
for asset in assets:
narrowed_asset = _type_check_asset(asset)
assets_def = (
asset if isinstance(asset, AssetsDefinition) else external_asset_from_spec(asset)
narrowed_asset
if isinstance(narrowed_asset, AssetsDefinition)
else external_asset_from_spec(narrowed_asset)
)
yield assets_def.map_asset_specs(get_airflow_data_to_spec_mapper(serialized_data))

Expand All @@ -323,11 +330,12 @@ def replace_assets_in_defs(
)


def assets_def_of_defs(defs: Definitions) -> Iterator[AssetsDefinition]:
for asset in defs.assets or []:
asset = check.inst( # noqa: PLW2901
asset,
(AssetSpec, AssetsDefinition),
"Expected passed assets to all be AssetsDefinitions or AssetSpecs.",
)
yield asset if isinstance(asset, AssetsDefinition) else external_asset_from_spec(asset)
def enrich_airflow_mapped_assets(
mapped_assets: Iterable[MappedAsset],
airflow_instance: AirflowInstance,
) -> Sequence[AssetsDefinition]:
"""Enrich Airflow-mapped assets with metadata from the provided :py:class:`AirflowInstance`."""
serialized_data = AirflowInstanceDefsLoader(
airflow_instance=airflow_instance, mapped_assets=mapped_assets
).get_or_fetch_state()
return list(_apply_airflow_data_to_specs(mapped_assets, serialized_data))
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
EFFECTIVE_TIMESTAMP_METADATA_KEY,
TASK_ID_TAG_KEY,
)
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData
from dagster_airlift.core.airflow_defs_data import AirflowDefinitionsData, MappedAsset
from dagster_airlift.core.airflow_instance import AirflowInstance, DagRun, TaskInstance
from dagster_airlift.core.sensor.event_translation import (
AssetEvent,
Expand Down Expand Up @@ -80,7 +80,7 @@ def check_keys_for_asset_keys(

def build_airflow_polling_sensor_defs(
*,
mapped_defs: Definitions,
mapped_assets: Iterable[MappedAsset],
airflow_instance: AirflowInstance,
event_transformer_fn: DagsterEventTransformerFn = default_event_transformer,
minimum_interval_seconds: int = DEFAULT_AIRFLOW_SENSOR_INTERVAL_SECONDS,
Expand All @@ -105,7 +105,7 @@ def build_airflow_polling_sensor_defs(
Definitions: A `Definitions` object containing the constructed sensor.
"""
airflow_data = AirflowDefinitionsData(
airflow_instance=airflow_instance, mapped_defs=mapped_defs
airflow_instance=airflow_instance, mapped_assets=mapped_assets
)

@sensor(
Expand Down Expand Up @@ -390,7 +390,7 @@ def automapped_tasks_asset_keys(
asset_keys_to_emit = set()
asset_keys = airflow_data.asset_keys_in_task(dag_run.dag_id, task_instance.task_id)
for asset_key in asset_keys:
spec = airflow_data.mapped_defs.get_assets_def(asset_key).get_asset_spec(asset_key)
spec = airflow_data.all_asset_specs_by_key[asset_key]
if spec.metadata.get(AUTOMAPPED_TASK_METADATA_KEY):
asset_keys_to_emit.add(asset_key)
return asset_keys_to_emit
Loading

0 comments on commit 8cb9eb6

Please sign in to comment.