Skip to content

Commit

Permalink
fix the text_splitter and make the separator customizable
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Jan 30, 2025
1 parent 735210a commit c97d720
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 127 deletions.
14 changes: 8 additions & 6 deletions adalflow/adalflow/components/data_process/text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Splitting texts is commonly used as a preprocessing step before embedding and retrieving texts.
"""Splitting texts is commonly used as a preprocessing step before embedding and retrieving texts.
We encourage you to process your data here and define your own embedding and retrieval methods. These methods can highly depend on the product environment and may extend beyond the scope of this library.
Expand All @@ -15,7 +14,7 @@
"""

from copy import deepcopy
from typing import List, Literal
from typing import List, Literal, Optional
from tqdm import tqdm
import logging

Expand Down Expand Up @@ -44,6 +43,7 @@
DEFAULT_CHUNK_OVERLAP = 200


# TODO: make it a non-component
class TextSplitter(Component):
"""
Text Splitter for Chunking Documents
Expand Down Expand Up @@ -162,6 +162,7 @@ def __init__(
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
batch_size: int = 1000,
separators: Optional[dict] = SEPARATORS,
):
"""
Initializes the TextSplitter with the specified parameters for text splitting.
Expand All @@ -181,9 +182,10 @@ def __init__(
super().__init__()

self.split_by = split_by
if split_by not in SEPARATORS:
self.separators = separators
if split_by not in self.separators:
raise ValueError(
f"Invalid options for split_by. You must select from {list(SEPARATORS.keys())}."
f"Invalid options for split_by. You must select from {list(self.separators.keys())}."
)

if chunk_overlap >= chunk_size:
Expand Down Expand Up @@ -224,7 +226,7 @@ def split_text(self, text: str) -> List[str]:
log.info(
f"Splitting text with split_by: {self.split_by}, chunk_size: {self.chunk_size}, chunk_overlap: {self.chunk_overlap}"
)
separator = SEPARATORS[self.split_by]
separator = self.separators[self.split_by]
splits = self._split_text_into_units(text, separator)
log.info(f"Text split into {len(splits)} parts.")
chunks = self._merge_units_to_chunks(
Expand Down
56 changes: 11 additions & 45 deletions tutorials/adalflow_logger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# TODO: need to be checked
"""
This script demonstrates the usage of AdalFlow's Logger functionality.
It can be run independently to showcase logging capabilities.
"""

from adalflow.components import Generator
from adalflow.core import Generator
from adalflow.components.model_client import OpenAIClient
from adalflow.utils import setup_env
from adalflow.utils.logger import get_logger
import logging
from typing import Dict, Any
import json


def setup_logging(log_file: str = "adalflow.log") -> logging.Logger:
Expand Down Expand Up @@ -66,57 +66,26 @@ def process_query(
logger.info(f"Processing query: {query}")

try:
# Generate response
response = generator.generate(prompt_kwargs={"query": query})
response = generator(prompt_kwargs={"query": query})

# Log successful response
logger.info(f"Generated response: {response}")

return {"query": query, "response": str(response), "status": "success"}

except Exception as e:
# Log error if generation fails
logger.error(f"Error processing query: {str(e)}")
return {"query": query, "response": None, "status": "error", "error": str(e)}


def analyze_logs(log_file: str, logger: logging.Logger) -> Dict[str, int]:
"""
Analyze the log file to gather statistics.
Args:
log_file: Path to the log file
logger: Logger instance for recording the analysis
Returns:
Dictionary containing log statistics
"""
stats = {"total_queries": 0, "successful_queries": 0, "failed_queries": 0}

try:
with open(log_file, "r") as f:
for line in f:
if "Processing query:" in line:
stats["total_queries"] += 1
if "Generated response:" in line:
stats["successful_queries"] += 1
if "Error processing query:" in line:
stats["failed_queries"] += 1

logger.info(f"Log analysis complete: {json.dumps(stats, indent=2)}")
return stats

except Exception as e:
logger.error(f"Error analyzing logs: {str(e)}")
return stats


def main():
"""Main function to demonstrate logger functionality."""
# Setup
log_file = "adalflow.log"
logger = setup_logging(log_file)
generator = setup_generator()
generator = Generator(
model_client=OpenAIClient(),
model_kwargs={"model": "gpt-3.5-turbo", "temperature": 0, "max_tokens": 1000},
)

# Example queries
queries = [
Expand All @@ -130,14 +99,11 @@ def main():
for query in queries:
result = process_query(generator, query, logger)
results.append(result)
print(f"\nQuery: {query}")
print(f"Response: {result['response']}")

# Analyze logs
stats = analyze_logs(log_file, logger)
print("\nLog Analysis:")
print(json.dumps(stats, indent=2))


if __name__ == "__main__":
from adalflow.utils import setup_env

setup_env()

main()
48 changes: 10 additions & 38 deletions tutorials/adalflow_modelclient_sync_and_async.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
import asyncio
import time
from adalflow.components.model_client import (
from adalflow.components.model_client.openai_client import (
OpenAIClient,
) # Assuming OpenAIClient with .call() and .acall() is available
)
from adalflow.core.types import ModelType

from getpass import getpass
import os

from adalflow.utils import setup_env

# Load environment variables - Make sure to have OPENAI_API_KEY in .env file and .env is present in current folder
if os.path.isfile(".env"):
setup_env(".env")

# Prompt user to enter their API keys securely
if "OPENAI_API_KEY" not in os.environ:
openai_api_key = getpass("Please enter your OpenAI API key: ")
# Set environment variables
os.environ["OPENAI_API_KEY"] = openai_api_key
print("API keys have been set.")


# Synchronous function for benchmarking .call()
def benchmark_sync_call(api_kwargs, runs=10):
"""
Benchmark the synchronous .call() method by running it multiple times.
Expand All @@ -31,33 +17,26 @@ def benchmark_sync_call(api_kwargs, runs=10):
- api_kwargs: The arguments to be passed to the API call
- runs: The number of times to run the call (default is 10)
"""
# List to store responses
responses = []

# Record the start time of the benchmark
start_time = time.time()

# Perform synchronous API calls for the specified number of runs
responses = [
openai_client.call(
api_kwargs=api_kwargs, # API arguments
model_type=ModelType.LLM, # Model type (e.g., LLM for language models)
api_kwargs=api_kwargs,
model_type=ModelType.LLM,
)
for _ in range(runs) # Repeat 'runs' times
for _ in range(runs)
]

# Record the end time after all calls are completed
end_time = time.time()

# Output the results of each synchronous call
for i, response in enumerate(responses):
print(f"sync call {i + 1} completed: {response}")

# Print the total time taken for all synchronous calls
print(f"\nSynchronous benchmark completed in {end_time - start_time:.2f} seconds")


# Asynchronous function for benchmarking .acall()
async def benchmark_async_acall(api_kwargs, runs=10):
"""
Benchmark the asynchronous .acall() method by running it multiple times concurrently.
Expand All @@ -66,44 +45,37 @@ async def benchmark_async_acall(api_kwargs, runs=10):
- api_kwargs: The arguments to be passed to the API call
- runs: The number of times to run the asynchronous call (default is 10)
"""
# Record the start time of the benchmark
start_time = time.time()

# Create a list of asynchronous tasks for the specified number of runs
tasks = [
openai_client.acall(
api_kwargs=api_kwargs, # API arguments
model_type=ModelType.LLM, # Model type (e.g., LLM for language models)
api_kwargs=api_kwargs,
model_type=ModelType.LLM,
)
for _ in range(runs) # Repeat 'runs' times
for _ in range(runs)
]

# Execute all tasks concurrently and wait for them to finish
responses = await asyncio.gather(*tasks)

# Record the end time after all tasks are completed
end_time = time.time()

# Output the results of each asynchronous call
for i, response in enumerate(responses):
print(f"Async call {i + 1} completed: {response}")

# Print the total time taken for all asynchronous calls
print(f"\nAsynchronous benchmark completed in {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
# Initialize the OpenAI client
setup_env()
openai_client = OpenAIClient()

# Sample prompt for testing
prompt = "Tell me a joke."

model_kwargs = {"model": "gpt-3.5-turbo", "temperature": 0.5, "max_tokens": 100}
api_kwargs = openai_client.convert_inputs_to_api_kwargs(
input=prompt, model_kwargs=model_kwargs, model_type=ModelType.LLM
)
# Run both benchmarks

print("Starting synchronous benchmark...\n")
benchmark_sync_call(api_kwargs)

Expand Down
42 changes: 4 additions & 38 deletions tutorials/adalflow_text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"Code for tutorial: https://adalflow.sylph.ai/tutorials/text_splitter.html"

from adalflow.components.data_process.text_splitter import TextSplitter
from adalflow.core.types import Document
from typing import Optional, Dict
Expand All @@ -6,17 +8,7 @@
def split_by_words(
text: str, chunk_size: int = 5, chunk_overlap: int = 1, doc_id: Optional[str] = None
) -> list:
"""Split text by words with configurable parameters
Args:
text: Input text to split
chunk_size: Maximum number of words per chunk
chunk_overlap: Number of overlapping words between chunks
doc_id: Optional document ID

Returns:
List of Document objects containing the split text chunks
"""
text_splitter = TextSplitter(
split_by="word", chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
Expand All @@ -29,17 +21,7 @@ def split_by_words(
def split_by_tokens(
text: str, chunk_size: int = 5, chunk_overlap: int = 0, doc_id: Optional[str] = None
) -> list:
"""Split text by tokens with configurable parameters
Args:
text: Input text to split
chunk_size: Maximum number of tokens per chunk
chunk_overlap: Number of overlapping tokens between chunks
doc_id: Optional document ID

Returns:
List of Document objects containing the split text chunks
"""
text_splitter = TextSplitter(
split_by="token", chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
Expand All @@ -57,19 +39,7 @@ def split_by_custom(
chunk_overlap: int = 0,
doc_id: Optional[str] = None,
) -> list:
"""Split text using custom separator with configurable parameters
Args:
text: Input text to split
split_by: Custom split type that matches separator dict key
separators: Dictionary mapping split types to separator strings
chunk_size: Maximum chunk size
chunk_overlap: Overlap size between chunks
doc_id: Optional document ID
Returns:
List of Document objects containing the split text chunks
"""

text_splitter = TextSplitter(
split_by=split_by,
chunk_size=chunk_size,
Expand All @@ -83,27 +53,23 @@ def split_by_custom(


def example_usage():
"""Example showing how to use the text splitting functions"""
# Word splitting example
text = "Example text. More example text. Even more text to illustrate."
word_splits = split_by_words(text, chunk_size=5, chunk_overlap=1)
print("\nWord Split Example:")
for doc in word_splits:
print(doc)

# Token splitting example
token_splits = split_by_tokens(text, chunk_size=5, chunk_overlap=0)
print("\nToken Split Example:")
for doc in token_splits:
print(doc)

# Custom separator example
question_text = "What is your name? How old are you? Where do you live?"
custom_splits = split_by_custom(
text=question_text,
split_by="question",
separators={"question": "?"},
chunk_size=1,
separators={"question": "?"},
)
print("\nCustom Separator Example:")
for doc in custom_splits:
Expand Down

0 comments on commit c97d720

Please sign in to comment.