Skip to content

Commit 63dce8f

Browse files
authored
Merge pull request datafold#827 from datafold/type-annotate-everything-2
Type annotate some things ("no-brainers")
2 parents d5a4d12 + ff76f94 commit 63dce8f

35 files changed

+180
-162
lines changed

data_diff/__main__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_schema(pair: Tuple[Database, DbPath]) -> Dict[str, RawColumnInfo]:
7777
return db.query_table_schema(table_path)
7878

7979

80-
def diff_schemas(table1, table2, schema1, schema2, columns):
80+
def diff_schemas(table1, table2, schema1, schema2, columns) -> None:
8181
logging.info("Diffing schemas...")
8282
attrs = "name", "type", "datetime_precision", "numeric_precision", "numeric_scale"
8383
for c in columns:
@@ -103,7 +103,7 @@ def diff_schemas(table1, table2, schema1, schema2, columns):
103103

104104

105105
class MyHelpFormatter(click.HelpFormatter):
106-
def __init__(self, **kwargs):
106+
def __init__(self, **kwargs) -> None:
107107
super().__init__(self, **kwargs)
108108
self.indent_increment = 6
109109

@@ -281,7 +281,7 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
281281
default=None,
282282
help="Override the dbt production schema configuration within dbt_project.yml",
283283
)
284-
def main(conf, run, **kw):
284+
def main(conf, run, **kw) -> None:
285285
log_handlers = _get_log_handlers(kw["dbt"])
286286
if kw["table2"] is None and kw["database2"]:
287287
# Use the "database table table" form
@@ -341,9 +341,7 @@ def main(conf, run, **kw):
341341
production_schema_flag=kw["prod_schema"],
342342
)
343343
else:
344-
return _data_diff(
345-
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
346-
)
344+
_data_diff(dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw)
347345
except Exception as e:
348346
logging.error(e)
349347
raise
@@ -389,7 +387,7 @@ def _data_diff(
389387
threads1=None,
390388
threads2=None,
391389
__conf__=None,
392-
):
390+
) -> None:
393391
if limit and stats:
394392
logging.error("Cannot specify a limit when using the -s/--stats switch")
395393
return

data_diff/abcs/database_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class Integer(NumericType, IKey):
290290
precision: int = 0
291291
python_type: type = int
292292

293-
def __attrs_post_init__(self):
293+
def __attrs_post_init__(self) -> None:
294294
assert self.precision == 0
295295

296296

data_diff/cloud/data_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def process_response(self, value: str) -> str:
4646
return value
4747

4848

49-
def _validate_temp_schema(temp_schema: str):
49+
def _validate_temp_schema(temp_schema: str) -> None:
5050
if len(temp_schema.split(".")) != 2:
5151
raise ValueError("Temporary schema should have a format <database>.<schema>")
5252

data_diff/cloud/datafold_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class DatafoldAPI:
185185
host: str = "https://app.datafold.com"
186186
timeout: int = 30
187187

188-
def __attrs_post_init__(self):
188+
def __attrs_post_init__(self) -> None:
189189
self.host = self.host.rstrip("/")
190190
self.headers = {
191191
"Authorization": f"Key {self.api_key}",

data_diff/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _apply_config(config: Dict[str, Any], run_name: str, kw: Dict[str, Any]):
9999
_ENV_VAR_PATTERN = r"\$\{([A-Za-z0-9_]+)\}"
100100

101101

102-
def _resolve_env(config: Dict[str, Any]):
102+
def _resolve_env(config: Dict[str, Any]) -> None:
103103
"""
104104
Resolve environment variables referenced as ${ENV_VAR_NAME}.
105105
Missing environment variables are replaced with an empty string.

data_diff/databases/_connect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class Connect:
100100
database_by_scheme: Dict[str, Database]
101101
conn_cache: MutableMapping[Hashable, Database]
102102

103-
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
103+
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME) -> None:
104104
super().__init__()
105105
self.database_by_scheme = database_by_scheme
106106
self.conn_cache = weakref.WeakValueDictionary()

data_diff/databases/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@
55
import math
66
import sys
77
import logging
8-
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
8+
from typing import (
9+
Any,
10+
Callable,
11+
ClassVar,
12+
Dict,
13+
Generator,
14+
Iterator,
15+
NewType,
16+
Tuple,
17+
Optional,
18+
Sequence,
19+
Type,
20+
List,
21+
Union,
22+
TypeVar,
23+
)
924
from functools import partial, wraps
1025
from concurrent.futures import ThreadPoolExecutor
1126
import threading
@@ -116,7 +131,7 @@ def dialect(self) -> "BaseDialect":
116131
def compile(self, elem, params=None) -> str:
117132
return self.dialect.compile(self, elem, params)
118133

119-
def new_unique_name(self, prefix="tmp"):
134+
def new_unique_name(self, prefix="tmp") -> str:
120135
self._counter[0] += 1
121136
return f"{prefix}{self._counter[0]}"
122137

@@ -173,7 +188,7 @@ class ThreadLocalInterpreter:
173188
compiler: Compiler
174189
gen: Generator
175190

176-
def apply_queries(self, callback: Callable[[str], Any]):
191+
def apply_queries(self, callback: Callable[[str], Any]) -> None:
177192
q: Expr = next(self.gen)
178193
while True:
179194
sql = self.compiler.database.dialect.compile(self.compiler, q)
@@ -885,20 +900,21 @@ def optimizer_hints(self, hints: str) -> str:
885900

886901

887902
T = TypeVar("T", bound=BaseDialect)
903+
Row = Sequence[Any]
888904

889905

890906
@attrs.define(frozen=True)
891907
class QueryResult:
892-
rows: list
908+
rows: List[Row]
893909
columns: Optional[list] = None
894910

895-
def __iter__(self):
911+
def __iter__(self) -> Iterator[Row]:
896912
return iter(self.rows)
897913

898-
def __len__(self):
914+
def __len__(self) -> int:
899915
return len(self.rows)
900916

901-
def __getitem__(self, i):
917+
def __getitem__(self, i) -> Row:
902918
return self.rows[i]
903919

904920

@@ -1209,7 +1225,7 @@ class ThreadedDatabase(Database):
12091225
_queue: Optional[ThreadPoolExecutor] = None
12101226
thread_local: threading.local = attrs.field(factory=threading.local)
12111227

1212-
def __attrs_post_init__(self):
1228+
def __attrs_post_init__(self) -> None:
12131229
self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
12141230
logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")
12151231

data_diff/databases/bigquery.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ class Dialect(BaseDialect):
8585
def random(self) -> str:
8686
return "RAND()"
8787

88-
def quote(self, s: str):
88+
def quote(self, s: str) -> str:
8989
return f"`{s}`"
9090

91-
def to_string(self, s: str):
91+
def to_string(self, s: str) -> str:
9292
return f"cast({s} as string)"
9393

9494
def type_repr(self, t) -> str:
@@ -212,7 +212,7 @@ class BigQuery(Database):
212212
dataset: str
213213
_client: Any
214214

215-
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw):
215+
def __init__(self, project, *, dataset, bigquery_credentials=None, **kw) -> None:
216216
super().__init__()
217217
credentials = bigquery_credentials
218218
bigquery = import_bigquery()

data_diff/databases/clickhouse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class Clickhouse(ThreadedDatabase):
175175

176176
_args: Dict[str, Any]
177177

178-
def __init__(self, *, thread_count: int, **kw):
178+
def __init__(self, *, thread_count: int, **kw) -> None:
179179
super().__init__(thread_count=thread_count)
180180

181181
self._args = kw

data_diff/databases/databricks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def type_repr(self, t) -> str:
6565
except KeyError:
6666
return super().type_repr(t)
6767

68-
def quote(self, s: str):
68+
def quote(self, s: str) -> str:
6969
return f"`{s}`"
7070

7171
def to_string(self, s: str) -> str:
@@ -118,7 +118,7 @@ class Databricks(ThreadedDatabase):
118118
catalog: str
119119
_args: Dict[str, Any]
120120

121-
def __init__(self, *, thread_count, **kw):
121+
def __init__(self, *, thread_count, **kw) -> None:
122122
super().__init__(thread_count=thread_count)
123123
logging.getLogger("databricks.sql").setLevel(logging.WARNING)
124124

data_diff/databases/duckdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class DuckDB(Database):
126126
_args: Dict[str, Any] = attrs.field(init=False)
127127
_conn: Any = attrs.field(init=False)
128128

129-
def __init__(self, **kw):
129+
def __init__(self, **kw) -> None:
130130
super().__init__()
131131
self._args = kw
132132
self._conn = self.create_connection()

data_diff/databases/mssql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Dialect(BaseDialect):
7676
"json": JSON,
7777
}
7878

79-
def quote(self, s: str):
79+
def quote(self, s: str) -> str:
8080
return f"[{s}]"
8181

8282
def set_timezone_to_utc(self) -> str:
@@ -93,7 +93,7 @@ def current_schema(self) -> str:
9393
FROM sys.database_principals
9494
WHERE name = CURRENT_USER"""
9595

96-
def to_string(self, s: str):
96+
def to_string(self, s: str) -> str:
9797
# Both convert(varchar(max), …) and convert(text, …) do work.
9898
return f"CONVERT(VARCHAR(MAX), {s})"
9999

@@ -168,7 +168,7 @@ class MsSQL(ThreadedDatabase):
168168
_args: Dict[str, Any]
169169
_mssql: Any
170170

171-
def __init__(self, host, port, user, password, *, database, thread_count, **kw):
171+
def __init__(self, host, port, user, password, *, database, thread_count, **kw) -> None:
172172
super().__init__(thread_count=thread_count)
173173

174174
args = dict(server=host, port=port, database=database, user=user, password=password, **kw)

data_diff/databases/mysql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ class Dialect(BaseDialect):
7070
"boolean": Boolean,
7171
}
7272

73-
def quote(self, s: str):
73+
def quote(self, s: str) -> str:
7474
return f"`{s}`"
7575

76-
def to_string(self, s: str):
76+
def to_string(self, s: str) -> str:
7777
return f"cast({s} as char)"
7878

7979
def is_distinct_from(self, a: str, b: str) -> str:
@@ -129,7 +129,7 @@ class MySQL(ThreadedDatabase):
129129

130130
_args: Dict[str, Any]
131131

132-
def __init__(self, *, thread_count, **kw):
132+
def __init__(self, *, thread_count, **kw) -> None:
133133
super().__init__(thread_count=thread_count)
134134
self._args = kw
135135

data_diff/databases/oracle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ class Dialect(
5959
ROUNDS_ON_PREC_LOSS = True
6060
PLACEHOLDER_TABLE = "DUAL"
6161

62-
def quote(self, s: str):
62+
def quote(self, s: str) -> str:
6363
return f'"{s}"'
6464

65-
def to_string(self, s: str):
65+
def to_string(self, s: str) -> str:
6666
return f"cast({s} as varchar(1024))"
6767

6868
def limit_select(
@@ -164,7 +164,7 @@ class Oracle(ThreadedDatabase):
164164
kwargs: Dict[str, Any]
165165
_oracle: Any
166166

167-
def __init__(self, *, host, database, thread_count, **kw):
167+
def __init__(self, *, host, database, thread_count, **kw) -> None:
168168
super().__init__(thread_count=thread_count)
169169
self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
170170
self.default_schema = kw.get("user").upper()

data_diff/databases/postgresql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class PostgreSQL(ThreadedDatabase):
163163
_args: Dict[str, Any]
164164
_conn: Any
165165

166-
def __init__(self, *, thread_count, **kw):
166+
def __init__(self, *, thread_count, **kw) -> None:
167167
super().__init__(thread_count=thread_count)
168168
self._args = kw
169169
self.default_schema = "public"

data_diff/databases/presto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Presto(Database):
152152

153153
_conn: Any
154154

155-
def __init__(self, **kw):
155+
def __init__(self, **kw) -> None:
156156
super().__init__()
157157
self.default_schema = "public"
158158
prestodb = import_presto()

data_diff/databases/snowflake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class Snowflake(Database):
104104

105105
_conn: Any
106106

107-
def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw):
107+
def __init__(self, *, schema: str, key: Optional[str] = None, key_content: Optional[str] = None, **kw) -> None:
108108
super().__init__()
109109
snowflake, serialization, default_backend = import_snowflake()
110110
logging.getLogger("snowflake.connector").setLevel(logging.WARNING)

data_diff/databases/trino.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Trino(presto.Presto):
4040

4141
_conn: Any
4242

43-
def __init__(self, **kw):
43+
def __init__(self, **kw) -> None:
4444
super().__init__()
4545
trino = import_trino()
4646

data_diff/databases/vertica.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Dialect(BaseDialect):
6060
# https://www.vertica.com/docs/9.3.x/HTML/Content/Authoring/SQLReferenceManual/DataTypes/Numeric/NUMERIC.htm#Default
6161
DEFAULT_NUMERIC_PRECISION = 15
6262

63-
def quote(self, s: str):
63+
def quote(self, s: str) -> str:
6464
return f'"{s}"'
6565

6666
def concat(self, items: List[str]) -> str:
@@ -137,7 +137,7 @@ class Vertica(ThreadedDatabase):
137137

138138
_args: Dict[str, Any]
139139

140-
def __init__(self, *, thread_count, **kw):
140+
def __init__(self, *, thread_count, **kw) -> None:
141141
super().__init__(thread_count=thread_count)
142142
self._args = kw
143143
self._args["AUTOCOMMIT"] = False

data_diff/dbt_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def try_get_dbt_runner():
5050

5151
# ProfileRenderer.render_data() fails without instantiating global flag MACRO_DEBUGGING in dbt-core 1.5
5252
# hacky but seems to be a bug on dbt's end
53-
def try_set_dbt_flags():
53+
def try_set_dbt_flags() -> None:
5454
try:
5555
from dbt.flags import set_flags
5656

data_diff/diff_tables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from contextlib import contextmanager
88
from operator import methodcaller
9-
from typing import Dict, Set, List, Tuple, Iterator, Optional, Union
9+
from typing import Any, Dict, Set, List, Tuple, Iterator, Optional, Union
1010
from concurrent.futures import ThreadPoolExecutor, as_completed
1111

1212
import attrs
@@ -89,7 +89,7 @@ class DiffResultWrapper:
8989
stats: dict
9090
result_list: list = attrs.field(factory=list)
9191

92-
def __iter__(self):
92+
def __iter__(self) -> Iterator[Any]:
9393
yield from self.result_list
9494
for i in self.diff:
9595
self.result_list.append(i)

data_diff/hashdiff_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class HashDiffer(TableDiffer):
9696

9797
stats: dict = attrs.field(factory=dict)
9898

99-
def __attrs_post_init__(self):
99+
def __attrs_post_init__(self) -> None:
100100
# Validate options
101101
if self.bisection_factor >= self.bisection_threshold:
102102
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")

0 commit comments

Comments
 (0)