Source code for dbt_airflow_factory.tasks

"""Classes representing tasks corresponding to a single DBT model."""

from typing import Dict, Iterable, List, Optional

from airflow.models.baseoperator import BaseOperator


[docs]class ModelExecutionTask: """ Wrapper around tasks corresponding to a single DBT model. :param execution_airflow_task: Operator running DBT's ``run`` task. :type execution_airflow_task: BaseOperator :param test_airflow_task: Operator running DBT's ``test`` task (optional). :type test_airflow_task: BaseOperator :param task_group: TaskGroup consisting of ``run`` and ``test`` tasks (if Airflow version is at least 2). """ def __init__( # type: ignore self, execution_airflow_task: BaseOperator, test_airflow_task: Optional[BaseOperator] = None, task_group=None, ) -> None: self.execution_airflow_task = execution_airflow_task self.test_airflow_task = test_airflow_task self.task_group = task_group def __repr__(self) -> str: return ( repr(self.task_group) if self.task_group else repr( [self.execution_airflow_task] + ([self.test_airflow_task] if self.test_airflow_task else []) ) )
[docs] def get_start_task(self): # type: ignore """ Return model's first task. It is either a whole TaskGroup or ``run`` task. """ return self.task_group or self.execution_airflow_task
[docs] def get_end_task(self): # type: ignore """ Return model's last task. It is either a whole TaskGroup, ``test`` task, or ``run`` task, depending on version of Airflow and existence of ``test`` task. """ return self.task_group or self.test_airflow_task or self.execution_airflow_task
[docs]class ModelExecutionTasks: """ Dictionary of all Operators corresponding to DBT tasks. :param tasks: Dictionary of model tasks. :type tasks: Dict[str, ModelExecutionTask] :param starting_task_names: List of names of initial tasks (DAG sources). :type starting_task_names: List[str] :param ending_task_names: List of names of ending tasks (DAG sinks). :type ending_task_names: List[str] """ def __init__( self, tasks: Dict[str, ModelExecutionTask], starting_task_names: List[str], ending_task_names: List[str], ) -> None: self._tasks = tasks self._starting_task_names = starting_task_names self._ending_task_names = ending_task_names def __repr__(self) -> str: return f"ModelExecutionTasks(\n {self._tasks} \n)"
[docs] def get_task(self, node_name: str) -> ModelExecutionTask: """ Return :class:`ModelExecutionTask` for given model's **node_name**. :param node_name: Name of the task. :type node_name: str :return: Wrapper around tasks corresponding to a given model. :rtype: ModelExecutionTask """ return self._tasks[node_name]
[docs] def length(self) -> int: """Count TaskGroups corresponding to a single DBT model.""" return len(self._tasks)
[docs] def get_starting_tasks(self) -> List[ModelExecutionTask]: """ Get a list of all DAG sources. :return: List of all DAG sources. :rtype: List[ModelExecutionTask] """ return self._extract_by_keys(self._starting_task_names)
[docs] def get_ending_tasks(self) -> List[ModelExecutionTask]: """ Get a list of all DAG sinks. :return: List of all DAG sinks. :rtype: List[ModelExecutionTask] """ return self._extract_by_keys(self._ending_task_names)
def _extract_by_keys(self, keys: Iterable[str]) -> List[ModelExecutionTask]: tasks = [] for key in keys: tasks.append(self._tasks[key]) return tasks