Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _from_pyarrow_array(self, pa_array) -> Self:
raise NotImplementedError

def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError
Expand All @@ -50,31 +53,31 @@ def _str_len(self):
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))
return self._from_pyarrow_array(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))
return self._from_pyarrow_array(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_pad(
self,
Expand Down Expand Up @@ -104,7 +107,9 @@ def _str_pad(
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))
return self._from_pyarrow_array(
pa_pad(self._pa_array, width=width, padding=fillchar)
)

def _str_get(self, i: int) -> Self:
lengths = pc.utf8_length(self._pa_array)
Expand All @@ -124,15 +129,17 @@ def _str_get(self, i: int) -> Self:
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
) -> Self:
if pa_version_under13p0:
# GH#59724
result = self._apply_elementwise(lambda val: val[start:stop:step])
return type(self)(pa.chunked_array(result, type=self._pa_array.type))
return self._from_pyarrow_array(
pa.chunked_array(result, type=self._pa_array.type)
)
if start is None:
if step is not None and step < 0:
# GH#59710
Expand All @@ -141,7 +148,7 @@ def _str_slice(
start = 0
if step is None:
step = 1
return type(self)(
return self._from_pyarrow_array(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

Expand All @@ -154,7 +161,9 @@ def _str_slice_replace(
start = 0
if stop is None:
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))
return self._from_pyarrow_array(
pc.utf8_replace_slice(self._pa_array, start, stop, repl)
)

def _str_replace(
self,
Expand All @@ -181,32 +190,32 @@ def _str_replace(
replacement=repl,
max_replacements=pa_max_replacements,
)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_capitalize(self) -> Self:
return type(self)(pc.utf8_capitalize(self._pa_array))
return self._from_pyarrow_array(pc.utf8_capitalize(self._pa_array))

def _str_title(self) -> Self:
return type(self)(pc.utf8_title(self._pa_array))
return self._from_pyarrow_array(pc.utf8_title(self._pa_array))

def _str_swapcase(self) -> Self:
return type(self)(pc.utf8_swapcase(self._pa_array))
return self._from_pyarrow_array(pc.utf8_swapcase(self._pa_array))

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
return self._from_pyarrow_array(result)
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))
return self._from_pyarrow_array(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)
return self._from_pyarrow_array(result)

def _str_startswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
Expand Down
Loading
Loading