diff --git a/.gitmodules b/.gitmodules index 2fb2537..04dde04 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "third-party/pcre2"] + path = third-party/pcre2 + url = https://github.com/PCRE2Project/pcre2.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c5eac98..f0ce71c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,19 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece) + +# Configure PCRE2 +set(PCRE2_BUILD_PCRE2_8 ON) +set(PCRE2_BUILD_PCRE2_16 OFF) +set(PCRE2_BUILD_PCRE2_32 OFF) +set(PCRE2_BUILD_TESTS OFF) +set(PCRE2_BUILD_PCRE2GREP OFF) +set(PCRE2_BUILD_PCRE2TEST OFF) +set(PCRE2_BUILD_PCRE2GPERF OFF) +set(PCRE2_BUILD_DOCS OFF) +set(PCRE2_BUILD_LIBPCRE2_PDB OFF) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2) + set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) file(GLOB tokenizers_source_files ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) @@ -45,9 +58,10 @@ target_include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece/src ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) -target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2) +target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2 pcre2-8) # Build test if(TOKENIZERS_BUILD_TEST) @@ -77,7 +91,8 @@ if(TOKENIZERS_BUILD_TEST) ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) target_link_libraries(${test_name} gtest_main GTest::gmock tokenizers) add_test(${test_name} "${test_name}") set_tests_properties(${test_name} PROPERTIES ENVIRONMENT ${test_env}) diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 32da84d..16c0456 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -18,11 +18,9 @@ #include #include -// Third Party -#include - // Local #include +#include #include #include #include @@ -30,7 +28,6 @@ namespace tokenizers { namespace detail { -using Re2UPtr = std::unique_ptr; using TokenMap = StringIntegerMap<>; template @@ -119,9 +116,15 @@ class BPETokenizerBase : public Tokenizer { explicit BPETokenizerBase() {} virtual ~BPETokenizerBase() override {} - std::pair, re2::StringPiece> + std::pair, std::string> + split_with_allowed_special_token_( + const std::string& input, + const TokenMap& allowed_special) const; + + std::pair, std::string> split_with_allowed_special_token_( - re2::StringPiece& input, + const std::string& input, + size_t offset, const TokenMap& allowed_special) const; Result, uint64_t>> encode_with_special_token_( @@ -133,17 +136,17 @@ class BPETokenizerBase : public Tokenizer { const TokenMap& encoder) const; // Protected members that can be overloaded by other BPE tokenizers - Re2UPtr special_token_regex_; + std::unique_ptr special_token_regex_; std::optional token_map_; std::optional special_token_map_; private: virtual Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const = 0; - virtual void _decode(re2::StringPiece input, std::string& ret) const = 0; + virtual void _decode(const std::string& input, std::string& ret) const = 0; }; } // namespace detail diff --git a/include/pytorch/tokenizers/hf_tokenizer.h b/include/pytorch/tokenizers/hf_tokenizer.h index 4f8301a..54869c7 100644 --- a/include/pytorch/tokenizers/hf_tokenizer.h +++ b/include/pytorch/tokenizers/hf_tokenizer.h @@ -15,9 +15,6 @@ // Standard #include -// Third Party -#include - // Local #include #include @@ -43,11 +40,11 @@ class HFTokenizer : public detail::BPETokenizerBase { private: Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string& ret) const override; + void _decode(const std::string& input, std::string& ret) const override; PreTokenizer::Ptr _pretokenizer; TokenDecoder::Ptr _decoder; diff --git a/include/pytorch/tokenizers/pcre2_regex.h b/include/pytorch/tokenizers/pcre2_regex.h new file mode 100644 index 0000000..1015c43 --- /dev/null +++ b/include/pytorch/tokenizers/pcre2_regex.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +// Define PCRE2 code unit width before including pcre2.h +#define PCRE2_CODE_UNIT_WIDTH 8 + +// Third Party +#include + +// Local +#include "regex.h" + +/** + * @brief PCRE2-based implementation of IRegex. + */ +class Pcre2Regex : public IRegex { + public: + /** + * @brief Construct a PCRE2 regex with the given pattern. + * + * @param pattern The regex pattern to compile. + */ + explicit Pcre2Regex(const std::string& pattern); + + /** + * @brief Destructor to clean up PCRE2 resources. + */ + ~Pcre2Regex(); + + /** + * @brief Return all non-overlapping matches found in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + /** + * @brief Check if PCRE2 compiled the pattern successfully. + */ + bool ok() const override; + + protected: + /** + * @brief Expose internal PCRE2 pointer to the factory if needed. + */ + const pcre2_code* rawRegex() const; + + private: + pcre2_code* regex_; + pcre2_match_data* match_data_; + bool is_valid_; + + friend std::unique_ptr createRegex(const std::string& pattern); +}; \ No newline at end of file diff --git a/include/pytorch/tokenizers/pre_tokenizer.h b/include/pytorch/tokenizers/pre_tokenizer.h index 56218c7..8462c9f 100644 --- a/include/pytorch/tokenizers/pre_tokenizer.h +++ b/include/pytorch/tokenizers/pre_tokenizer.h @@ -19,6 +19,9 @@ #include #include +// Local +#include + namespace tokenizers { // -- Base --------------------------------------------------------------------- @@ -42,7 +45,7 @@ class PreTokenizer { * https://abseil.io/docs/cpp/guides/strings#string_view */ virtual std::vector pre_tokenize( - re2::StringPiece input) const = 0; + const std::string& input) const = 0; virtual ~PreTokenizer() = default; }; // end class PreTokenizer @@ -138,18 +141,16 @@ class PreTokenizerConfig { class RegexPreTokenizer : public PreTokenizer { public: - typedef std::unique_ptr Re2UPtr; - explicit RegexPreTokenizer(const std::string& pattern) : regex_(RegexPreTokenizer::create_regex_(pattern)) {} /** Pre-tokenize with the stored regex */ - std::vector pre_tokenize(re2::StringPiece input) const; + std::vector pre_tokenize(const std::string& input) const; protected: - static Re2UPtr create_regex_(const std::string& pattern); + static std::unique_ptr create_regex_(const std::string& pattern); - Re2UPtr regex_; + std::unique_ptr regex_; }; // end class RegexPreTokenizer @@ -185,7 +186,8 @@ class ByteLevelPreTokenizer : public PreTokenizer { : ByteLevelPreTokenizer(true, pattern) {} /** Perform pre-tokenization */ - std::vector pre_tokenize(re2::StringPiece input) const override; + std::vector pre_tokenize( + const std::string& input) const override; private: const std::string pattern_; @@ -206,7 +208,8 @@ class SequencePreTokenizer : public PreTokenizer { explicit SequencePreTokenizer(std::vector pre_tokenizers); /** Perform pre-tokenization */ - std::vector pre_tokenize(re2::StringPiece input) const override; + std::vector pre_tokenize( + const std::string& input) const override; private: const std::vector pre_tokenizers_; diff --git a/include/pytorch/tokenizers/re2_regex.h b/include/pytorch/tokenizers/re2_regex.h new file mode 100644 index 0000000..c615713 --- /dev/null +++ b/include/pytorch/tokenizers/re2_regex.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +// Third Party +#include + +// Local +#include "regex.h" + +/** + * @brief RE2-based implementation of IRegex. + */ +class Re2Regex : public IRegex { + public: + /** + * @brief Construct a RE2 regex with the given pattern. + * + * @param pattern The regex pattern to compile. + */ + explicit Re2Regex(const std::string& pattern); + + /** + * @brief Return all non-overlapping matches found in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + /** + * @brief Check if RE2 compiled the pattern successfully. + */ + bool ok() const override; + + protected: + /** + * @brief Expose internal RE2 pointer to the factory if needed. + */ + const re2::RE2* rawRegex() const; + + private: + std::unique_ptr regex_; + + friend std::unique_ptr createRegex(const std::string& pattern); +}; diff --git a/include/pytorch/tokenizers/regex.h b/include/pytorch/tokenizers/regex.h new file mode 100644 index 0000000..98dbc9f --- /dev/null +++ b/include/pytorch/tokenizers/regex.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +struct Match { + std::string text; + size_t position; +}; + +/** + * @brief Abstract interface for regex wrappers. + */ +class IRegex { + public: + virtual ~IRegex() = default; + + /** + * @brief Find all non-overlapping matches in the input string. + * + * @param text The input string to search. + * @return A vector of strings containing all matched substrings. + */ + virtual std::vector findAll(const std::string& text) const = 0; + + /** + * @brief Check if the regex pattern was compiled successfully. + * + * @return true if the pattern is valid and ready to use, false otherwise. + */ + virtual bool ok() const = 0; +}; + +/** + * @brief Creates a regex instance. Tries RE2 first, falls back to std::regex. + * + * @param pattern The regex pattern to compile. + * @return A unique pointer to an IRegex-compatible object. + */ +std::unique_ptr createRegex(const std::string& pattern); diff --git a/include/pytorch/tokenizers/std_regex.h b/include/pytorch/tokenizers/std_regex.h new file mode 100644 index 0000000..e49127b --- /dev/null +++ b/include/pytorch/tokenizers/std_regex.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include "regex.h" + +/** + * @brief std::regex-based implementation of IRegex. + */ +class StdRegex : public IRegex { + public: + /** + * @brief Construct a std::regex wrapper with the given pattern. + * + * @param pattern The regex pattern to compile. + * @throws std::regex_error if the pattern is invalid. + */ + explicit StdRegex(const std::string& pattern); + + /** + * @brief Find all non-overlapping matches in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + /** + * @brief Check if std::regex compiled the pattern successfully. + * + * @return true if the pattern is valid, false otherwise. + */ + bool ok() const override; + + private: + std::regex regex_; +}; diff --git a/include/pytorch/tokenizers/tiktoken.h b/include/pytorch/tokenizers/tiktoken.h index f4e4f9e..7cd3263 100644 --- a/include/pytorch/tokenizers/tiktoken.h +++ b/include/pytorch/tokenizers/tiktoken.h @@ -15,10 +15,11 @@ #include // Third Party -#include "re2/re2.h" +#include // Local #include +#include #include #include @@ -77,11 +78,11 @@ class Tiktoken : public detail::BPETokenizerBase { } Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string& ret) const override; + void _decode(const std::string& input, std::string& ret) const override; detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const; @@ -93,7 +94,7 @@ class Tiktoken : public detail::BPETokenizerBase { const std::string _pattern = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; - detail::Re2UPtr _regex; + std::unique_ptr _regex; }; } // namespace tokenizers diff --git a/include/pytorch/tokenizers/token_decoder.h b/include/pytorch/tokenizers/token_decoder.h index 825e95a..df2eac6 100644 --- a/include/pytorch/tokenizers/token_decoder.h +++ b/include/pytorch/tokenizers/token_decoder.h @@ -45,7 +45,7 @@ class TokenDecoder { * * @returns decoded: The decoded token string */ - virtual std::string decode(re2::StringPiece token) const = 0; + virtual std::string decode(const std::string& token) const = 0; // virtual destructor virtual ~TokenDecoder() = default; @@ -92,7 +92,7 @@ class TokenDecoderConfig { class ByteLevelTokenDecoder : public TokenDecoder { public: - std::string decode(re2::StringPiece token) const override; + std::string decode(const std::string& token) const override; }; // end class ByteLevelTokenDecoder diff --git a/src/bpe_tokenizer_base.cpp b/src/bpe_tokenizer_base.cpp index 63882c5..3c4d336 100644 --- a/src/bpe_tokenizer_base.cpp +++ b/src/bpe_tokenizer_base.cpp @@ -130,42 +130,24 @@ static std::vector _byte_pair_merge( // ---- Helper utils end ------------------------------------------------------- // ---- protected start -------------------------------------------------------- -std::pair, re2::StringPiece> +std::pair, std::string> BPETokenizerBase::split_with_allowed_special_token_( - re2::StringPiece& input, + const std::string& input, + size_t offset, const TokenMap& allowed_special) const { if (!special_token_regex_) { - return std::make_pair(std::nullopt, input); + return std::make_pair(std::nullopt, input.substr(offset)); } -#if __cplusplus >= 202002L - auto start = input.begin(); -#else - const char* start = input.data(); -#endif + auto matches = special_token_regex_->findAll(input.substr(offset)); - std::string special; - while (true) { - if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) { - // No special token. - break; + for (const auto& m : matches) { + if (allowed_special.tryGetInteger(m.text).has_value()) { + return {m.text, input.substr(offset, m.position)}; } - - if (allowed_special.tryGetInteger(special).has_value()) { - // Found an allowed special token, split the text with it. -#if __cplusplus >= 202002L - return std::make_pair( - special, - re2::StringPiece(start, input.begin() - start - special.size())); -#else - return std::make_pair( - special, - re2::StringPiece(start, (input.data() - start) - special.size())); -#endif - } // else try to find the next special token } - return std::make_pair(std::nullopt, input); + return {std::nullopt, input.substr(offset)}; } Result, uint64_t>> @@ -174,33 +156,31 @@ BPETokenizerBase::encode_with_special_token_( const TokenMap& allowed_special) const { std::vector tokens; uint64_t last_piece_token_len = 0; - re2::StringPiece input(text); - while (true) { + size_t offset = 0; + + while (offset < text.size()) { auto [special, sub_input] = - split_with_allowed_special_token_(input, allowed_special); + split_with_allowed_special_token_(text, offset, allowed_special); TK_CHECK_OK_OR_RETURN_ERROR( _encode(sub_input, tokens, last_piece_token_len)); + offset += sub_input.size(); if (special) { const auto result = special_token_map_->tryGetInteger(*special); if (!result) { - // Should never go here, since special pattern includes all special - // chars. TK_LOG(Error, "unknown special token: %s\n", special->c_str()); return Error::EncodeFailure; } tokens.push_back(*result); last_piece_token_len = 0; + offset += special->size(); // advance past the matched token } else { break; } } - // last_piece_token_len is how many tokens came from the last regex split. - // This is used for determining unstable tokens, since you can't merge - // across (stable) regex splits return std::make_pair(tokens, last_piece_token_len); } @@ -273,7 +253,7 @@ Result BPETokenizerBase::decode(uint64_t prev, uint64_t cur) } else { token_bytes = *result; } - _decode(token_bytes, ret); + _decode(std::string(token_bytes), ret); return ret; } diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 4a50673..1f97540 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -100,9 +100,11 @@ Error HFTokenizer::load(const std::string& path) { // Set up the pre-tokenizer try { + std::cout << "Setting up pretokenizer..." << std::endl; _pretokenizer = PreTokenizerConfig() .parse_json(parsed_json.at("pre_tokenizer")) .create(); + std::cout << "Pretokenizer set up" << std::endl; } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); return Error::LoadFailure; @@ -231,7 +233,7 @@ Error HFTokenizer::load(const std::string& path) { // -------------------------private method start-------------------------------- Error HFTokenizer::_encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { for (const auto& piece : _pretokenizer->pre_tokenize(input)) { @@ -249,15 +251,11 @@ Error HFTokenizer::_encode( return Error::Ok; } -void HFTokenizer::_decode(re2::StringPiece input, std::string& ret) const { +void HFTokenizer::_decode(const std::string& input, std::string& ret) const { if (_decoder) { ret += _decoder->decode(input); } else { -#ifdef _USE_INTERNAL_STRING_VIEW - ret += input.as_string(); -#else ret += input; -#endif } } diff --git a/src/pcre2_regex.cpp b/src/pcre2_regex.cpp new file mode 100644 index 0000000..b484ed3 --- /dev/null +++ b/src/pcre2_regex.cpp @@ -0,0 +1,111 @@ +#include "pytorch/tokenizers/pcre2_regex.h" + +#include +#include + +Pcre2Regex::Pcre2Regex(const std::string& pattern) + : regex_(nullptr), match_data_(nullptr), is_valid_(false) { + int error_code; + PCRE2_SIZE error_offset; + + // Compile the pattern + regex_ = pcre2_compile( + reinterpret_cast(pattern.c_str()), + pattern.length(), + PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode + &error_code, + &error_offset, + nullptr); + + if (regex_ == nullptr) { + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " + << error_buffer << std::endl; + return; + } + + // Create match data + match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr); + if (match_data_ == nullptr) { + pcre2_code_free(regex_); + regex_ = nullptr; + std::cerr << "Failed to create PCRE2 match data" << std::endl; + return; + } + + is_valid_ = true; +} + +Pcre2Regex::~Pcre2Regex() { + if (match_data_) { + pcre2_match_data_free(match_data_); + } + if (regex_) { + pcre2_code_free(regex_); + } +} + +std::vector Pcre2Regex::findAll(const std::string& text) const { + std::vector result; + + if (!is_valid_ || !regex_ || !match_data_) { + return result; + } + + PCRE2_SIZE* ovector; + PCRE2_SPTR subject = reinterpret_cast(text.c_str()); + PCRE2_SIZE subject_length = text.length(); + PCRE2_SIZE offset = 0; + + while (offset < subject_length) { + int rc = pcre2_match( + regex_, + subject, + subject_length, + offset, + 0, // Default options + match_data_, + nullptr); + + if (rc < 0) { + if (rc == PCRE2_ERROR_NOMATCH) { + break; // No more matches + } else { + // Error occurred + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(rc, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 matching error: " << error_buffer << std::endl; + break; + } + } + + ovector = pcre2_get_ovector_pointer(match_data_); + + // Extract the match + size_t match_start = ovector[0]; + size_t match_length = ovector[1] - ovector[0]; + + // Add the match to the result + result.push_back({text.substr(match_start, match_length), match_start}); + + // Move to the next position after the match + offset = ovector[1]; + + // If the match was empty, move forward by one character to avoid infinite + // loop + if (ovector[0] == ovector[1]) { + offset++; + } + } + + return result; +} + +bool Pcre2Regex::ok() const { + return is_valid_ && regex_ != nullptr && match_data_ != nullptr; +} + +const pcre2_code* Pcre2Regex::rawRegex() const { + return regex_; +} \ No newline at end of file diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 956403d..a1025bb 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -13,6 +13,7 @@ // Standard #include +#include #include #include @@ -105,20 +106,21 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { // RegexPreTokenizer /////////////////////////////////////////////////////////// -RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_( +std::unique_ptr RegexPreTokenizer::create_regex_( const std::string& pattern) { assert(!pattern.empty()); - return std::make_unique("(" + pattern + ")"); + return createRegex(pattern); } std::vector RegexPreTokenizer::pre_tokenize( - re2::StringPiece input) const { - std::vector result; - std::string piece; - while (RE2::FindAndConsume(&input, *regex_, &piece)) { - result.emplace_back(piece); + const std::string& input) const { + if (!regex_) + return {}; + std::vector results; + for (const auto& match : regex_->findAll(input)) { + results.push_back(match.text); } - return result; + return results; } // ByteLevelPreTokenizer /////////////////////////////////////////////////////// @@ -146,14 +148,15 @@ ByteLevelPreTokenizer::ByteLevelPreTokenizer( add_prefix_space_(add_prefix_space) {} std::vector ByteLevelPreTokenizer::pre_tokenize( - re2::StringPiece input) const { - // Add the prefix space if configured to do so - std::string input_str(input); - if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') { - input_str.insert(input_str.begin(), ' '); + const std::string& input) const { + // Add the prefix space if configured to do so. + std::string formatted_input = input; + if (add_prefix_space_ && !formatted_input.empty() && + formatted_input[0] != ' ') { + formatted_input.insert(formatted_input.begin(), ' '); } - return unicode_regex_split(input_str, {pattern_}); + return unicode_regex_split(formatted_input, {pattern_}); } // SequencePreTokenizer //////////////////////////////////////////////////////// @@ -163,7 +166,7 @@ SequencePreTokenizer::SequencePreTokenizer( : pre_tokenizers_(std::move(pre_tokenizers)) {} std::vector SequencePreTokenizer::pre_tokenize( - re2::StringPiece input) const { + const std::string& input) const { std::vector pieces{std::string(input)}; for (const auto& pre_tokenizer : pre_tokenizers_) { std::vector new_pieces; diff --git a/src/re2_regex.cpp b/src/re2_regex.cpp new file mode 100644 index 0000000..98cf8f5 --- /dev/null +++ b/src/re2_regex.cpp @@ -0,0 +1,32 @@ +#include "pytorch/tokenizers/re2_regex.h" + +Re2Regex::Re2Regex(const std::string& pattern) { + regex_ = std::make_unique(pattern); + // Warmup re2 as it is slow on the first run, void the return value as it's + // not needed Refer to + // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 + (void)regex_->ReverseProgramSize(); +} + +std::vector Re2Regex::findAll(const std::string& text) const { + std::vector result; + re2::StringPiece input(text); + re2::StringPiece piece; + + const char* base = input.data(); + + while (RE2::FindAndConsume(&input, *regex_, &piece)) { + size_t pos = piece.data() - base; + result.push_back({std::string(piece.data(), piece.size()), pos}); + } + + return result; +} + +bool Re2Regex::ok() const { + return regex_ && regex_->ok(); +} + +const re2::RE2* Re2Regex::rawRegex() const { + return regex_.get(); +} diff --git a/src/regex.cpp b/src/regex.cpp new file mode 100644 index 0000000..0df2895 --- /dev/null +++ b/src/regex.cpp @@ -0,0 +1,49 @@ +#include "pytorch/tokenizers/regex.h" +#include "pytorch/tokenizers/pcre2_regex.h" +#include "pytorch/tokenizers/re2_regex.h" +#include "pytorch/tokenizers/std_regex.h" + +#include +#include +#include + +/** + * @brief Factory function that creates a regex object using RE2 if possible. + * Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if + * PCRE2 fails. + */ +std::unique_ptr createRegex(const std::string& pattern) { + // Try RE2 first + auto re2 = std::make_unique("(" + pattern + ")"); + + if (re2->ok()) { + return re2; + } + + const re2::RE2* raw = re2->rawRegex(); + if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) { + // RE2 doesn't support some Perl features, try PCRE2 + auto pcre2 = std::make_unique("(" + pattern + ")"); + + if (pcre2->ok()) { + std::cout + << "RE2 is unable to support things such as negative lookaheads in " + << pattern << ", using PCRE2 instead." << std::endl; + return pcre2; + } + + // If PCRE2 also fails, fall back to std::regex + try { + std::cout + << "PCRE2 failed to compile pattern, falling back to std::regex."; + return std::make_unique("(" + pattern + ")"); + } catch (const std::regex_error& e) { + std::cerr << "std::regex failed: " << e.what() << std::endl; + return nullptr; + } + } else { + std::cerr << "RE2 failed to compile pattern: " << pattern << "\n"; + std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl; + return nullptr; + } +} diff --git a/src/std_regex.cpp b/src/std_regex.cpp new file mode 100644 index 0000000..5e0e248 --- /dev/null +++ b/src/std_regex.cpp @@ -0,0 +1,26 @@ +#include "pytorch/tokenizers/std_regex.h" +#include + +StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {} + +std::vector StdRegex::findAll(const std::string& text) const { + std::vector result; + std::sregex_iterator iter(text.begin(), text.end(), regex_); + std::sregex_iterator end; + + for (; iter != end; ++iter) { + const auto& match = *iter; + result.push_back({ + match[1].str(), // capture group 1 + static_cast(match.position(1)) // position of group 1 + }); + } + + return result; +} + +bool StdRegex::ok() const { + // std::regex constructor throws if the pattern is invalid + // If we got here, the pattern is valid + return true; +} diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 6a86eed..fa900c5 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -41,13 +41,12 @@ using namespace detail; // ------------------------------Util start------------------------------------ namespace { -static Re2UPtr _create_regex(const std::string& pattern) { +static std::unique_ptr _create_regex(const std::string& pattern) { assert(!pattern.empty()); - - return std::make_unique("(" + pattern + ")"); + return createRegex(pattern); } -static Re2UPtr _build_special_token_regex( +static std::unique_ptr _build_special_token_regex( const std::vector>& special_encoder) { std::string special_pattern; for (const auto& ele : special_encoder) { @@ -56,11 +55,9 @@ static Re2UPtr _build_special_token_regex( } special_pattern += re2::RE2::QuoteMeta(ele.first); } - if (special_pattern.empty()) { return nullptr; } - return _create_regex(special_pattern); } @@ -114,26 +111,26 @@ static Result _load_token_map(const std::string& path) { // -------------------------private method start------------------------------- Error Tiktoken::_encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { std::string piece; assert(_regex); - while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { - const auto result = token_map_->tryGetInteger(piece); + for (const auto& match : _regex->findAll(input)) { + const auto result = token_map_->tryGetInteger(match.text); if (result) { last_piece_token_len = 1; ret.push_back(*result); continue; } - auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_)); + auto tokens = TK_UNWRAP(byte_pair_encode_(match.text, *token_map_)); last_piece_token_len = tokens.size(); ret.insert(ret.end(), tokens.begin(), tokens.end()); } return Error::Ok; } -void Tiktoken::_decode(re2::StringPiece input, std::string& ret) const { +void Tiktoken::_decode(const std::string& input, std::string& ret) const { #ifdef _USE_INTERNAL_STRING_VIEW ret += input.as_string(); #else @@ -156,14 +153,7 @@ Error Tiktoken::load(const std::string& path) { special_token_map_.emplace(TokenMap(special_token_map)); _regex = _create_regex(_pattern); - // Warmup re2 as it is slow on the first run, void the return value as it's - // not needed Refer to - // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 - (void)_regex->ReverseProgramSize(); - special_token_regex_ = _build_special_token_regex(special_token_map); - // Same as above, warm up re2 - (void)special_token_regex_->ReverseProgramSize(); // initialize vocab_size, bos_tok, eos_tok vocab_size_ = token_map_->size() + special_token_map_->size(); diff --git a/src/token_decoder.cpp b/src/token_decoder.cpp index 7933840..3bc028e 100644 --- a/src/token_decoder.cpp +++ b/src/token_decoder.cpp @@ -71,15 +71,14 @@ static std::string format(const char* fmt, ...) { } // namespace -std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { +std::string ByteLevelTokenDecoder::decode(const std::string& token) const { // This is borrowed and lightly tweaked from llama.cpp // CITE: // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L1755 std::string decoded_text; // TODO: This could be more efficient since what we really need is a string // const ref. - std::string text(token); - const auto cpts = unicode_cpts_from_utf8(text); + const auto cpts = unicode_cpts_from_utf8(token); for (const auto cpt : cpts) { const auto utf8 = unicode_cpt_to_utf8(cpt); try { @@ -89,7 +88,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { for (const auto c : utf8) { decoded_text += format("%02x", (uint8_t)c); } - decoded_text += text + "]"; + decoded_text += token + "]"; } } diff --git a/test/test_pre_tokenizer.cpp b/test/test_pre_tokenizer.cpp index baa795b..d6e2736 100644 --- a/test/test_pre_tokenizer.cpp +++ b/test/test_pre_tokenizer.cpp @@ -23,8 +23,7 @@ static void assert_split_match( const PreTokenizer& ptok, const std::string& prompt, const std::vector& expected) { - re2::StringPiece prompt_view(prompt); - const auto& got = ptok.pre_tokenize(prompt_view); + const auto& got = ptok.pre_tokenize(prompt); EXPECT_EQ(expected.size(), got.size()); for (auto i = 0; i < got.size(); ++i) { EXPECT_EQ(expected[i], got[i]); diff --git a/test/test_regex.cpp b/test/test_regex.cpp new file mode 100644 index 0000000..915c62c --- /dev/null +++ b/test/test_regex.cpp @@ -0,0 +1,83 @@ +#include + +#include "pytorch/tokenizers/regex.h" +#include "pytorch/tokenizers/re2_regex.h" +#include "pytorch/tokenizers/pcre2_regex.h" + +// Test basic functionality +TEST(RegexTest, BasicMatching) { + auto regex = createRegex("\\w+"); + ASSERT_TRUE(regex->ok()); + + std::string text = "Hello world"; + auto matches = regex->findAll(text); + ASSERT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + EXPECT_EQ(matches[1].text, "world"); + EXPECT_EQ(matches[1].position, 6); +} + +// Test pattern that only PCRE2 supports (lookbehind) +TEST(RegexTest, Pcre2Specific) { + // First verify that RE2 cannot handle this pattern + const std::string pattern = "(?<=@)\\w+"; + Re2Regex re2_regex(pattern); + ASSERT_FALSE(re2_regex.ok()); + + // Now verify that the factory function fallsback on a PCRE2 regex + auto regex = createRegex(pattern); + ASSERT_TRUE(regex->ok()); + + std::string text = "user@example.com"; + auto matches = regex->findAll(text); + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].text, "example"); + EXPECT_EQ(matches[0].position, 5); +} + +// Test complex pattern with negative lookahead that should fall back to PCRE2. +// This specific pattern is from the Qwen2.5 1.5B pretokenizer. +// https://huggingface.co/Qwen/Qwen2.5-1.5B/raw/main/tokenizer.json +TEST(RegexTest, ComplexPatternWithNegativeLookahead) { + const std::string complex_pattern = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // First verify that RE2 cannot handle this pattern + Re2Regex re2_regex(complex_pattern); + ASSERT_FALSE(re2_regex.ok()); + + // Now verify that the factory function fallsback on a PCRE2 regex + auto regex = createRegex(complex_pattern); + ASSERT_TRUE(regex->ok()); + + // Test the pattern with some sample text + std::string text = "Hello's world\n test"; + auto matches = regex->findAll(text); + + // We expect to match: + // 1. "Hello" (word) + // 2. "'s" (contraction) + // 3. " world" (word with leading space) + // 4. "\n" (newline) + // 5. " " (whitespace) + // 6. " test" (word with leading space) + ASSERT_EQ(matches.size(), 6); + + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + + EXPECT_EQ(matches[1].text, "'s"); + EXPECT_EQ(matches[1].position, 5); + + EXPECT_EQ(matches[2].text, " world"); + EXPECT_EQ(matches[2].position, 7); + + EXPECT_EQ(matches[3].text, "\n"); + EXPECT_EQ(matches[3].position, 13); + + EXPECT_EQ(matches[4].text, " "); + EXPECT_EQ(matches[4].position, 14); + + EXPECT_EQ(matches[5].text, " test"); + EXPECT_EQ(matches[5].position, 15); +} \ No newline at end of file diff --git a/third-party/pcre2 b/third-party/pcre2 new file mode 160000 index 0000000..2e03e32 --- /dev/null +++ b/third-party/pcre2 @@ -0,0 +1 @@ +Subproject commit 2e03e323339ab692640626f02f8d8d6f95bff9c6