Skip to content

Feature (TokenTextCplitter): Add overlapping function for text segmentation #2123 #3780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import org.springframework.util.Assert;

/**
* A {@link TextSplitter} that splits text into chunks of a target size in tokens.
* A {@link TextSplitter} that splits text into chunks of a target size in tokens. Now
* supports overlapping tokens between chunks.
*
* @author Raphael Yu
* @author Christian Tzolov
* @author Ricken Bazolo
* @author Enginner JiaXing
*/
public class TokenTextSplitter extends TextSplitter {

Expand All @@ -46,39 +48,42 @@ public class TokenTextSplitter extends TextSplitter {

private static final boolean KEEP_SEPARATOR = true;

private static final int DEFAULT_OVERLAP_SIZE = 0;

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();

private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);

// The target size of each text chunk in tokens
private final int chunkSize;

// The minimum size of each text chunk in characters
private final int minChunkSizeChars;

// Discard chunks shorter than this
private final int minChunkLengthToEmbed;

// The maximum number of chunks to generate from a text
private final int maxNumChunks;

private final boolean keepSeparator;

private final int overlapSize;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR,
DEFAULT_OVERLAP_SIZE);
}

public TokenTextSplitter(boolean keepSeparator) {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator,
DEFAULT_OVERLAP_SIZE);
}

public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
boolean keepSeparator, int overlapSize) {
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
this.keepSeparator = keepSeparator;
this.overlapSize = overlapSize;
}

public static Builder builder() {
Expand All @@ -97,59 +102,52 @@ protected List<String> doSplit(String text, int chunkSize) {

List<Integer> tokens = getEncodedTokens(text);
List<String> chunks = new ArrayList<>();

int start = 0;
int num_chunks = 0;
while (!tokens.isEmpty() && num_chunks < this.maxNumChunks) {
List<Integer> chunk = tokens.subList(0, Math.min(chunkSize, tokens.size()));

while (start < tokens.size() && num_chunks < this.maxNumChunks) {
int end = Math.min(start + chunkSize, tokens.size());
List<Integer> chunk = tokens.subList(start, end);
String chunkText = decodeTokens(chunk);

// Skip the chunk if it is empty or whitespace
if (chunkText.trim().isEmpty()) {
tokens = tokens.subList(chunk.size(), tokens.size());
start = end;
continue;
}

// Find the last period or punctuation mark in the chunk
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));

if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
// Truncate the chunk text at the punctuation mark
chunkText = chunkText.substring(0, lastPunctuation + 1);
}

String chunkTextToAppend = (this.keepSeparator) ? chunkText.trim()
String chunkTextToAppend = this.keepSeparator ? chunkText.trim()
: chunkText.replace(System.lineSeparator(), " ").trim();

if (chunkTextToAppend.length() > this.minChunkLengthToEmbed) {
chunks.add(chunkTextToAppend);
num_chunks++;
}

// Remove the tokens corresponding to the chunk text from the remaining tokens
tokens = tokens.subList(getEncodedTokens(chunkText).size(), tokens.size());

num_chunks++;
}

// Handle the remaining tokens
if (!tokens.isEmpty()) {
String remaining_text = decodeTokens(tokens).replace(System.lineSeparator(), " ").trim();
if (remaining_text.length() > this.minChunkLengthToEmbed) {
chunks.add(remaining_text);
}
// Move start forward by chunkSize - overlapSize to allow overlap
start += chunkSize - this.overlapSize;
}

return chunks;
}

private List<Integer> getEncodedTokens(String text) {
List<Integer> getEncodedTokens(String text) {
Assert.notNull(text, "Text must not be null");
return this.encoding.encode(text).boxed();
}

private String decodeTokens(List<Integer> tokens) {
Assert.notNull(tokens, "Tokens must not be null");
var tokensIntArray = new IntArrayList(tokens.size());
tokens.forEach(tokensIntArray::add);
return this.encoding.decode(tokensIntArray);
IntArrayList tokenArray = new IntArrayList(tokens.size());
tokens.forEach(tokenArray::add);
return this.encoding.decode(tokenArray);
}

public static final class Builder {
Expand All @@ -164,6 +162,8 @@ public static final class Builder {

private boolean keepSeparator = KEEP_SEPARATOR;

private int overlapSize = DEFAULT_OVERLAP_SIZE;

private Builder() {
}

Expand Down Expand Up @@ -192,9 +192,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
return this;
}

public Builder withOverlapSize(int overlapSize) {
this.overlapSize = overlapSize;
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
this.maxNumChunks, this.keepSeparator, this.overlapSize);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@
import java.util.Map;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.Document;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
* @author Ricken Bazolo
*/
public class TokenTextSplitterTest {

private final String SAMPLE_TEXT = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
+ "Vestibulum volutpat augue et turpis facilisis, id porta ligula interdum. "
+ "Proin condimentum justo sed lectus fermentum, a pretium orci iaculis. "
+ "Mauris nec pharetra libero. Nulla facilisi. Sed consequat velit id eros volutpat dignissim.";

@Test
public void testTokenTextSplitterBuilderWithDefaultValues() {

Expand Down Expand Up @@ -112,4 +119,83 @@ public void testTokenTextSplitterBuilderWithAllFields() {
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

@Test
void testSplitWithOverlap() {
TokenTextSplitter splitter = TokenTextSplitter.builder()
.withChunkSize(40)
.withOverlapSize(10)
.withMinChunkLengthToEmbed(5)
.build();

List<String> chunks = splitter.splitText(SAMPLE_TEXT);

assertNotNull(chunks);
assertTrue(chunks.size() > 1, "Text should be split into multiple chunks");

// Compare overlapping tokens between consecutive chunks
List<Integer> allTokens = splitter.getEncodedTokens(SAMPLE_TEXT);

for (int i = 1; i < chunks.size(); i++) {
List<Integer> prevTokens = splitter.getEncodedTokens(chunks.get(i - 1));
List<Integer> currTokens = splitter.getEncodedTokens(chunks.get(i));

int overlap = getOverlapSize(prevTokens, currTokens);

// Allow some deviation due to punctuation or sentence trimming
assertTrue(overlap >= 5 && overlap <= 15,
"Expected ~10 overlapping tokens between chunks, but got " + overlap);
}
}

@Test
void testSplitWithoutOverlap() {
TokenTextSplitter splitter = TokenTextSplitter.builder().withChunkSize(40).withOverlapSize(0).build();

List<String> chunks = splitter.splitText(SAMPLE_TEXT);

assertNotNull(chunks);
assertTrue(chunks.size() > 1);

for (int i = 1; i < chunks.size(); i++) {
List<Integer> prev = splitter.getEncodedTokens(chunks.get(i - 1));
List<Integer> curr = splitter.getEncodedTokens(chunks.get(i));

assertTrue(noOverlap(prev, curr), "There should be no overlap between chunks");
}
}

@Test
void testEmptyText() {
TokenTextSplitter splitter = TokenTextSplitter.builder().withChunkSize(50).withOverlapSize(10).build();

List<String> chunks = splitter.splitText(" ");
assertTrue(chunks.isEmpty(), "Empty or whitespace-only input should return no chunks");
}

/**
* Calculate the number of overlapping tokens between the end of the previous chunk
* and the start of the current chunk.
*/
private int getOverlapSize(List<Integer> prev, List<Integer> curr) {
int maxOverlap = Math.min(prev.size(), curr.size());
for (int i = maxOverlap; i > 0; i--) {
if (prev.subList(prev.size() - i, prev.size()).equals(curr.subList(0, i))) {
return i;
}
}
return 0;
}

/**
* Check whether there is no overlap between the two token lists.
*/
private boolean noOverlap(List<Integer> prev, List<Integer> curr) {
for (int len = Math.min(prev.size(), curr.size()); len > 0; len--) {
if (prev.subList(prev.size() - len, prev.size()).equals(curr.subList(0, len))) {
return false;
}
}
return true;
}

}