Skip to content

Commit

Permalink
[SPARK-27839][SQL] Change UTF8String.replace() to operate on UTF8 bytes
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR significantly improves the performance of `UTF8String.replace()` by performing direct replacement over UTF8 bytes instead of decoding those bytes into Java Strings.

In cases where the search string is not found (i.e. no replacements are performed, a case which I expect to be common) this new implementation performs no object allocation or memory copying.

My implementation is modeled after `commons-lang3`'s `StringUtils.replace()` method. As part of my implementation, I needed a StringBuilder / resizable buffer, so I moved `UTF8StringBuilder` from the `catalyst` package to `unsafe`.

## How was this patch tested?

Copied tests from `StringExpressionSuite` to `UTF8StringSuite` and added a couple of new cases.

To evaluate performance, I did some quick local benchmarking by running the following code in `spark-shell` (with Java 1.8.0_191):

```scala
import org.apache.spark.unsafe.types.UTF8String

def benchmark(text: String, search: String, replace: String) {
  val utf8Text = UTF8String.fromString(text)
  val utf8Search = UTF8String.fromString(search)
  val utf8Replace = UTF8String.fromString(replace)

  val start = System.currentTimeMillis
  var i = 0
  while (i < 1000 * 1000 * 100) {
    utf8Text.replace(utf8Search, utf8Replace)
    i += 1
  }
  val end = System.currentTimeMillis

  println(end - start)
}

benchmark("ABCDEFGH", "DEF", "ZZZZ")  // replacement occurs
benchmark("ABCDEFGH", "Z", "")  // no replacement occurs
```

On my laptop this took ~54 / ~40 seconds seconds before this patch's changes and ~6.5 / ~3.8 seconds afterwards.

Closes apache#24707 from JoshRosen/faster-string-replace.

Authored-by: Josh Rosen <[email protected]>
Signed-off-by: Josh Rosen <[email protected]>
  • Loading branch information
JoshRosen committed Jun 19, 2019
1 parent fe5145e commit fc65e0f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen;
package org.apache.spark.unsafe;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.UTF8String;

Expand All @@ -34,7 +33,18 @@ public class UTF8StringBuilder {

public UTF8StringBuilder() {
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
this.buffer = new byte[16];
this(16);
}

public UTF8StringBuilder(int initialSize) {
if (initialSize < 0) {
throw new IllegalArgumentException("Size must be non-negative");
}
if (initialSize > ARRAY_MAX) {
throw new IllegalArgumentException(
"Size " + initialSize + " exceeded maximum size of " + ARRAY_MAX);
}
this.buffer = new byte[initialSize];
}

// Grows the buffer by at least `neededSize`
Expand Down Expand Up @@ -72,6 +82,17 @@ public void append(String value) {
append(UTF8String.fromString(value));
}

public void appendBytes(Object base, long offset, int length) {
grow(length);
Platform.copyMemory(
base,
offset,
buffer,
cursor,
length);
cursor += length;
}

public UTF8String build() {
return UTF8String.fromBytes(buffer, 0, totalSize());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UTF8StringBuilder;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;

Expand Down Expand Up @@ -1002,12 +1003,29 @@ public UTF8String[] split(UTF8String pattern, int limit) {
}

public UTF8String replace(UTF8String search, UTF8String replace) {
if (EMPTY_UTF8.equals(search)) {
// This implementation is loosely based on commons-lang3's StringUtils.replace().
if (numBytes == 0 || search.numBytes == 0) {
return this;
}
String replaced = toString().replace(
search.toString(), replace.toString());
return fromString(replaced);
// Find the first occurrence of the search string.
int start = 0;
int end = this.find(search, start);
if (end == -1) {
// Search string was not found, so string is unchanged.
return this;
}
// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes - search.numBytes) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(numBytes + increase);
while (end != -1) {
buf.appendBytes(this.base, this.offset + start, end - start);
buf.append(replace);
start = end + search.numBytes;
end = this.find(search, start);
}
buf.appendBytes(this.base, this.offset + start, numBytes - start);
return buf.build();
}

// TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,44 @@ public void split() {
new UTF8String[]{fromString("ab"), fromString("def,ghi,")}));
}

@Test
public void replace() {
assertEquals(
fromString("re123ace"),
fromString("replace").replace(fromString("pl"), fromString("123")));
assertEquals(
fromString("reace"),
fromString("replace").replace(fromString("pl"), fromString("")));
assertEquals(
fromString("replace"),
fromString("replace").replace(fromString(""), fromString("123")));
// tests for multiple replacements
assertEquals(
fromString("a12ca12c"),
fromString("abcabc").replace(fromString("b"), fromString("12")));
assertEquals(
fromString("adad"),
fromString("abcdabcd").replace(fromString("bc"), fromString("")));
// tests for single character search and replacement strings
assertEquals(
fromString("AbcAbc"),
fromString("abcabc").replace(fromString("a"), fromString("A")));
assertEquals(
fromString("abcabc"),
fromString("abcabc").replace(fromString("Z"), fromString("A")));
// Tests with non-ASCII characters
assertEquals(
fromString("花ab界"),
fromString("花花世界").replace(fromString("花世"), fromString("ab")));
assertEquals(
fromString("a水c"),
fromString("a火c").replace(fromString("火"), fromString("水")));
// Tests for a large number of replacements, triggering UTF8StringBuilder resize
assertEquals(
fromString("abcd").repeat(17),
fromString("a").repeat(17).replace(fromString("a"), fromString("abcd")));
}

@Test
public void levenshteinDistance() {
assertEquals(0, EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
Expand Down

0 comments on commit fc65e0f

Please sign in to comment.