Skip to content

Commit

Permalink
[AIRFLOW-4205] Replace type comments by native Python typing (apache#…
Browse files Browse the repository at this point in the history
…5327)

We supported typing in Python 2 with mypy & type comments. Now that
we're dropping Python 2 support, we can switch to native Python types.
We support Python 3.5 which doesn't include variable type annotations
yet, so only function arguments and return values are typed.
  • Loading branch information
BasPH authored and potiuk committed May 25, 2019
1 parent 3322782 commit 05c06b0
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 153 deletions.
4 changes: 2 additions & 2 deletions airflow/contrib/operators/gcp_compute_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
from typing import Dict

from googleapiclient.errors import HttpError

Expand Down Expand Up @@ -452,8 +453,7 @@ def __init__(self,
project_id=project_id, zone=self.zone, resource_id=resource_id,
gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)

def _possibly_replace_template(self, dictionary):
# type: (dict) -> None
def _possibly_replace_template(self, dictionary: Dict) -> None:
if dictionary.get('instanceTemplate') == self.source_template:
dictionary['instanceTemplate'] = self.destination_template
self._change_performed = True
Expand Down
3 changes: 1 addition & 2 deletions airflow/contrib/utils/gcp_field_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class GcpBodyFieldSanitizer(LoggingMixin):
:type sanitize_specs: list[str]
"""
def __init__(self, sanitize_specs):
# type: (List[str]) -> None
def __init__(self, sanitize_specs: List[str]) -> None:
super().__init__()
self._sanitize_specs = sanitize_specs

Expand Down
22 changes: 8 additions & 14 deletions airflow/contrib/utils/gcp_field_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ class GcpBodyFieldValidator(LoggingMixin):
:type api_version: str
"""
def __init__(self, validation_specs, api_version):
# type: (Sequence[Dict], str) -> None
def __init__(self, validation_specs: Sequence[str], api_version: str) -> None:
super().__init__()
self._validation_specs = validation_specs
self._api_version = api_version
Expand All @@ -207,9 +206,8 @@ def _get_field_name_with_parent(field_name, parent):
return field_name

@staticmethod
def _sanity_checks(children_validation_specs, field_type, full_field_path,
regexp, allow_empty, custom_validation, value):
# type: (dict, str, str, str, Callable, object) -> None
def _sanity_checks(children_validation_specs: Dict, field_type: str, full_field_path: str,
regexp: str, allow_empty: bool, custom_validation: Callable, value) -> None:
if value is None and field_type != 'union':
raise GcpFieldValidationException(
"The required body field '{}' is missing. Please add it.".
Expand All @@ -236,8 +234,7 @@ def _sanity_checks(children_validation_specs, field_type, full_field_path,
format(full_field_path))

@staticmethod
def _validate_regexp(full_field_path, regexp, value):
# type: (str, str, str) -> None
def _validate_regexp(full_field_path: str, regexp: str, value: str) -> None:
if not re.match(regexp, value):
# Note matching of only the beginning as we assume the regexps all-or-nothing
raise GcpFieldValidationException(
Expand All @@ -246,15 +243,13 @@ def _validate_regexp(full_field_path, regexp, value):
format(full_field_path, value, regexp))

@staticmethod
def _validate_is_empty(full_field_path, value):
# type: (str, str) -> None
def _validate_is_empty(full_field_path: str, value: str) -> None:
if not value:
raise GcpFieldValidationException(
"The body field '{}' can't be empty. Please provide a value."
.format(full_field_path, value))

def _validate_dict(self, children_validation_specs, full_field_path, value):
# type: (dict, str, dict) -> None
def _validate_dict(self, children_validation_specs: Dict, full_field_path: str, value: Dict) -> None:
for child_validation_spec in children_validation_specs:
self._validate_field(validation_spec=child_validation_spec,
dictionary_to_validate=value,
Expand All @@ -272,9 +267,8 @@ def _validate_dict(self, children_validation_specs, full_field_path, value):
self._get_field_name_with_parent(field_name, full_field_path),
children_validation_specs)

def _validate_union(self, children_validation_specs, full_field_path,
dictionary_to_validate):
# type: (dict, str, dict) -> None
def _validate_union(self, children_validation_specs: Dict, full_field_path: str,
dictionary_to_validate: Dict) -> None:
field_found = False
found_field_name = None
for child_validation_spec in children_validation_specs:
Expand Down
7 changes: 4 additions & 3 deletions airflow/hooks/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _get_connection_from_env(cls, conn_id):
return conn

@classmethod
def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
def get_connections(cls, conn_id: str) -> Iterable[Connection]:
conn = cls._get_connection_from_env(conn_id)
if conn:
conns = [conn]
Expand All @@ -72,15 +72,16 @@ def get_connections(cls, conn_id): # type: (str) -> Iterable[Connection]
return conns

@classmethod
def get_connection(cls, conn_id): # type: (str) -> Connection
def get_connection(cls, conn_id: str) -> Connection:
conn = random.choice(list(cls.get_connections(conn_id)))
if conn.host:
log = LoggingMixin().log
log.info("Using connection to: %s", conn.debug_info())
return conn

@classmethod
def get_hook(cls, conn_id): # type: (str) -> BaseHook
def get_hook(cls, conn_id: str) -> "BaseHook":
# TODO: set method return type to BaseHook class when on 3.7+. See https://stackoverflow.com/a/33533514/3066428 # noqa: E501
connection = cls.get_connection(conn_id)
return connection.get_hook()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict

import yaml
from airflow.kubernetes.pod import Pod
from airflow.kubernetes.kubernetes_request_factory.kubernetes_request_factory \
Expand Down Expand Up @@ -42,8 +44,7 @@ def __init__(self):

pass

def create(self, pod):
# type: (Pod) -> dict
def create(self, pod: Pod) -> Dict:
req = yaml.safe_load(self._yaml)
self.extract_name(pod, req)
self.extract_labels(pod, req)
Expand Down Expand Up @@ -112,8 +113,7 @@ class ExtractXcomPodRequestFactory(KubernetesRequestFactory):
def __init__(self):
pass

def create(self, pod):
# type: (Pod) -> dict
def create(self, pod: Pod) -> Dict:
req = yaml.safe_load(self._yaml)
self.extract_name(pod, req)
self.extract_labels(pod, req)
Expand Down
4 changes: 1 addition & 3 deletions airflow/kubernetes/pod_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def run_pod(

return self._monitor_pod(pod, get_logs)

def _monitor_pod(self, pod, get_logs):
# type: (Pod, bool) -> Tuple[State, Optional[str]]

def _monitor_pod(self, pod: Pod, get_logs: bool) -> Tuple[State, Optional[str]]:
if get_logs:
logs = self.read_pod_logs(pod)
for line in logs:
Expand Down
75 changes: 36 additions & 39 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,40 +238,40 @@ class derived from this one results in the creation of a task object,
@apply_defaults
def __init__(
self,
task_id, # type: str
owner=configuration.conf.get('operators', 'DEFAULT_OWNER'), # type: str
email=None, # type: Optional[str]
email_on_retry=True, # type: bool
email_on_failure=True, # type: bool
retries=0, # type: int
retry_delay=timedelta(seconds=300), # type: timedelta
retry_exponential_backoff=False, # type: bool
max_retry_delay=None, # type: Optional[datetime]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
task_id: str,
owner: str = configuration.conf.get('operators', 'DEFAULT_OWNER'),
email: Optional[str] = None,
email_on_retry: bool = True,
email_on_failure: bool = True,
retries: int = 0,
retry_delay: timedelta = timedelta(seconds=300),
retry_exponential_backoff: bool = False,
max_retry_delay: Optional[datetime] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
schedule_interval=None, # not hooked as of now
depends_on_past=False, # type: bool
wait_for_downstream=False, # type: bool
dag=None, # type: Optional[DAG]
params=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
priority_weight=1, # type: int
weight_rule=WeightRule.DOWNSTREAM, # type: str
queue=configuration.conf.get('celery', 'default_queue'), # type: str
pool=None, # type: Optional[str]
sla=None, # type: Optional[timedelta]
execution_timeout=None, # type: Optional[timedelta]
on_failure_callback=None, # type: Optional[Callable]
on_success_callback=None, # type: Optional[Callable]
on_retry_callback=None, # type: Optional[Callable]
trigger_rule=TriggerRule.ALL_SUCCESS, # type: str
resources=None, # type: Optional[Dict]
run_as_user=None, # type: Optional[str]
task_concurrency=None, # type: Optional[int]
executor_config=None, # type: Optional[Dict]
do_xcom_push=True, # type: bool
inlets=None, # type: Optional[Dict]
outlets=None, # type: Optional[Dict]
depends_on_past: bool = False,
wait_for_downstream: bool = False,
dag: Optional[DAG] = None,
params: Optional[Dict] = None,
default_args: Optional[Dict] = None,
priority_weight: int = 1,
weight_rule: str = WeightRule.DOWNSTREAM,
queue: str = configuration.conf.get('celery', 'default_queue'),
pool: Optional[str] = None,
sla: Optional[timedelta] = None,
execution_timeout: Optional[timedelta] = None,
on_failure_callback: Optional[Callable] = None,
on_success_callback: Optional[Callable] = None,
on_retry_callback: Optional[Callable] = None,
trigger_rule: str = TriggerRule.ALL_SUCCESS,
resources: Optional[Dict] = None,
run_as_user: Optional[str] = None,
task_concurrency: Optional[int] = None,
executor_config: Optional[Dict] = None,
do_xcom_push: bool = True,
inlets: Optional[Dict] = None,
outlets: Optional[Dict] = None,
*args,
**kwargs
):
Expand Down Expand Up @@ -957,8 +957,7 @@ def xcom_pull(
include_prior_dates=include_prior_dates)

@cached_property
def extra_links(self):
# type: () -> Iterable[str]
def extra_links(self) -> Iterable[str]:
return list(set(self.operator_extra_link_dict.keys())
.union(self.global_operator_extra_link_dict.keys()))

Expand Down Expand Up @@ -986,8 +985,7 @@ class BaseOperatorLink(metaclass=ABCMeta):

@property
@abstractmethod
def name(self):
# type: () -> str
def name(self) -> str:
"""
Name of the link. This will be the button name on the task UI.
Expand All @@ -996,8 +994,7 @@ def name(self):
pass

@abstractmethod
def get_link(self, operator, dttm):
# type: (BaseOperator, datetime) -> str
def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
"""
Link to external system.
Expand Down
47 changes: 23 additions & 24 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,30 +164,29 @@ class DAG(BaseDag, LoggingMixin):

def __init__(
self,
dag_id, # type: str
description='', # type: str
schedule_interval=timedelta(days=1), # type: Optional[ScheduleInterval]
start_date=None, # type: Optional[datetime]
end_date=None, # type: Optional[datetime]
full_filepath=None, # type: Optional[str]
template_searchpath=None, # type: Optional[Union[str, Iterable[str]]]
template_undefined=jinja2.Undefined, # type: Type[jinja2.Undefined]
user_defined_macros=None, # type: Optional[Dict]
user_defined_filters=None, # type: Optional[Dict]
default_args=None, # type: Optional[Dict]
concurrency=configuration.conf.getint('core', 'dag_concurrency'), # type: int
max_active_runs=configuration.conf.getint(
'core', 'max_active_runs_per_dag'), # type: int
dagrun_timeout=None, # type: Optional[timedelta]
sla_miss_callback=None, # type: Optional[Callable]
default_view=None, # type: Optional[str]
orientation=configuration.conf.get('webserver', 'dag_orientation'), # type: str
catchup=configuration.conf.getboolean('scheduler', 'catchup_by_default'), # type: bool
on_success_callback=None, # type: Optional[Callable]
on_failure_callback=None, # type: Optional[Callable]
doc_md=None, # type: Optional[str]
params=None, # type: Optional[Dict]
access_control=None # type: Optional[Dict]
dag_id: str,
description: str = '',
schedule_interval: Optional[ScheduleInterval] = timedelta(days=1),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
full_filepath: Optional[str] = None,
template_searchpath: Optional[Union[str, Iterable[str]]] = None,
template_undefined: Type[jinja2.Undefined] = jinja2.Undefined,
user_defined_macros: Optional[Dict] = None,
user_defined_filters: Optional[Dict] = None,
default_args: Optional[Dict] = None,
concurrency: int = configuration.conf.getint('core', 'dag_concurrency'),
max_active_runs: int = configuration.conf.getint('core', 'max_active_runs_per_dag'),
dagrun_timeout: Optional[timedelta] = None,
sla_miss_callback: Optional[Callable] = None,
default_view: Optional[str] = None,
orientation: str = configuration.conf.get('webserver', 'dag_orientation'),
catchup: bool = configuration.conf.getboolean('scheduler', 'catchup_by_default'),
on_success_callback: Optional[Callable] = None,
on_failure_callback: Optional[Callable] = None,
doc_md: Optional[str] = None,
params: Optional[Dict] = None,
access_control: Optional[Dict] = None
):
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
Expand Down
12 changes: 6 additions & 6 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def setdefault(cls, key, default, deserialize_json=False):
@provide_session
def get(
cls,
key, # type: str
default_var=__NO_DEFAULT_SENTINEL, # type: Any
deserialize_json=False, # type: bool
key: str,
default_var: Any = __NO_DEFAULT_SENTINEL,
deserialize_json: bool = False,
session=None
):
obj = session.query(cls).filter(cls.key == key).first()
Expand All @@ -123,9 +123,9 @@ def get(
@provide_session
def set(
cls,
key, # type: str
value, # type: Any
serialize_json=False, # type: bool
key: str,
value: Any,
serialize_json: bool = False,
session=None
):

Expand Down
26 changes: 13 additions & 13 deletions airflow/operators/check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class CheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
sql, # type: str
conn_id=None, # type: Optional[str]
sql: str,
conn_id: Optional[str] = None,
*args,
**kwargs
):
Expand Down Expand Up @@ -132,10 +132,10 @@ class ValueCheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
sql, # type: str
pass_value, # type: Any
tolerance=None, # type: Any
conn_id=None, # type: Optional[str]
sql: str,
pass_value: Any,
tolerance: Any = None,
conn_id: Optional[str] = None,
*args,
**kwargs
):
Expand Down Expand Up @@ -244,13 +244,13 @@ class IntervalCheckOperator(BaseOperator):
@apply_defaults
def __init__(
self,
table, # type: str
metrics_thresholds, # type: Dict[str, int]
date_filter_column='ds', # type: Optional[str]
days_back=-7, # type: SupportsAbs[int]
ratio_formula='max_over_min', # type: Optional[str]
ignore_zero=True, # type: Optional[bool]
conn_id=None, # type: Optional[str]
table: str,
metrics_thresholds: Dict[str, int],
date_filter_column: Optional[str] = 'ds',
days_back: SupportsAbs[int] = -7,
ratio_formula: Optional[str] = 'max_over_min',
ignore_zero: Optional[bool] = True,
conn_id: Optional[str] = None,
*args, **kwargs
):
super().__init__(*args, **kwargs)
Expand Down
Loading

0 comments on commit 05c06b0

Please sign in to comment.