Skip to content

Commit

Permalink
[SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This change is a cleanup and consolidation of 3 areas related to Pandas UDFs:

1) `ArrowStreamPandasSerializer` now inherits from `ArrowStreamSerializer` and uses the base class `dump_stream`, `load_stream` to create Arrow reader/writer and send Arrow record batches.  `ArrowStreamPandasSerializer` makes the conversions to/from Pandas and converts to Arrow record batch iterators. This change removed duplicated creation of Arrow readers/writers.

2) `createDataFrame` with Arrow now uses `ArrowStreamPandasSerializer` instead of doing its own conversions from Pandas to Arrow and sending record batches through `ArrowStreamSerializer`.

3) Grouped Map UDFs now reuse existing logic in `ArrowStreamPandasSerializer` to send Pandas DataFrame results as a `StructType` instead of separating each column from the DataFrame. This makes the code a little more consistent with the Python worker, but does require that the returned StructType column is flattened out in `FlatMapGroupsInPandasExec` in Scala.

## How was this patch tested?

Existing tests and ran tests with pyarrow 0.12.0

Closes apache#24095 from BryanCutler/arrow-refactor-cleanup-UDFs.

Authored-by: Bryan Cutler <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
BryanCutler authored and HyukjinKwon committed Mar 21, 2019
1 parent b1857a4 commit be08b41
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 135 deletions.
218 changes: 119 additions & 99 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,92 +245,13 @@ def __repr__(self):
return "ArrowStreamSerializer"


def _create_batch(series, timezone, safecheck, assign_cols_by_name):
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:param timezone: A timezone to respect when handling timestamp values
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])


class ArrowStreamPandasSerializer(Serializer):
"""
Serializes Pandas.Series as Arrow data with Arrow streaming format.
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
"""

def __init__(self, timezone, safecheck, assign_cols_by_name):
Expand All @@ -347,39 +268,138 @@ def arrow_to_pandas(self, arrow_column):
s = _check_series_localize_timestamps(s, self._timezone)
return s

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series or list of Series,
with optional type.
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
:return: Arrow RecordBatch
"""
import decimal
from distutils.version import LooseVersion
import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Ensure timestamp series are in expected form for Spark internal representation
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
elif t is not None and pa.types.is_string(t) and sys.version < '3':
# TODO: need decode before converting to Arrow in Python 2
# TODO: don't need as of Arrow 0.9.1
return pa.Array.from_pandas(s.apply(
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
elif t is not None and pa.types.is_decimal(t) and \
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
return pa.Array.from_pandas(s, mask=mask, type=t)

try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
"conversions warned by Arrow. Arrow safe type check can be " + \
"disabled by using SQL config " + \
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
raise RuntimeError(error_msg % (s.dtype, t), e)
return array

arrs = []
for s, t in series:
if t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s)))

# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
if len(s) == 0 and len(s.columns) == 0:
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
# Assign result columns by schema name if user labeled with strings
elif self._assign_cols_by_name and any(isinstance(name, basestring)
for name in s.columns):
arrs_names = [(create_array(s[field.name], field.type), field.name)
for field in t]
# Assign result columns by position
else:
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)]

struct_arrs, struct_names = zip(*arrs_names)

# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
else:
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
else:
arrs.append(create_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])

def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
import pyarrow as pa
writer = None
try:
for series in iterator:
batch = _create_batch(series, self._timezone, self._safecheck,
self._assign_cols_by_name)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
finally:
if writer is not None:
writer.close()
batches = (self._create_batch(series) for series in iterator)
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)

def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
import pyarrow as pa
reader = pa.ipc.open_stream(stream)

for batch in reader:
for batch in batches:
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]

def __repr__(self):
return "ArrowStreamPandasSerializer"


class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
"""
Serializer used by Python worker to evaluate Pandas UDFs
"""

def dump_stream(self, iterator, stream):
"""
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
This should be sent after creating the first record batch so in case of an error, it can
be sent back to the JVM before the Arrow stream starts.
"""

def init_stream_yield_batches():
should_write_start_length = True
for series in iterator:
batch = self._create_batch(series)
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False
yield batch

return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)

def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
42 changes: 24 additions & 18 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,15 +530,29 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowStreamSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
from distutils.version import LooseVersion
from pyspark.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()

from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
import pyarrow as pa

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
arrow_schema = temp_batch.schema
else:
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
schema = struct

# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
Expand All @@ -555,32 +569,24 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))

# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone, safecheck, col_by_name)
for pdf_slice in pdf_slices]

# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
if isinstance(schema, (list, tuple)):
struct = from_arrow_schema(batches[0].schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create list of Arrow (columns, type) for serializer dump_stream
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]

jsqlContext = self._wrapped._jsqlContext

safecheck = self._wrapped._conf.arrowSafeTypeConversion()
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)

def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)

def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
create_RDD_server)
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
Expand Down
23 changes: 7 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
BatchedSerializer, ArrowStreamPandasUDFSerializer
from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
Expand Down Expand Up @@ -103,10 +103,7 @@ def verify_result_length(*a):
return lambda *a: (verify_result_length(*a), arrow_return_type)


def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
assign_cols_by_name = runner_conf.get(
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")
assign_cols_by_name = assign_cols_by_name.lower() == "true"
def wrap_grouped_map_pandas_udf(f, return_type, argspec):

def wrapped(key_series, value_series):
import pandas as pd
Expand All @@ -125,15 +122,9 @@ def wrapped(key_series, value_series):
"Number of columns of the returned pandas.DataFrame "
"doesn't match specified schema. "
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
return result

# Assign result columns by schema name if user labeled with strings, else use position
if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns):
return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type]
else:
return [(result[result.columns[i]], to_arrow_type(field.dataType))
for i, field in enumerate(return_type)]

return wrapped
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]


def wrap_grouped_agg_pandas_udf(f, return_type):
Expand Down Expand Up @@ -227,7 +218,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = _get_argspec(row_func) # signature was lost when wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
Expand Down Expand Up @@ -257,12 +248,12 @@ def read_udfs(pickleSer, infile, eval_type):
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
"false").lower() == 'true'
# NOTE: this is duplicated from wrap_grouped_map_pandas_udf
# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
assign_cols_by_name = runner_conf.get(
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
.lower() == "true"

ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name)
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down
Loading

0 comments on commit be08b41

Please sign in to comment.