-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtest_main.py
executable file
·123 lines (104 loc) · 4.45 KB
/
test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import boto3
from hive_metastore_client.builders import ( # type: ignore
ColumnBuilder,
DatabaseBuilder,
SerDeInfoBuilder,
StorageDescriptorBuilder,
TableBuilder,
)
from hive_metastore_client import HiveMetastoreClient # type: ignore
from thrift_files.libraries.thrift_hive_metastore_client.ttypes import Database, Table # type: ignore
import psycopg2
import os
import typing as t
import unittest
DATABASE_NAME = 'test_database'
Env = t.Dict[str, str]
def get_test_table(db_location: str) -> Table:
""" Inspired from: https://github.com/quintoandar/hive-metastore-client/blob/6743fb7a383f4fa00cf5235f599c239f8af2a92c/examples/create_table.py """
# You must create a list with the columns
columns = [
ColumnBuilder("id", "string", "col comment").build(),
ColumnBuilder("client_name", "string").build(),
ColumnBuilder("amount", "string").build(),
ColumnBuilder("year", "string").build(),
ColumnBuilder("month", "string").build(),
ColumnBuilder("day", "string").build(),
]
# If you table has partitions create a list with the partition columns
# This list is similar to the columns list, and the year, month and day
# columns are the same.
partition_keys = [
ColumnBuilder("year", "string").build(),
ColumnBuilder("month", "string").build(),
ColumnBuilder("day", "string").build(),
]
serde_info = SerDeInfoBuilder(
serialization_lib="org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
).build()
storage_descriptor = StorageDescriptorBuilder(
columns=columns,
location=f"{db_location}/orders",
input_format="org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat",
output_format="org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat",
serde_info=serde_info,
).build()
table = TableBuilder(
table_name="orders",
db_name=DATABASE_NAME,
storage_descriptor=storage_descriptor,
partition_keys=partition_keys,
).build()
return table
def put_database(db_location: str) -> Database:
database_name = DATABASE_NAME
database = DatabaseBuilder(name=database_name).build()
with HiveMetastoreClient(os.environ['HIVE_HOST'], int(os.environ['HIVE_PORT'])) as client:
client.create_database(database)
client.create_table(get_test_table(db_location))
return client.get_database(database_name)
class TestCRUD(unittest.TestCase):
env: Env
def setUp(self) -> None:
self.env = os.environ.copy()
def tearDown(self) -> None:
with HiveMetastoreClient(self.env['HIVE_HOST'], int(self.env['HIVE_PORT'])) as client:
client.drop_database(DATABASE_NAME, deleteData=True, cascade=True)
def validate_postgres(self) -> None:
conn = psycopg2.connect(
host=self.env["POSTGRES_HOST"],
dbname=self.env["POSTGRES_DB"],
user=self.env["POSTGRES_USER"],
password=self.env["POSTGRES_PASSWORD"],
)
# Lifted from: https://stackoverflow.com/a/28668161
sql = """
SELECT
pgClass.relname AS tableName,
pgClass.reltuples AS rowCount
FROM
pg_class pgClass
INNER JOIN
pg_namespace pgNamespace ON (pgNamespace.oid = pgClass.relnamespace)
WHERE
pgNamespace.nspname NOT IN ('pg_catalog', 'information_schema')
AND pgClass.relkind='r'
"""
with conn.cursor() as cur:
cur.execute(sql)
records = cur.fetchall()
self.assertGreater(len(records), 0)
def gen_s3_location(self) -> str:
return f's3a://{self.env["S3_BUCKET"]}/{self.env["S3_PREFIX"]}/{DATABASE_NAME}.db'
def validate_s3(self) -> None:
client = boto3.client("s3", endpoint_url=self.env['S3_ENDPOINT_URL'])
response = client.list_objects_v2(Bucket=self.env["S3_BUCKET"], Prefix=self.env["S3_PREFIX"])
object_keys = [o["Key"] for o in response["Contents"]]
self.assertGreater(len(object_keys), 0)
def test_backend(self) -> None:
db_location = self.gen_s3_location()
database = put_database(db_location)
self.validate_postgres()
self.validate_s3()
expected_database = Database(name=DATABASE_NAME, locationUri=db_location, parameters={}, ownerType=1, catalogName='hive')
self.assertEqual(expected_database, database)