diff --git a/MANIFEST.in b/MANIFEST.in index 04f196a..ddf8476 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,4 @@ include README.md include LICENSE +include NOTES +include pyspark_db_utils/jars/*.jar diff --git a/NOTES b/NOTES new file mode 100644 index 0000000..0922a77 --- /dev/null +++ b/NOTES @@ -0,0 +1,18 @@ +# packaging project +python3.5 -m pip install --user --upgrade setuptools wheel +sudo apt-get install twine + +rm -f dist/* +python3.5 setup.py sdist bdist_wheel +twine upload --repository-url https://test.pypi.org/legacy/ dist/*.tar.gz + + +# gen doc +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.ifconfig', + 'sphinxcontrib.napoleon', +] + +sphinx-apidoc -f -o ./source/ ../ && make html \ No newline at end of file diff --git a/pyspark_db_utils/__init__.py b/pyspark_db_utils/__init__.py index e69de29..0b3ea7c 100644 --- a/pyspark_db_utils/__init__.py +++ b/pyspark_db_utils/__init__.py @@ -0,0 +1,4 @@ +from .read_from_db import read_from_db +from .write_to_db import write_to_db +from .utils.db_connect import db_connect +from .batch import insert_values, update_many, execute_batch diff --git a/pyspark_db_utils/pg.py b/pyspark_db_utils/batch.py similarity index 53% rename from pyspark_db_utils/pg.py rename to pyspark_db_utils/batch.py index 4ef75be..3ae7c6f 100644 --- a/pyspark_db_utils/pg.py +++ b/pyspark_db_utils/batch.py @@ -3,185 +3,17 @@ Most useful are: :func:`read_from_pg`, :func:`write_to_pg`, :func:`execute_batch` """ -from typing import Dict, List, Set, Optional, Iterator, Iterable, Any -from contextlib import contextmanager -from itertools import chain +import functools from logging import Logger +from typing import Dict, List, Optional, Iterator, Iterable, Any -import jaydebeapi -import datetime -import string -import functools -from pyspark.sql import SQLContext from pyspark import SparkContext from pyspark.sql import DataFrame +from pyspark.sql import SQLContext -from pyspark_db_utils.utils.drop_columns import drop_other_columns - - -def read_from_pg(config: dict, sql: str, sc: SparkContext, logger: Optional[Logger]=None) -> DataFrame: - """ Read dataframe from postgres - - Args: - config: settings for connect - sql: sql to read, it may be one of these format - - - 'table_name' - - - 'schema_name.table_name' - - - '(select a, b, c from t1 join t2 ...) as foo' - - sc: specific current spark_context or None - logger: logger - - Returns: - selected DF - """ - if logger: - logger.info('read_from_pg:\n{}'.format(sql)) - sqlContext = SQLContext(sc) - df = sqlContext.read.format("jdbc").options( - url=config['PG_URL'], - dbtable=sql, - **config['PG_PROPERTIES'] - ).load().cache() - return df - - -def write_to_pg(df: DataFrame, config: dict, table: str, mode: str='append', logger: Optional[Logger]=None) -> None: - """ Write dataframe to postgres - - Args: - df: DataFrame to write - config: config dict - table: table_name - logger: logger - mode: mode, one of these: - - - append - create table if not exists (with all columns of DataFrame) - and write records to table (using fields only in table columns) - - - overwrite - truncate table (if exists) and write records (using fields only in table columns) - - - overwrite_full - drop table and create new one with all columns and DataFrame and append records to it - - - fail - fail if table is not exists, otherwise append records to it - """ - field_names = get_field_names(table, config) - table_exists = bool(field_names) - - if mode == 'fail': - if not table_exists: - raise Exception('table {} does not exist'.format(table)) - else: - mode = 'append' # if table exists just append records to it - - if mode == 'append': - if table_exists: - df = drop_other_columns(df, field_names) - elif mode == 'overwrite_full': - if table_exists: - run_sql('drop table {}'.format(table), config, logger=logger) - elif mode == 'overwrite': - if table_exists: - df = drop_other_columns(df, field_names) - run_sql('truncate {}'.format(table), config, logger=logger) - df.write.jdbc(url=config['PG_URL'], - table=table, - mode='append', # always just append because all logic already done - properties=config['PG_PROPERTIES']) - - -def run_sql(sql: str, config: Dict, logger: Optional[Logger]=None) -> None: - """ just run sql """ - if logger: - logger.info('run_sql: {}'.format(sql)) - with jdbc_connect(config, autocommit=True) as (conn, curs): - curs.execute(sql) - - -def get_field_names(table_name: str, config: Dict) -> Set[str]: - """ get field names of table """ - if len(table_name.split('.')) > 1: - table_name = table_name.split('.')[-1] - with jdbc_connect(config) as (conn, cur): - sql = "SELECT column_name FROM information_schema.columns WHERE table_name='{}'".format(table_name) - cur.execute(sql) - res = cur.fetchall() - field_names = list(chain(*res)) - return set(field_names) - - -def get_field_names_stub(df: DataFrame, config: Dict, table_name: str, sc: SparkContext) -> Set[str]: - """ get field names of table - - ! DONT USE IT ! Use get_field_names instead ! - - TODO: replace with get_field_names - """ - sql = '(select * from {} limit 1) as smth'.format(table_name) - df_tmp = read_from_pg(config, sql, sc) - columns_in_db = set(df_tmp.columns) - columns_in_df = set(df.columns) - field_names = columns_in_db.intersection(columns_in_df) - return set(field_names) - - -@contextmanager -def jdbc_connect(config: Dict, autocommit: bool=False): - """ context manager, opens and closes connection correctly - - Args: - config: config - autocommit: enable autocommit - - Yields: - tuple: connection, cursor - """ - conn = jaydebeapi.connect(config["PG_PROPERTIES"]['driver'], - config["PG_URL"], - {'user': config["PG_PROPERTIES"]['user'], - 'password': config["PG_PROPERTIES"]['password']}, - config["PG_DRIVER_PATH"] - ) - if not autocommit: - conn.jconn.setAutoCommit(False) - curs = conn.cursor() - yield conn, curs - curs.close() - conn.close() - - -def mogrify(val) -> str: - """ cast python values to raw-sql correctly and escape if necessary - - Args: - val: some value - - Returns: - mogrified value - """ - if isinstance(val, str): - escaped = val.replace("'", "''") - return "'{}'".format(escaped) - elif isinstance(val, (int, float)): - return str(val) - elif isinstance(val, datetime.datetime): - return "'{}'::TIMESTAMP".format(val) - elif isinstance(val, datetime.date): - return "'{}'::DATE".format(val) - elif val is None: - return 'null' - else: - raise TypeError('unknown type {} for mogrify'.format(type(val))) - - -class MogrifyFormatter(string.Formatter): - """ custom formatter to mogrify {}-like formatting strings """ - def get_value(self, key, args, kwargs) -> str: - row = args[0] - return mogrify(row[key]) +from pyspark_db_utils.mogrify import mogrify, mogrifier +from pyspark_db_utils.utils.db_connect import db_connect +from pyspark_db_utils.utils.ensure_columns_in_table import ensure_columns_in_table def batcher(iterable: Iterable, batch_size: int): @@ -204,32 +36,19 @@ def batcher(iterable: Iterable, batch_size: int): yield batch -mogrifier = MogrifyFormatter() - - -def _execute_batch_partition(partition: Iterator, sql_temp: str, config: Dict, batch_size: int) -> None: +def _execute_batch_partition(partition: Iterator, sql_temp: str, con_info: Dict, batch_size: int) -> None: """ execute sql_temp for rows in partition in batch """ - - # For debugging RAM - # def get_ram(): - # import os - # ram = os.popen('free -m').read() - # return ram - - with jdbc_connect(config) as (conn, curs): + with db_connect(con_info=con_info) as (conn, curs): for batch in batcher(partition, batch_size): sql = ';'.join( mogrifier.format(sql_temp, row) for row in batch ) - # if config.get('DEBUG_SQL'): - print('\n\nsql: {}\n\n'.format(sql[:500])) curs.execute(sql) - # print('\nFREE RAM: %s\n' % get_ram()) conn.commit() -def execute_batch(df: DataFrame, sql_temp: str, config: Dict, batch_size: int=1000) -> None: +def execute_batch(df: DataFrame, sql_temp: str, con_info: Dict, batch_size: int=1000) -> None: """ Very useful function to run custom SQL on each rows in DataFrame by batches. @@ -244,7 +63,7 @@ def execute_batch(df: DataFrame, sql_temp: str, config: Dict, batch_size: int=10 Examples: update table rows by id and values for DF records:: - >> execute_batch(df, config=config, + >> execute_batch(df, con_info=con_info, sql_temp='update %(table_name)s set out_date=%(filename_date)s where id={id}' % {'table_name': table_name, 'filename_date': filename_date}) @@ -258,17 +77,17 @@ def execute_batch(df: DataFrame, sql_temp: str, config: Dict, batch_size: int=10 lost_sales = EXTRACT(epoch FROM {check_date_time} - c.start_date) * (3.0 / 7) / (24 * 3600) WHERE c.id = {id} - ''', config=config) + ''', con_info=con_info) """ df.foreachPartition( - functools.partial(_execute_batch_partition, sql_temp=sql_temp, config=config, batch_size=batch_size)) + functools.partial(_execute_batch_partition, sql_temp=sql_temp, con_info=con_info, batch_size=batch_size)) def _update_many_partition(partition: Iterator, table_name: str, set_to: Dict[str, Any], - config: Dict, + con_info: Dict, batch_size: int, id_field: str='id' ) -> None: @@ -278,7 +97,7 @@ def _update_many_partition(partition: Iterator, partition: DataFrame partition table_name: table name set_to: dict such as {'field_name1': new_value1, 'field_name2': new_value2} - config: config + con_info: con_info batch_size: batch size id_field: id field """ @@ -286,7 +105,7 @@ def _update_many_partition(partition: Iterator, for field_name, new_value in set_to.items(): field_stmt_list.append('{}={}'.format(field_name, mogrify(new_value))) fields_stmt = ', '.join(field_stmt_list) - with jdbc_connect(config) as (conn, curs): + with db_connect(con_info) as (conn, curs): for batch in batcher(partition, batch_size): ids = [row[id_field] for row in batch] if not ids: @@ -300,7 +119,7 @@ def _update_many_partition(partition: Iterator, def update_many(df: DataFrame, table_name: str, set_to: Dict, - config: Dict, + con_info: Dict, batch_size: int=1000, id_field: str='id' ) -> None: @@ -314,19 +133,19 @@ def update_many(df: DataFrame, df: DataFrame table_name: table name set_to: dict such as {'field_name1': new_const_value1, 'field_name2': new_const_value2} - config: config + con_info: con_info batch_size: batch size id_field: id field """ df.foreachPartition( - functools.partial(_update_many_partition, table_name=table_name, set_to=set_to, config=config, + functools.partial(_update_many_partition, table_name=table_name, set_to=set_to, con_info=con_info, batch_size=batch_size, id_field=id_field)) def _insert_values_partition(partition: Iterator, sql_temp: str, values_temp: str, - config: Dict, + con_info: Dict, batch_size: int, fields_stmt: Optional[str]=None, table_name: Optional[str]=None, @@ -337,13 +156,13 @@ def _insert_values_partition(partition: Iterator, partition: DataFrame partition sql_temp: sql template (may consist values, fields, table_name formatting-arguments) values_temp: string template for values - config: config + con_info: con_info batch_size: batch size fields_stmt: string template for fields table_name: table name argument for string-formatting """ - with jdbc_connect(config) as (conn, curs): + with db_connect(con_info) as (conn, curs): for batch in batcher(partition, batch_size): values = ','.join( mogrifier.format(values_temp, row) @@ -358,12 +177,12 @@ def _insert_values_partition(partition: Iterator, def insert_values(df: DataFrame, - config: Dict, + con_info: Dict, batch_size: int=1000, fields: Optional[List[str]]=None, values_temp: Optional[str]=None, sql_temp: Optional[str]=None, - table_name: Optional[str]=None, + table: Optional[str]=None, on_conflict_do_nothing: bool=False, on_conflict_do_update: bool=False, drop_duplicates: bool=False, @@ -384,10 +203,10 @@ def insert_values(df: DataFrame, df: DataFrame sql_temp: sql template (may consist values, fields, table_name formatting-arguments) values_temp: string template for values - config: config + con_info: con_info fields: list of columns for insert (if None, all olumns will be used) batch_size: batch size - table_name: table name argument for string-formatting + table: table name argument for string-formatting on_conflict_do_nothing: add ON CONFLICT DO NOTHING statement to each INSERT on_conflict_do_update: add ON CONFLICT DO UPDATE statement to each INSERT drop_duplicates: drop duplicates if set to True @@ -400,9 +219,8 @@ def insert_values(df: DataFrame, cleaned_df = df.select(*df.columns) # select columns to write - if table_name: - field_names = get_field_names_stub(df, config, table_name, sc) - cleaned_df = df.select(*field_names) + if table: + cleaned_df = ensure_columns_in_table(con_info=con_info, table=table, df=df) if drop_duplicates: cleaned_df = cleaned_df.dropDuplicates(drop_duplicates) @@ -414,9 +232,9 @@ def insert_values(df: DataFrame, where {} is not null""".format(exclude_null_field)) # TODO: add mogrify values, not table_name, fields, etc - assert table_name or sql_temp + assert table or sql_temp if sql_temp is None: - sql_temp = 'INSERT INTO {table_name}({fields}) VALUES {values}' + sql_temp = 'INSERT INTO {table}({fields}) VALUES {values}' if on_conflict_do_nothing: sql_temp += ' ON CONFLICT DO NOTHING' @@ -434,6 +252,6 @@ def insert_values(df: DataFrame, cleaned_df.foreachPartition( functools.partial(_insert_values_partition, - sql_temp=sql_temp, values_temp=values_temp, fields_stmt=fields_stmt, table_name=table_name, - config=config, batch_size=batch_size, logger=logger)) + sql_temp=sql_temp, values_temp=values_temp, fields_stmt=fields_stmt, table=table, + con_info=con_info, batch_size=batch_size, logger=logger)) cleaned_df.unpersist() diff --git a/pyspark_db_utils/ch/__init__.py b/pyspark_db_utils/ch/__init__.py deleted file mode 100644 index f595ef1..0000000 --- a/pyspark_db_utils/ch/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .write_to_ch import write_to_ch -from .read_from_ch import read_from_ch diff --git a/pyspark_db_utils/ch/make_ch_model_for_df.py b/pyspark_db_utils/ch/make_ch_model_for_df.py deleted file mode 100644 index 21f747a..0000000 --- a/pyspark_db_utils/ch/make_ch_model_for_df.py +++ /dev/null @@ -1,65 +0,0 @@ -import types -from pyspark.sql.types import ( - StringType, BinaryType, BooleanType, DateType, - TimestampType, DecimalType, DoubleType, FloatType, ByteType, IntegerType, - LongType, ShortType) -from infi.clickhouse_orm import models, engines -from infi.clickhouse_orm.fields import ( - StringField, FixedStringField, DateField, DateTimeField, - UInt8Field, UInt16Field, UInt32Field, UInt64Field, - Int8Field, Int16Field, Int32Field, Int64Field, - Float32Field, Float64Field, Enum8Field, Enum16Field, NullableField, Field as CHField) -from pyspark.sql.types import StructField as SparkField - - -# mapping from spark type to ClickHouse type -SparkType2CHField = { - StringType: StringField, - BinaryType: StringField, - BooleanType: UInt8Field, # There are no bool type in ClickHouse - DateType: DateField, - TimestampType: DateTimeField, - DoubleType: Float64Field, - FloatType: Float32Field, - ByteType: UInt8Field, - IntegerType: Int32Field, - LongType: Int64Field, - ShortType: Int16Field, - DecimalType: Float64Field, -} - - -def spark_field2clickhouse_field(spark_field: SparkField) -> CHField: - """ spark field to clickhouse field """ - spark_type = type(spark_field.dataType) - clickhouse_field_class = SparkType2CHField[spark_type] - clickhouse_field = clickhouse_field_class() - # if spark_field.nullable: - # logger.warning('spark_field {} is nullable, it is not good for ClickHouse'.format(spark_field)) - # # IDEA - # # clickhouse_field = NullableField(clickhouse_field) - return clickhouse_field - - -def make_ch_model_for_df(df, date_field_name, table_name, pk_columns=None): - """ - creates ORM Model for DataFrame - models.Model is meta class so it is a bit tricky to dynamically create child-class with given attrivutes - ToDo: Add support for engine Memory and Log - :param df: PySpark DataFrame - :param date_field_name: Date-typed field for partitioning - :param pk_columns: primary key columns - :param table_name: table name in DB - :return: ORM Model class - """ - assert date_field_name in df.schema.names - assert 'engine' not in df.schema.names - if pk_columns is None: - pk_columns = df.schema.names - attrs = {'engine': engines.MergeTree(date_field_name, pk_columns)} - for field in df.schema.fields: - clickhouse_field = spark_field2clickhouse_field(field) - attrs[field.name] = clickhouse_field - Model = type('MyModel', (models.Model,), attrs) - Model.table_name = staticmethod(types.MethodType(lambda cls: table_name, Model)) - return Model diff --git a/pyspark_db_utils/ch/read_from_ch.py b/pyspark_db_utils/ch/read_from_ch.py deleted file mode 100644 index 01a2c3e..0000000 --- a/pyspark_db_utils/ch/read_from_ch.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Optional, Dict -from logging import Logger - -from pyspark.sql import SQLContext, DataFrame -from pyspark import SparkContext - - -def read_from_ch(config: Dict, - sql: str, - sc: SparkContext, - logger: Optional[Logger]=None - ) -> DataFrame: - """ Read DF from ClickHouse SQL - - Args: - config: config - sql: sql - it may be one of these format: - - 'table_name' - - 'schema_name.table_name' - - '(select a, b, c from t1 join t2 ...) as foo' - sc: spark context - logger: logger - - Returns: - DataFrame - """ - if logger: - logger.info('read_from_ch: {}'.format(sql)) - spark = SQLContext(sc) - df = spark.read.format("jdbc").options( - url=config['CH_JDBC_URL'], - dbtable=sql, - **config['CH_JDBC_PROPERTIES'] - ).load().cache() - return df diff --git a/pyspark_db_utils/ch/smart_ch_fillna.py b/pyspark_db_utils/ch/smart_ch_fillna.py deleted file mode 100644 index dd4b479..0000000 --- a/pyspark_db_utils/ch/smart_ch_fillna.py +++ /dev/null @@ -1,59 +0,0 @@ -from functools import reduce - -from pyspark.sql import DataFrame -import pyspark.sql.functions as F -from pyspark.sql.types import ( - StringType, BinaryType, BooleanType, DateType, - TimestampType, DecimalType, DoubleType, FloatType, ByteType, IntegerType, - LongType, ShortType) - - -# Default value (if null) for Spark types -# IMPORTANT! No default values for Date and TimeStamp types yet! -SparkType2Default = { - StringType: '', - BinaryType: '', - BooleanType: 0, - # DateType: '0001-01-03', # https://issues.apache.org/jira/browse/SPARK-22182 - # TimestampType: '0001-01-03 00:29:43', # https://issues.apache.org/jira/browse/SPARK-22182 - # DateType: '0001-01-01', - # TimestampType: '0001-01-01 00:00:00', - DoubleType: 0.0, - FloatType: 0.0, - ByteType: 0, - IntegerType: 0, - LongType: 0, - ShortType: 0, - DecimalType: 0.0, -} - - -def check_date_columns_for_nulls(df: DataFrame) -> bool: - """ returns True if any Date or Timestamp column consist NULL value """ - expr_list = [] - for field in df.schema.fields: - name = field.name - spark_type = type(field.dataType) - if spark_type in [DateType, TimestampType]: - expr_list.append(F.isnull(F.col(name))) - if not expr_list: - return False - expr = reduce(lambda x, y: x | y, expr_list) - df_nulls = df.filter(expr) - return not df_nulls.rdd.isEmpty() - - -def smart_ch_fillna(df: DataFrame) -> DataFrame: - """ change null-value to default values """ - mapping = {} - if check_date_columns_for_nulls(df): - raise Exception('Date and Timestamp columns mustn\'t be null!') - for field in df.schema.fields: - name = field.name - spark_type = type(field.dataType) - if spark_type in [DateType, TimestampType]: - continue - default_value = SparkType2Default[spark_type] - mapping[name] = default_value - df = df.fillna(mapping) - return df diff --git a/pyspark_db_utils/ch/write_to_ch.py b/pyspark_db_utils/ch/write_to_ch.py deleted file mode 100644 index 3359d6a..0000000 --- a/pyspark_db_utils/ch/write_to_ch.py +++ /dev/null @@ -1,117 +0,0 @@ -from typing import Union - -from infi.clickhouse_orm.database import Database, DatabaseException -from infi.clickhouse_orm.models import ModelBase -from pyspark.sql.types import DateType - -from pyspark_db_utils.ch.make_ch_model_for_df import make_ch_model_for_df -from pyspark_db_utils.ch.smart_ch_fillna import smart_ch_fillna - - -class CustomDatabase(Database): - """ ClickHouse database with useful functions """ - @staticmethod - def table2table_name(table: Union[str, ModelBase]) -> str: - """ get table_name for table - - Args: - table: may be string of table_name or db Model class - """ - if isinstance(table, ModelBase): - table_name = table.table_name() - elif isinstance(table, str): - table_name = table - else: - raise TypeError - return table_name - - def check_table_exist(self, table: Union[str, ModelBase]) -> bool: - """ check if table exists - - Args: - table: table to check - """ - table_name = self.table2table_name(table) - try: - # TODO: use EXISTS statement in CHSQL - resp = self.raw('select * from {} limit 0'.format(table_name)) - assert resp == '' - except DatabaseException: - exists = False - else: - exists = True - return exists - - def describe(self, table: Union[ModelBase, str]) -> str: - """ Returns result for DESCRIBE statement on table - - Args: - table: table - - Returns: - describe table - - Examples: - example of output:: - - plu Int64 - shop_id Int64 - check_date_time DateTime - clickhouse_date Date - created DateTime - type UInt8 - """ - table_name = self.table2table_name(table) - resp = self.raw('describe table {}'.format(table_name)) - return resp - - -def make_sure_exsit(df, date_field_name, table_name, mode, config, logger, pk_columns=None): - """ drop and create table if need """ - Model = make_ch_model_for_df(df, date_field_name, table_name, pk_columns=pk_columns) - db = CustomDatabase(db_name=config["CH_DB_NAME"], db_url=config["CH_URL"]) - if db.check_table_exist(Model): - if mode == 'fail': - raise Exception('table {}.{} already exists and mode={}'.format(db.db_name, Model.table_name(), mode)) - elif mode == 'overwrite': - db.drop_table(Model) - logger.info('DROP TABLE {}.{}'.format(db.db_name, table_name)) - db.create_table(Model) - logger.info('CREATE TABLE {}.{}'.format(db.db_name, table_name)) - elif mode == 'append': - pass - else: - db.create_table(Model) - logger.info('CREATE TABLE {}.{}'.format(db.db_name, table_name)) - db.describe(Model) - - -def write_to_ch(df, date_field_name, table_name, mode, config, logger, pk_columns=None) -> None: - """ Dumps PySpark DataFrame to ClickHouse, create or recreate table if needed. - - Args: - df: PySpark DataFrame - mode: describe, what do if table already exists - - must be one of 'overwrite' / 'append' / 'fail': - - - overwrite: drop and create table and insert rows (CH hasn't truncate operator) - - append: insert rows to exist table - - fail: raise Exception - - table_name: table name - date_field_name: date field for partitioning - pk_columns: list/tuple of primary key columns (None for all columns) - """ - assert mode in ['overwrite', 'append', 'fail'], "mode must be 'overwrite' / 'append' / 'fail'" - assert '.' not in table_name, 'dots are not allowed in table_name' - date_field = next(field for field in df.schema.fields if field.name == date_field_name) - assert type(date_field.dataType) == DateType, \ - "df['{}'].dataType={} must be DateType".format(date_field_name, date_field.dataType) - make_sure_exsit(df, date_field_name, table_name, mode, config=config, logger=logger, pk_columns=pk_columns) - full_table_name = '{}.{}'.format(config["CH_DB_NAME"], table_name) - # Spark JDBC CH Driver works correctly only in append mode - # and without NULL-s - df = smart_ch_fillna(df) - df.write.jdbc(url=config['CH_JDBC_URL'], table=full_table_name, mode='append', - properties=config['CH_JDBC_PROPERTIES']) diff --git a/pyspark_db_utils/example.py b/pyspark_db_utils/example.py deleted file mode 100644 index 5c39de1..0000000 --- a/pyspark_db_utils/example.py +++ /dev/null @@ -1,86 +0,0 @@ -""" It's just simple example of using lib - It asks you about DB connection parameters, makes DF, writes to DB, loads it back and shows. -""" - -import os - -from pyspark.sql import SparkSession -from pyspark import SparkContext, SparkConf -import pyspark.sql.functions as F - -from pyspark_db_utils.pg import write_to_pg, read_from_pg - - -SPARK_CONFIG = { - "MASTER": "local[*]", - "settings": { - "spark.executor.cores": "1", - "spark.executor.memory": "1g", - "spark.driver.cores": "1", - "spark.driver.memory": "1g", - "spark.cores.max": "1" - } -} - - -def get_pg_config() -> dict: - """ Ask DB connections params""" - host = input('host: ') - db = input('db: ') - user = input('user: ') - password = input('password: ') - - return { - "PG_PROPERTIES": { - "user": user, - "password": password, - "driver": "org.postgresql.Driver" - }, - "PG_DRIVER_PATH": "jars/postgresql-42.1.4.jar", - "PG_URL": "jdbc:postgresql://{host}/{db}".format(host=host, db=db), - } - - -def init_spark_context(appname: str) -> SparkContext: - """ init spark context """ - os.environ['PYSPARK_SUBMIT_ARGS'] = '--jars jars/postgresql-42.1.4.jar pyspark-shell' - conf = SparkConf() - conf.setMaster(SPARK_CONFIG['MASTER']) - conf.setAppName(appname) - - for setting, value in SPARK_CONFIG['settings'].items(): - conf.set(setting, value) - - sc = SparkContext(conf=conf) - - return sc - - -def main(spark) -> None: - """ run example """ - PG_CONFIG = get_pg_config() - - print('TRY: create df') - df = spark.range(1, 20, 1, 4).withColumn('mono_id', F.monotonically_increasing_id()) - print('OK: create df') - df.show() - - print('') - - print('TRY: write_to_pg') - write_to_pg(df=df, config=PG_CONFIG, table='test_table') - print('OK: write_to_pg') - - print('') - - print('TRY: read_from_pg') - df_loaded = read_from_pg(config=PG_CONFIG, sql='test_table', sc=sc) - print('OK: read_from_pg') - df_loaded.show() - - -if __name__ == '__main__': - sc = init_spark_context('app') - spark = SparkSession(sc) - main(spark) - spark.stop() diff --git a/pyspark_db_utils/get_jars.py b/pyspark_db_utils/get_jars.py new file mode 100644 index 0000000..02588e9 --- /dev/null +++ b/pyspark_db_utils/get_jars.py @@ -0,0 +1,15 @@ +from typing import List +import os + +DEFAULT_JARS_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'jars') + + +def get_jars() -> List[str]: + """ get list of jars paths + if os.environ variable JARS_PATH is set, find all in thise path + elsewhere find all jars inside `jars` directory + """ + jars_path = os.environ.get('JARS_PATH', DEFAULT_JARS_PATH) + jars = [os.path.join(jars_path, jar) + for jar in os.listdir(jars_path)] + return jars diff --git a/pyspark_db_utils/jars/clickhouse.jar b/pyspark_db_utils/jars/clickhouse.jar new file mode 100644 index 0000000..a6a0cbc Binary files /dev/null and b/pyspark_db_utils/jars/clickhouse.jar differ diff --git a/pyspark_db_utils/jars/postgresql.jar b/pyspark_db_utils/jars/postgresql.jar new file mode 100644 index 0000000..08a54b1 Binary files /dev/null and b/pyspark_db_utils/jars/postgresql.jar differ diff --git a/pyspark_db_utils/mogrify.py b/pyspark_db_utils/mogrify.py new file mode 100644 index 0000000..68fc849 --- /dev/null +++ b/pyspark_db_utils/mogrify.py @@ -0,0 +1,39 @@ +import json +import datetime +import string + + +def mogrify(val) -> str: + """ cast python values to raw-sql correctly and escape if necessary + + Args: + val: some value + + Returns: + mogrified value + """ + if isinstance(val, (tuple, list, dict)): + return json.dumps(val) # are you sure? + elif isinstance(val, str): + escaped = val.replace("'", "''") + return "'{}'".format(escaped) + elif isinstance(val, (int, float)): + return str(val) + elif isinstance(val, datetime.datetime): + return "'{}'::TIMESTAMP".format(val) + elif isinstance(val, datetime.date): + return "'{}'::DATE".format(val) + elif val is None: + return 'NULL' + else: + raise TypeError('unknown type {} for mogrify'.format(type(val))) + + +class MogrifyFormatter(string.Formatter): + """ custom formatter to mogrify {}-like formatting strings """ + def get_value(self, key, args, kwargs) -> str: + row = args[0] + return mogrify(row[key]) + + +mogrifier = MogrifyFormatter() diff --git a/pyspark_db_utils/read_from_db.py b/pyspark_db_utils/read_from_db.py new file mode 100644 index 0000000..65a7f03 --- /dev/null +++ b/pyspark_db_utils/read_from_db.py @@ -0,0 +1,28 @@ +from pyspark_db_utils.utils.spark.get_spark_con_params import get_spark_con_params +from pyspark_db_utils.utils.spark.get_or_create_spark_session import get_or_create_spark_session + + +def read_from_db(con_info, table, spark_session=None, **kwargs): + """ + kwargs: kw arguments of pyspark.sql.readwriter.DataFrameReader.jdbc() + """ + if spark_session is None: + spark_session = get_or_create_spark_session() + + return spark_session.read.jdbc(table=table, + **get_spark_con_params(con_info), + **kwargs) + + +def read_from_ch(config, **kwargs): + """ + read DataFrame from ClickHouse + """ + return read_from_db(config['clickhouse'], **kwargs) + + +def read_from_pg(config, **kwargs): + """ + read DataFrame from PostgreSQL + """ + return read_from_db(config['postgresql'], **kwargs) diff --git a/pyspark_db_utils/utils/db_connect.py b/pyspark_db_utils/utils/db_connect.py new file mode 100644 index 0000000..24da9ca --- /dev/null +++ b/pyspark_db_utils/utils/db_connect.py @@ -0,0 +1,29 @@ +from contextlib import contextmanager + +import jaydebeapi + +from pyspark_db_utils.get_jars import get_jars + + +@contextmanager +def db_connect(con_info: dict, autocommit: bool=False): + """ context manager, opens and closes connection correctly + + Args: + con_info: con_info for db + autocommit: if False (default), disable autocommit mode + if True, enable autocommit mode + + Yields: + connection, cursor + """ + conn = jaydebeapi.connect(jclassname=con_info['driver'], + url='jdbc:{dbtype}://{host}/{dbname}'.format(**con_info), + driver_args={k: con_info[k] for k in ('user', 'password')}, + jars=get_jars() # jaydebeapi create java-process once, so add all possible drivers + ) + conn.jconn.setAutoCommit(autocommit) + curs = conn.cursor() + yield conn, curs + curs.close() + conn.close() diff --git a/pyspark_db_utils/utils/drop_columns.py b/pyspark_db_utils/utils/drop_columns.py deleted file mode 100644 index 0431d71..0000000 --- a/pyspark_db_utils/utils/drop_columns.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import List - -from pyspark.sql import DataFrame - - -def drop_other_columns(df: DataFrame, template_schema: List[str]) -> DataFrame: - """ drop all df columns that are absent in template_schema """ - columns_to_drop = set(df.schema.names) - set(template_schema) - for column in columns_to_drop: - df = df.drop(column) - return df diff --git a/pyspark_db_utils/utils/ensure_columns_in_table.py b/pyspark_db_utils/utils/ensure_columns_in_table.py new file mode 100644 index 0000000..6f783d0 --- /dev/null +++ b/pyspark_db_utils/utils/ensure_columns_in_table.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pyspark.sql import DataFrame, SparkSession + +from pyspark_db_utils.utils.get_field_names import get_field_names +from pyspark_db_utils.utils.table_exists import table_exists + + +def ensure_columns_in_table(con_info: dict, + df: DataFrame, + table: str, + spark_session: Optional[SparkSession] = None + ) -> DataFrame: + """ drop columns in `df` which does not exist in DB table, + do nothing if DB table does not exist + """ + if not table_exists(con_info=con_info, table=table): + return df + table_columns = get_field_names(con_info=con_info, table=table, spark_session=spark_session) + columns_to_drop = set(df.columns) - set(table_columns) + df = df.drop(*columns_to_drop) + return df diff --git a/pyspark_db_utils/utils/ensure_from_expr.py b/pyspark_db_utils/utils/ensure_from_expr.py new file mode 100644 index 0000000..043a6a4 --- /dev/null +++ b/pyspark_db_utils/utils/ensure_from_expr.py @@ -0,0 +1,36 @@ +def ensure_from_expr(sql: str) -> str: + """ + Check that sql is either 'table_name' or '(select * from table) as foo' expression + Wraps in '({sql}) as foo'.format(sql) elsewhere + Args: + sql: sql to ensure + + Returns: + original sql or wrapped sql + + Examples: + >>> ensure_from_expr('data.posdata') + 'data.posdata' + >>> ensure_from_expr(' ( select * from data.posdata ) as foo ') + ' ( select * from data.posdata ) as foo ' + >>> ensure_from_expr('select * from data.posdata') # wraps in this case + '(select * from data.posdata) as foo' + + Warnings: + it is not really strict error-sensitive check, it's just hack covered all common use cases + """ + tokens = [t for t in sql.split() if t] + if len(tokens) == 1: + # sql is just table_name + return sql + elif len(tokens) < 3: + # its impossible because it could not be correct SQL + # but... ok, lets leave it on conscience of developer + return sql + else: + # TODO: use regexp + correct_open_barcket = tokens[0][0] == '(' + correct_close_barcket = tokens[-3][-1] == ')' + correct_as = tokens[-2].upper() == 'AS' + if not correct_open_barcket or not correct_close_barcket or not correct_as: + return '({sql}) as foo'.format(sql=sql) diff --git a/pyspark_db_utils/utils/execute_sql.py b/pyspark_db_utils/utils/execute_sql.py new file mode 100644 index 0000000..235c662 --- /dev/null +++ b/pyspark_db_utils/utils/execute_sql.py @@ -0,0 +1,19 @@ +from typing import Optional, List, Tuple + +from pyspark_db_utils.utils.db_connect import db_connect + + +def execute_sql(con_info, sql) -> Optional[List[Tuple]]: + """ execute sql and return rows if possible + + Args: + con_info: db connection info + sql: sql to execute + + Returns: + fetched + """ + with db_connect(con_info, autocommit=True) as (con, cur): + cur.execute(sql) + if cur.description: + return cur.fetchall() diff --git a/pyspark_db_utils/utils/get_field_names.py b/pyspark_db_utils/utils/get_field_names.py new file mode 100644 index 0000000..d1d6b76 --- /dev/null +++ b/pyspark_db_utils/utils/get_field_names.py @@ -0,0 +1,16 @@ +from typing import Optional, List + +from pyspark.sql import SparkSession + +from pyspark_db_utils.read_from_db import read_from_db + + +def get_field_names(con_info: dict, + table: str, + spark_session: Optional[SparkSession]=None) -> List[str]: + """ get field names of table + TODO: remove spark + """ + sql = '(select * from {} limit 0) as foo'.format(table) + df_tmp = read_from_db(con_info, sql, spark_session=spark_session) + return df_tmp.columns diff --git a/pyspark_db_utils/utils/spark/__init__.py b/pyspark_db_utils/utils/spark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyspark_db_utils/utils/spark/get_or_create_spark_session.py b/pyspark_db_utils/utils/spark/get_or_create_spark_session.py new file mode 100644 index 0000000..94f6988 --- /dev/null +++ b/pyspark_db_utils/utils/spark/get_or_create_spark_session.py @@ -0,0 +1,6 @@ +from pyspark.sql import SparkSession + + +def get_or_create_spark_session() -> SparkSession: + """ get or create spark session """ + return SparkSession.builder.config().getOrCreate() diff --git a/pyspark_db_utils/utils/spark/get_spark_con_params.py b/pyspark_db_utils/utils/spark/get_spark_con_params.py new file mode 100644 index 0000000..cb8f36f --- /dev/null +++ b/pyspark_db_utils/utils/spark/get_spark_con_params.py @@ -0,0 +1,5 @@ +def get_spark_con_params(con_info): + """ get DB jdbc connection string from config `con_info` """ + return dict(url='jdbc:{dbtype}://{host}/{dbname}'.format(**con_info), + properties={k: con_info[k] + for k in ('user', 'password', 'driver')}) diff --git a/pyspark_db_utils/utils/spark/get_spark_conf.py b/pyspark_db_utils/utils/spark/get_spark_conf.py new file mode 100644 index 0000000..2c6fada --- /dev/null +++ b/pyspark_db_utils/utils/spark/get_spark_conf.py @@ -0,0 +1,78 @@ +import os +import json + +import pyspark + +from pyspark_db_utils.get_jars import get_jars +from pyspark_db_utils.utils.execute_sql import execute_sql +from pyspark_db_utils.mogrify import mogrify + +CONFIGS_TABLE = 'analytics.configs' + + +def load_config_from_table(all_con_info, + key=None, + config_id=None, + table_name=CONFIGS_TABLE) -> dict: + """ load spark settings from DB """ + assert bool(key) ^ bool(config_id) + + if config_id is None: + config_id = execute_sql(all_con_info['postgresql'], ''' + SELECT MAX(id) + FROM {table} + WHERE key = {key} + '''.format(table=table_name, + key=mogrify(key)))[0][0] + value = execute_sql(all_con_info['postgresql'], ''' + SELECT value + FROM {table} + WHERE id = {id} + LIMIT 1 + '''.format(table=table_name, + id=mogrify(config_id))) + if not value: + raise ValueError( + 'No value in {table} for key={key!r} and config_id={id}'.format( + table=table_name, + key=key, + id=config_id)) + value = json.loads(str(value[0][0])) + return value + + +def get_spark_conf(spark_config=None, + con_info=None, + config_key=None, + config_id=None, + configs_table=CONFIGS_TABLE, + app_name=None, + master=None): + assert bool(spark_config) ^ bool(con_info), 'spark_config or con_info must be set' + + if spark_config: + master = spark_config['MASTER'] + values = spark_config['settings'] + jars = get_jars() + elif con_info: + values = {} + if con_info and (config_key or config_id): + values = load_config_from_table(con_info, config_key, config_id, + configs_table) + values = dict(values) + config_jars = {os.path.abspath(jar.replace('file://', '')) + for jar in values.get('spark.jars', '').split(',') + if jar} + jars = config_jars.union(set(get_jars())) + else: + raise ValueError('spark_config or con_info must be set') + + values['spark.jars'] = ','.join(jars) + values['spark.sql.execution.arrow.enabled'] = True + conf = pyspark.SparkConf() + conf.setAll(list(values.items())) + if app_name is not None: + conf = conf.setAppName(app_name) + if master is not None: + conf = conf.setMaster(master) + return conf diff --git a/pyspark_db_utils/utils/table_exists.py b/pyspark_db_utils/utils/table_exists.py new file mode 100644 index 0000000..fe569f7 --- /dev/null +++ b/pyspark_db_utils/utils/table_exists.py @@ -0,0 +1,15 @@ +import jaydebeapi +from pyspark_db_utils.utils.execute_sql import execute_sql + + +def table_exists(con_info, table) -> bool: + """ table exists or not + + Notes: + actually, just select LIMIT 0 from table and returns was it successful or not + """ + try: + execute_sql(con_info, 'SELECT * FROM {} LIMIT 0'.format(table)) + return True + except jaydebeapi.DatabaseError: + return False diff --git a/pyspark_db_utils/write_to_db.py b/pyspark_db_utils/write_to_db.py new file mode 100644 index 0000000..92ab7ba --- /dev/null +++ b/pyspark_db_utils/write_to_db.py @@ -0,0 +1,24 @@ +from pyspark_db_utils.utils.ensure_columns_in_table import ensure_columns_in_table +from pyspark_db_utils.utils.spark.get_spark_con_params import get_spark_con_params + + +def write_to_db(con_info, df, table, mode='error'): + """ + Note: With ClickHouse only mode='append' is supported + + mode: specifies the behavior of the save operation when data already exists. + * append: Append contents of this :class:`DataFrame` to existing data. + * overwrite: Overwrite existing data. + * ignore: Silently ignore this operation if data already exists. + * error: Throw an exception if data already exists. + """ + df = ensure_columns_in_table(df=df, con_info=con_info, table=table) + df.write.jdbc(table=table, mode=mode, **get_spark_con_params(con_info)) + + +def write_to_pg(config, **kwargs): + return write_to_db(config['postgresql'], **kwargs) + + +def write_to_ch(config, **kwargs): + return write_to_db(config['clickhouse'], **kwargs) diff --git a/requirements.txt b/requirements.txt index 8e02d61..17b5666 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,4 @@ sphinx-autodoc-typehints==1.2.5 # pyspark==2.2.0 JayDeBeApi==1.1.1 infi.clickhouse-orm==0.9.4 -psycopg2==2.7.4 -testing.postgresql==1.3.0 +# testing.postgresql==1.3.0 diff --git a/setup.py b/setup.py index 8baa3df..caebc70 100644 --- a/setup.py +++ b/setup.py @@ -14,12 +14,16 @@ long_description = f.read() setup(name='pyspark_db_utils', - version='0.0.5', + version='0.0.7', description='Usefull functions for working with Database in PySpark (PostgreSQL, ClickHouse)', url='https://github.com/osahp/pyspark_db_utils', author='Vladimir Smelov', - author_email='vladimirfol@gmail.com', - packages=['pyspark_db_utils', 'pyspark_db_utils.ch', 'pyspark_db_utils.utils'], + author_email='vsmelov@vsmelov.ru', + packages=[ + 'pyspark_db_utils', + 'pyspark_db_utils.utils', + 'pyspark_db_utils.utils.spark', + ], classifiers=[ 'Development Status :: 3 - Alpha', 'Programming Language :: Python :: 3', diff --git a/tests/test_ensure_from_expr.py b/tests/test_ensure_from_expr.py new file mode 100644 index 0000000..a14b5f0 --- /dev/null +++ b/tests/test_ensure_from_expr.py @@ -0,0 +1,27 @@ +from unittest import TestCase +from pyspark_db_utils.utils.ensure_from_expr import ensure_from_expr + + +class TestEnsureFromExpr(TestCase): + def test_table_name(self): + sql = 'data.posdata' + self.assertEqual(ensure_from_expr(sql), sql) + + def test_select_as(self): + sql = '(select * from data.posdata) as foo' + self.assertEqual(ensure_from_expr(sql), sql) + + def test_select_as_spaces(self): + sql = ' ( select * from data.posdata ) as foo ' + self.assertEqual(ensure_from_expr(sql), sql) + + def test_select_as_multiline(self): + sql = '''( + select * from data.posdata + ) as foo ''' + self.assertEqual(ensure_from_expr(sql), sql) + + def test_select(self): + sql = '''select * from data.posdata''' + sql_as = '''({sql}) as foo'''.format(sql=sql) + self.assertEqual(ensure_from_expr(sql), sql_as)