From 934ffa3b5cc16b94fa2c63f766d8bd9082fcaa91 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:49:20 -0700 Subject: [PATCH 1/9] Add regex interface --- include/pytorch/tokenizers/re2_regex.h | 42 ++++++++++++++++++++++++++ include/pytorch/tokenizers/regex.h | 34 +++++++++++++++++++++ include/pytorch/tokenizers/std_regex.h | 28 +++++++++++++++++ src/re2_regex.cpp | 33 ++++++++++++++++++++ src/regex.cpp | 37 +++++++++++++++++++++++ src/std_regex.cpp | 22 ++++++++++++++ 6 files changed, 196 insertions(+) create mode 100644 include/pytorch/tokenizers/re2_regex.h create mode 100644 include/pytorch/tokenizers/regex.h create mode 100644 include/pytorch/tokenizers/std_regex.h create mode 100644 src/re2_regex.cpp create mode 100644 src/regex.cpp create mode 100644 src/std_regex.cpp diff --git a/include/pytorch/tokenizers/re2_regex.h b/include/pytorch/tokenizers/re2_regex.h new file mode 100644 index 0000000..1012911 --- /dev/null +++ b/include/pytorch/tokenizers/re2_regex.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include "regex.h" + +// Third Party +#include + +/** + * @brief RE2-based implementation of IRegex. + */ +class Re2Regex : public IRegex { + public: + /** + * @brief Construct a RE2 regex with the given pattern. + * + * @param pattern The regex pattern to compile. + */ + explicit Re2Regex(const std::string& pattern); + + /** + * @brief Return all non-overlapping matches found in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + protected: + /** + * @brief Check if RE2 compiled the pattern successfully. + */ + bool ok() const; + + /** + * @brief Expose internal RE2 pointer to the factory if needed. + */ + const re2::RE2* rawRegex() const; + + private: + std::unique_ptr regex_; + + friend std::unique_ptr createRegex(const std::string& pattern); +}; diff --git a/include/pytorch/tokenizers/regex.h b/include/pytorch/tokenizers/regex.h new file mode 100644 index 0000000..ac6c80b --- /dev/null +++ b/include/pytorch/tokenizers/regex.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct Match { + std::string text; + size_t position; +}; + +/** + * @brief Abstract interface for regex wrappers. + */ +class IRegex { + public: + virtual ~IRegex() = default; + + /** + * @brief Find all non-overlapping matches in the input string. + * + * @param text The input string to search. + * @return A vector of strings containing all matched substrings. + */ + virtual std::vector findAll(const std::string& text) const = 0; +}; + +/** + * @brief Creates a regex instance. Tries RE2 first, falls back to std::regex. + * + * @param pattern The regex pattern to compile. + * @return A unique pointer to an IRegex-compatible object. + */ +std::unique_ptr createRegex(const std::string& pattern); diff --git a/include/pytorch/tokenizers/std_regex.h b/include/pytorch/tokenizers/std_regex.h new file mode 100644 index 0000000..41828bf --- /dev/null +++ b/include/pytorch/tokenizers/std_regex.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include +#include "regex.h" + +/** + * @brief std::regex-based implementation of IRegex. + */ +class StdRegex : public IRegex { + public: + /** + * @brief Construct a std::regex wrapper with the given pattern. + * + * @param pattern The regex pattern to compile. + * @throws std::regex_error if the pattern is invalid. + */ + explicit StdRegex(const std::string& pattern); + + /** + * @brief Find all non-overlapping matches in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + private: + std::regex regex_; +}; diff --git a/src/re2_regex.cpp b/src/re2_regex.cpp new file mode 100644 index 0000000..394032d --- /dev/null +++ b/src/re2_regex.cpp @@ -0,0 +1,33 @@ +#include "pytorch/tokenizers/re2_regex.h" +#include + +Re2Regex::Re2Regex(const std::string& pattern) { + regex_ = std::make_unique("(" + pattern + ")"); + // Warmup re2 as it is slow on the first run, void the return value as it's + // not needed Refer to + // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 + (void)regex_->ReverseProgramSize(); +} + +bool Re2Regex::ok() const { + return regex_ && regex_->ok(); +} + +const re2::RE2* Re2Regex::rawRegex() const { + return regex_.get(); +} + +std::vector Re2Regex::findAll(const std::string& text) const { + std::vector result; + re2::StringPiece input(text); + re2::StringPiece piece; + + const char* base = input.data(); + + while (RE2::FindAndConsume(&input, *regex_, &piece)) { + size_t pos = piece.data() - base; + result.push_back({ std::string(piece.data(), piece.size()), pos }); + } + + return result; +} \ No newline at end of file diff --git a/src/regex.cpp b/src/regex.cpp new file mode 100644 index 0000000..9bf81d7 --- /dev/null +++ b/src/regex.cpp @@ -0,0 +1,37 @@ +#include "pytorch/tokenizers/regex.h" +#include "pytorch/tokenizers/re2_regex.h" +#include "pytorch/tokenizers/std_regex.h" + +#include +#include +#include + +/** + * @brief Factory function that creates a regex object using RE2 if possible. + * Falls back to std::regex if RE2 rejects the pattern with + * ErrorBadPerlOp. + */ +std::unique_ptr createRegex(const std::string& pattern) { + auto re2 = std::make_unique(pattern); + + if (re2->ok()) { + return re2; + } + + const re2::RE2* raw = re2->rawRegex(); + if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) { + try { + std::cout + << "RE2 is unable to support things such as negative lookaheads in " + << pattern << ", defaulting to std::regex."; + return std::make_unique(pattern); + } catch (const std::regex_error& e) { + std::cerr << "std::regex failed: " << e.what() << std::endl; + return nullptr; + } + } else { + std::cerr << "RE2 failed to compile pattern: " << pattern << "\n"; + std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl; + return nullptr; + } +} diff --git a/src/std_regex.cpp b/src/std_regex.cpp new file mode 100644 index 0000000..c2b98e0 --- /dev/null +++ b/src/std_regex.cpp @@ -0,0 +1,22 @@ +#include "pytorch/tokenizers/std_regex.h" +#include + +StdRegex::StdRegex(const std::string& pattern) + : regex_("(" + pattern + ")") // Add parentheses like RE2 version +{} + +std::vector StdRegex::findAll(const std::string& text) const { + std::vector result; + std::sregex_iterator iter(text.begin(), text.end(), regex_); + std::sregex_iterator end; + + for (; iter != end; ++iter) { + const auto& match = *iter; + result.push_back({ + match[1].str(), // capture group 1 + static_cast(match.position(1)) // position of group 1 + }); + } + + return result; +} From a78b408e00424f86334cd6194ecc1a876274c5de Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 14 Apr 2025 17:04:54 -0700 Subject: [PATCH 2/9] Use IRegex interface --- .../pytorch/tokenizers/bpe_tokenizer_base.h | 21 ++++---- include/pytorch/tokenizers/hf_tokenizer.h | 7 +-- include/pytorch/tokenizers/pre_tokenizer.h | 19 ++++--- include/pytorch/tokenizers/re2_regex.h | 4 +- include/pytorch/tokenizers/regex.h | 7 +++ include/pytorch/tokenizers/tiktoken.h | 9 ++-- include/pytorch/tokenizers/token_decoder.h | 4 +- src/bpe_tokenizer_base.cpp | 52 ++++++------------- src/hf_tokenizer.cpp | 10 ++-- src/pre_tokenizer.cpp | 32 ++++++------ src/re2_regex.cpp | 21 ++++---- src/regex.cpp | 22 ++++++-- src/std_regex.cpp | 4 +- src/tiktoken.cpp | 26 +++------- src/token_decoder.cpp | 7 ++- test/test_pre_tokenizer.cpp | 3 +- 16 files changed, 121 insertions(+), 127 deletions(-) diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index 32da84d..16c0456 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -18,11 +18,9 @@ #include #include -// Third Party -#include - // Local #include +#include #include #include #include @@ -30,7 +28,6 @@ namespace tokenizers { namespace detail { -using Re2UPtr = std::unique_ptr; using TokenMap = StringIntegerMap<>; template @@ -119,9 +116,15 @@ class BPETokenizerBase : public Tokenizer { explicit BPETokenizerBase() {} virtual ~BPETokenizerBase() override {} - std::pair, re2::StringPiece> + std::pair, std::string> + split_with_allowed_special_token_( + const std::string& input, + const TokenMap& allowed_special) const; + + std::pair, std::string> split_with_allowed_special_token_( - re2::StringPiece& input, + const std::string& input, + size_t offset, const TokenMap& allowed_special) const; Result, uint64_t>> encode_with_special_token_( @@ -133,17 +136,17 @@ class BPETokenizerBase : public Tokenizer { const TokenMap& encoder) const; // Protected members that can be overloaded by other BPE tokenizers - Re2UPtr special_token_regex_; + std::unique_ptr special_token_regex_; std::optional token_map_; std::optional special_token_map_; private: virtual Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const = 0; - virtual void _decode(re2::StringPiece input, std::string& ret) const = 0; + virtual void _decode(const std::string& input, std::string& ret) const = 0; }; } // namespace detail diff --git a/include/pytorch/tokenizers/hf_tokenizer.h b/include/pytorch/tokenizers/hf_tokenizer.h index 4f8301a..54869c7 100644 --- a/include/pytorch/tokenizers/hf_tokenizer.h +++ b/include/pytorch/tokenizers/hf_tokenizer.h @@ -15,9 +15,6 @@ // Standard #include -// Third Party -#include - // Local #include #include @@ -43,11 +40,11 @@ class HFTokenizer : public detail::BPETokenizerBase { private: Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string& ret) const override; + void _decode(const std::string& input, std::string& ret) const override; PreTokenizer::Ptr _pretokenizer; TokenDecoder::Ptr _decoder; diff --git a/include/pytorch/tokenizers/pre_tokenizer.h b/include/pytorch/tokenizers/pre_tokenizer.h index 56218c7..8462c9f 100644 --- a/include/pytorch/tokenizers/pre_tokenizer.h +++ b/include/pytorch/tokenizers/pre_tokenizer.h @@ -19,6 +19,9 @@ #include #include +// Local +#include + namespace tokenizers { // -- Base --------------------------------------------------------------------- @@ -42,7 +45,7 @@ class PreTokenizer { * https://abseil.io/docs/cpp/guides/strings#string_view */ virtual std::vector pre_tokenize( - re2::StringPiece input) const = 0; + const std::string& input) const = 0; virtual ~PreTokenizer() = default; }; // end class PreTokenizer @@ -138,18 +141,16 @@ class PreTokenizerConfig { class RegexPreTokenizer : public PreTokenizer { public: - typedef std::unique_ptr Re2UPtr; - explicit RegexPreTokenizer(const std::string& pattern) : regex_(RegexPreTokenizer::create_regex_(pattern)) {} /** Pre-tokenize with the stored regex */ - std::vector pre_tokenize(re2::StringPiece input) const; + std::vector pre_tokenize(const std::string& input) const; protected: - static Re2UPtr create_regex_(const std::string& pattern); + static std::unique_ptr create_regex_(const std::string& pattern); - Re2UPtr regex_; + std::unique_ptr regex_; }; // end class RegexPreTokenizer @@ -185,7 +186,8 @@ class ByteLevelPreTokenizer : public PreTokenizer { : ByteLevelPreTokenizer(true, pattern) {} /** Perform pre-tokenization */ - std::vector pre_tokenize(re2::StringPiece input) const override; + std::vector pre_tokenize( + const std::string& input) const override; private: const std::string pattern_; @@ -206,7 +208,8 @@ class SequencePreTokenizer : public PreTokenizer { explicit SequencePreTokenizer(std::vector pre_tokenizers); /** Perform pre-tokenization */ - std::vector pre_tokenize(re2::StringPiece input) const override; + std::vector pre_tokenize( + const std::string& input) const override; private: const std::vector pre_tokenizers_; diff --git a/include/pytorch/tokenizers/re2_regex.h b/include/pytorch/tokenizers/re2_regex.h index 1012911..7a3c64c 100644 --- a/include/pytorch/tokenizers/re2_regex.h +++ b/include/pytorch/tokenizers/re2_regex.h @@ -2,11 +2,13 @@ #include #include -#include "regex.h" // Third Party #include +// Local +#include "regex.h" + /** * @brief RE2-based implementation of IRegex. */ diff --git a/include/pytorch/tokenizers/regex.h b/include/pytorch/tokenizers/regex.h index ac6c80b..0ade91a 100644 --- a/include/pytorch/tokenizers/regex.h +++ b/include/pytorch/tokenizers/regex.h @@ -32,3 +32,10 @@ class IRegex { * @return A unique pointer to an IRegex-compatible object. */ std::unique_ptr createRegex(const std::string& pattern); + +// /** +// * Factory functions for creating specific regex instances, prefer using +// * createRegex unless you know you need to use a specific regex. +// */ +// std::unique_ptr MakeRe2Regex(const std::string& pattern); +// std::unique_ptr MakeStdRegex(const std::string& pattern); diff --git a/include/pytorch/tokenizers/tiktoken.h b/include/pytorch/tokenizers/tiktoken.h index f4e4f9e..7cd3263 100644 --- a/include/pytorch/tokenizers/tiktoken.h +++ b/include/pytorch/tokenizers/tiktoken.h @@ -15,10 +15,11 @@ #include // Third Party -#include "re2/re2.h" +#include // Local #include +#include #include #include @@ -77,11 +78,11 @@ class Tiktoken : public detail::BPETokenizerBase { } Error _encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string& ret) const override; + void _decode(const std::string& input, std::string& ret) const override; detail::TokenMap _build_special_token_map(ssize_t num_base_tokens) const; @@ -93,7 +94,7 @@ class Tiktoken : public detail::BPETokenizerBase { const std::string _pattern = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; - detail::Re2UPtr _regex; + std::unique_ptr _regex; }; } // namespace tokenizers diff --git a/include/pytorch/tokenizers/token_decoder.h b/include/pytorch/tokenizers/token_decoder.h index 825e95a..df2eac6 100644 --- a/include/pytorch/tokenizers/token_decoder.h +++ b/include/pytorch/tokenizers/token_decoder.h @@ -45,7 +45,7 @@ class TokenDecoder { * * @returns decoded: The decoded token string */ - virtual std::string decode(re2::StringPiece token) const = 0; + virtual std::string decode(const std::string& token) const = 0; // virtual destructor virtual ~TokenDecoder() = default; @@ -92,7 +92,7 @@ class TokenDecoderConfig { class ByteLevelTokenDecoder : public TokenDecoder { public: - std::string decode(re2::StringPiece token) const override; + std::string decode(const std::string& token) const override; }; // end class ByteLevelTokenDecoder diff --git a/src/bpe_tokenizer_base.cpp b/src/bpe_tokenizer_base.cpp index 63882c5..3c4d336 100644 --- a/src/bpe_tokenizer_base.cpp +++ b/src/bpe_tokenizer_base.cpp @@ -130,42 +130,24 @@ static std::vector _byte_pair_merge( // ---- Helper utils end ------------------------------------------------------- // ---- protected start -------------------------------------------------------- -std::pair, re2::StringPiece> +std::pair, std::string> BPETokenizerBase::split_with_allowed_special_token_( - re2::StringPiece& input, + const std::string& input, + size_t offset, const TokenMap& allowed_special) const { if (!special_token_regex_) { - return std::make_pair(std::nullopt, input); + return std::make_pair(std::nullopt, input.substr(offset)); } -#if __cplusplus >= 202002L - auto start = input.begin(); -#else - const char* start = input.data(); -#endif + auto matches = special_token_regex_->findAll(input.substr(offset)); - std::string special; - while (true) { - if (!re2::RE2::FindAndConsume(&input, *special_token_regex_, &special)) { - // No special token. - break; + for (const auto& m : matches) { + if (allowed_special.tryGetInteger(m.text).has_value()) { + return {m.text, input.substr(offset, m.position)}; } - - if (allowed_special.tryGetInteger(special).has_value()) { - // Found an allowed special token, split the text with it. -#if __cplusplus >= 202002L - return std::make_pair( - special, - re2::StringPiece(start, input.begin() - start - special.size())); -#else - return std::make_pair( - special, - re2::StringPiece(start, (input.data() - start) - special.size())); -#endif - } // else try to find the next special token } - return std::make_pair(std::nullopt, input); + return {std::nullopt, input.substr(offset)}; } Result, uint64_t>> @@ -174,33 +156,31 @@ BPETokenizerBase::encode_with_special_token_( const TokenMap& allowed_special) const { std::vector tokens; uint64_t last_piece_token_len = 0; - re2::StringPiece input(text); - while (true) { + size_t offset = 0; + + while (offset < text.size()) { auto [special, sub_input] = - split_with_allowed_special_token_(input, allowed_special); + split_with_allowed_special_token_(text, offset, allowed_special); TK_CHECK_OK_OR_RETURN_ERROR( _encode(sub_input, tokens, last_piece_token_len)); + offset += sub_input.size(); if (special) { const auto result = special_token_map_->tryGetInteger(*special); if (!result) { - // Should never go here, since special pattern includes all special - // chars. TK_LOG(Error, "unknown special token: %s\n", special->c_str()); return Error::EncodeFailure; } tokens.push_back(*result); last_piece_token_len = 0; + offset += special->size(); // advance past the matched token } else { break; } } - // last_piece_token_len is how many tokens came from the last regex split. - // This is used for determining unstable tokens, since you can't merge - // across (stable) regex splits return std::make_pair(tokens, last_piece_token_len); } @@ -273,7 +253,7 @@ Result BPETokenizerBase::decode(uint64_t prev, uint64_t cur) } else { token_bytes = *result; } - _decode(token_bytes, ret); + _decode(std::string(token_bytes), ret); return ret; } diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 4a50673..1f97540 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -100,9 +100,11 @@ Error HFTokenizer::load(const std::string& path) { // Set up the pre-tokenizer try { + std::cout << "Setting up pretokenizer..." << std::endl; _pretokenizer = PreTokenizerConfig() .parse_json(parsed_json.at("pre_tokenizer")) .create(); + std::cout << "Pretokenizer set up" << std::endl; } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); return Error::LoadFailure; @@ -231,7 +233,7 @@ Error HFTokenizer::load(const std::string& path) { // -------------------------private method start-------------------------------- Error HFTokenizer::_encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { for (const auto& piece : _pretokenizer->pre_tokenize(input)) { @@ -249,15 +251,11 @@ Error HFTokenizer::_encode( return Error::Ok; } -void HFTokenizer::_decode(re2::StringPiece input, std::string& ret) const { +void HFTokenizer::_decode(const std::string& input, std::string& ret) const { if (_decoder) { ret += _decoder->decode(input); } else { -#ifdef _USE_INTERNAL_STRING_VIEW - ret += input.as_string(); -#else ret += input; -#endif } } diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 956403d..9b6cb2c 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -13,6 +13,7 @@ // Standard #include +#include #include #include @@ -105,20 +106,21 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { // RegexPreTokenizer /////////////////////////////////////////////////////////// -RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_( +std::unique_ptr RegexPreTokenizer::create_regex_( const std::string& pattern) { assert(!pattern.empty()); - return std::make_unique("(" + pattern + ")"); + return createRegex(pattern); } std::vector RegexPreTokenizer::pre_tokenize( - re2::StringPiece input) const { - std::vector result; - std::string piece; - while (RE2::FindAndConsume(&input, *regex_, &piece)) { - result.emplace_back(piece); + const std::string& input) const { + if (!regex_) + return {}; + std::vector results; + for (const auto& match : regex_->findAll(input)) { + results.push_back(match.text); } - return result; + return results; } // ByteLevelPreTokenizer /////////////////////////////////////////////////////// @@ -146,14 +148,14 @@ ByteLevelPreTokenizer::ByteLevelPreTokenizer( add_prefix_space_(add_prefix_space) {} std::vector ByteLevelPreTokenizer::pre_tokenize( - re2::StringPiece input) const { - // Add the prefix space if configured to do so - std::string input_str(input); - if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') { - input_str.insert(input_str.begin(), ' '); + const std::string& input) const { + // Add the prefix space if configured to do so. + std::string formatted_input = input; + if (add_prefix_space_ && !input.empty() && input[0] != ' ') { + formatted_input.insert(input.begin(), ' '); } - return unicode_regex_split(input_str, {pattern_}); + return unicode_regex_split(formatted_input, {pattern_}); } // SequencePreTokenizer //////////////////////////////////////////////////////// @@ -163,7 +165,7 @@ SequencePreTokenizer::SequencePreTokenizer( : pre_tokenizers_(std::move(pre_tokenizers)) {} std::vector SequencePreTokenizer::pre_tokenize( - re2::StringPiece input) const { + const std::string& input) const { std::vector pieces{std::string(input)}; for (const auto& pre_tokenizer : pre_tokenizers_) { std::vector new_pieces; diff --git a/src/re2_regex.cpp b/src/re2_regex.cpp index 394032d..ee47024 100644 --- a/src/re2_regex.cpp +++ b/src/re2_regex.cpp @@ -1,22 +1,13 @@ #include "pytorch/tokenizers/re2_regex.h" -#include Re2Regex::Re2Regex(const std::string& pattern) { - regex_ = std::make_unique("(" + pattern + ")"); + regex_ = std::make_unique(pattern); // Warmup re2 as it is slow on the first run, void the return value as it's // not needed Refer to // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 (void)regex_->ReverseProgramSize(); } -bool Re2Regex::ok() const { - return regex_ && regex_->ok(); -} - -const re2::RE2* Re2Regex::rawRegex() const { - return regex_.get(); -} - std::vector Re2Regex::findAll(const std::string& text) const { std::vector result; re2::StringPiece input(text); @@ -30,4 +21,12 @@ std::vector Re2Regex::findAll(const std::string& text) const { } return result; -} \ No newline at end of file +} + +bool Re2Regex::ok() const { + return regex_ && regex_->ok(); +} + +const re2::RE2* Re2Regex::rawRegex() const { + return regex_.get(); +} diff --git a/src/regex.cpp b/src/regex.cpp index 9bf81d7..246742c 100644 --- a/src/regex.cpp +++ b/src/regex.cpp @@ -9,10 +9,10 @@ /** * @brief Factory function that creates a regex object using RE2 if possible. * Falls back to std::regex if RE2 rejects the pattern with - * ErrorBadPerlOp. + * ErrorBadPerlOp. */ std::unique_ptr createRegex(const std::string& pattern) { - auto re2 = std::make_unique(pattern); + auto re2 = std::make_unique("(" + pattern + ")"); if (re2->ok()) { return re2; @@ -24,7 +24,7 @@ std::unique_ptr createRegex(const std::string& pattern) { std::cout << "RE2 is unable to support things such as negative lookaheads in " << pattern << ", defaulting to std::regex."; - return std::make_unique(pattern); + return std::make_unique("(" + pattern + ")"); } catch (const std::regex_error& e) { std::cerr << "std::regex failed: " << e.what() << std::endl; return nullptr; @@ -35,3 +35,19 @@ std::unique_ptr createRegex(const std::string& pattern) { return nullptr; } } + +// std::unique_ptr createRe2Regex(const std::string& pattern) { +// auto re2 = std::make_unique(pattern); + +// if (re2->ok()) { +// return re2; +// } + +// std::cerr << "RE2 failed to compile pattern: " << pattern << "\n"; +// std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl; +// return nullptr; +// } + +// std::unique_ptr CreateStdRegex(const std::string& pattern) { +// return std::make_unique(pattern); +// } diff --git a/src/std_regex.cpp b/src/std_regex.cpp index c2b98e0..83c8e6d 100644 --- a/src/std_regex.cpp +++ b/src/std_regex.cpp @@ -1,9 +1,7 @@ #include "pytorch/tokenizers/std_regex.h" #include -StdRegex::StdRegex(const std::string& pattern) - : regex_("(" + pattern + ")") // Add parentheses like RE2 version -{} +StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {} std::vector StdRegex::findAll(const std::string& text) const { std::vector result; diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 6a86eed..fa900c5 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -41,13 +41,12 @@ using namespace detail; // ------------------------------Util start------------------------------------ namespace { -static Re2UPtr _create_regex(const std::string& pattern) { +static std::unique_ptr _create_regex(const std::string& pattern) { assert(!pattern.empty()); - - return std::make_unique("(" + pattern + ")"); + return createRegex(pattern); } -static Re2UPtr _build_special_token_regex( +static std::unique_ptr _build_special_token_regex( const std::vector>& special_encoder) { std::string special_pattern; for (const auto& ele : special_encoder) { @@ -56,11 +55,9 @@ static Re2UPtr _build_special_token_regex( } special_pattern += re2::RE2::QuoteMeta(ele.first); } - if (special_pattern.empty()) { return nullptr; } - return _create_regex(special_pattern); } @@ -114,26 +111,26 @@ static Result _load_token_map(const std::string& path) { // -------------------------private method start------------------------------- Error Tiktoken::_encode( - re2::StringPiece& input, + const std::string& input, std::vector& ret, uint64_t& last_piece_token_len) const { std::string piece; assert(_regex); - while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { - const auto result = token_map_->tryGetInteger(piece); + for (const auto& match : _regex->findAll(input)) { + const auto result = token_map_->tryGetInteger(match.text); if (result) { last_piece_token_len = 1; ret.push_back(*result); continue; } - auto tokens = TK_UNWRAP(byte_pair_encode_(piece, *token_map_)); + auto tokens = TK_UNWRAP(byte_pair_encode_(match.text, *token_map_)); last_piece_token_len = tokens.size(); ret.insert(ret.end(), tokens.begin(), tokens.end()); } return Error::Ok; } -void Tiktoken::_decode(re2::StringPiece input, std::string& ret) const { +void Tiktoken::_decode(const std::string& input, std::string& ret) const { #ifdef _USE_INTERNAL_STRING_VIEW ret += input.as_string(); #else @@ -156,14 +153,7 @@ Error Tiktoken::load(const std::string& path) { special_token_map_.emplace(TokenMap(special_token_map)); _regex = _create_regex(_pattern); - // Warmup re2 as it is slow on the first run, void the return value as it's - // not needed Refer to - // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 - (void)_regex->ReverseProgramSize(); - special_token_regex_ = _build_special_token_regex(special_token_map); - // Same as above, warm up re2 - (void)special_token_regex_->ReverseProgramSize(); // initialize vocab_size, bos_tok, eos_tok vocab_size_ = token_map_->size() + special_token_map_->size(); diff --git a/src/token_decoder.cpp b/src/token_decoder.cpp index 7933840..3bc028e 100644 --- a/src/token_decoder.cpp +++ b/src/token_decoder.cpp @@ -71,15 +71,14 @@ static std::string format(const char* fmt, ...) { } // namespace -std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { +std::string ByteLevelTokenDecoder::decode(const std::string& token) const { // This is borrowed and lightly tweaked from llama.cpp // CITE: // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L1755 std::string decoded_text; // TODO: This could be more efficient since what we really need is a string // const ref. - std::string text(token); - const auto cpts = unicode_cpts_from_utf8(text); + const auto cpts = unicode_cpts_from_utf8(token); for (const auto cpt : cpts) { const auto utf8 = unicode_cpt_to_utf8(cpt); try { @@ -89,7 +88,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { for (const auto c : utf8) { decoded_text += format("%02x", (uint8_t)c); } - decoded_text += text + "]"; + decoded_text += token + "]"; } } diff --git a/test/test_pre_tokenizer.cpp b/test/test_pre_tokenizer.cpp index baa795b..d6e2736 100644 --- a/test/test_pre_tokenizer.cpp +++ b/test/test_pre_tokenizer.cpp @@ -23,8 +23,7 @@ static void assert_split_match( const PreTokenizer& ptok, const std::string& prompt, const std::vector& expected) { - re2::StringPiece prompt_view(prompt); - const auto& got = ptok.pre_tokenize(prompt_view); + const auto& got = ptok.pre_tokenize(prompt); EXPECT_EQ(expected.size(), got.size()); for (auto i = 0; i < got.size(); ++i) { EXPECT_EQ(expected[i], got[i]); From 6fdaa9f62512e6c8977a9a5450f578dda0fd9051 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 14 Apr 2025 22:05:53 -0700 Subject: [PATCH 3/9] Rename re2_regex.cpp (?) --- src/{re2_regex.cpp => re2_regex.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{re2_regex.cpp => re2_regex.cpp} (100%) diff --git a/src/re2_regex.cpp b/src/re2_regex.cpp similarity index 100% rename from src/re2_regex.cpp rename to src/re2_regex.cpp From 2d3899dba459f275e4a0b67ef65acb6965c83c2a Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 08:11:20 -0700 Subject: [PATCH 4/9] Fix seg fault --- src/pre_tokenizer.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 9b6cb2c..a1025bb 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -151,8 +151,9 @@ std::vector ByteLevelPreTokenizer::pre_tokenize( const std::string& input) const { // Add the prefix space if configured to do so. std::string formatted_input = input; - if (add_prefix_space_ && !input.empty() && input[0] != ' ') { - formatted_input.insert(input.begin(), ' '); + if (add_prefix_space_ && !formatted_input.empty() && + formatted_input[0] != ' ') { + formatted_input.insert(formatted_input.begin(), ' '); } return unicode_regex_split(formatted_input, {pattern_}); From fc4d9268c4dd406b3d6aa741ca0041c0222b0821 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 08:11:52 -0700 Subject: [PATCH 5/9] Lint --- src/re2_regex.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/re2_regex.cpp b/src/re2_regex.cpp index ee47024..98cf8f5 100644 --- a/src/re2_regex.cpp +++ b/src/re2_regex.cpp @@ -1,32 +1,32 @@ #include "pytorch/tokenizers/re2_regex.h" Re2Regex::Re2Regex(const std::string& pattern) { - regex_ = std::make_unique(pattern); - // Warmup re2 as it is slow on the first run, void the return value as it's - // not needed Refer to - // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 - (void)regex_->ReverseProgramSize(); + regex_ = std::make_unique(pattern); + // Warmup re2 as it is slow on the first run, void the return value as it's + // not needed Refer to + // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141 + (void)regex_->ReverseProgramSize(); } std::vector Re2Regex::findAll(const std::string& text) const { - std::vector result; - re2::StringPiece input(text); - re2::StringPiece piece; + std::vector result; + re2::StringPiece input(text); + re2::StringPiece piece; - const char* base = input.data(); + const char* base = input.data(); - while (RE2::FindAndConsume(&input, *regex_, &piece)) { - size_t pos = piece.data() - base; - result.push_back({ std::string(piece.data(), piece.size()), pos }); - } + while (RE2::FindAndConsume(&input, *regex_, &piece)) { + size_t pos = piece.data() - base; + result.push_back({std::string(piece.data(), piece.size()), pos}); + } - return result; + return result; } bool Re2Regex::ok() const { - return regex_ && regex_->ok(); + return regex_ && regex_->ok(); } const re2::RE2* Re2Regex::rawRegex() const { - return regex_.get(); + return regex_.get(); } From edd7c99cd2e40238b9683c3feb486051e0792b7a Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:06:44 -0700 Subject: [PATCH 6/9] Remove code --- include/pytorch/tokenizers/regex.h | 7 ------- src/regex.cpp | 16 ---------------- 2 files changed, 23 deletions(-) diff --git a/include/pytorch/tokenizers/regex.h b/include/pytorch/tokenizers/regex.h index 0ade91a..ac6c80b 100644 --- a/include/pytorch/tokenizers/regex.h +++ b/include/pytorch/tokenizers/regex.h @@ -32,10 +32,3 @@ class IRegex { * @return A unique pointer to an IRegex-compatible object. */ std::unique_ptr createRegex(const std::string& pattern); - -// /** -// * Factory functions for creating specific regex instances, prefer using -// * createRegex unless you know you need to use a specific regex. -// */ -// std::unique_ptr MakeRe2Regex(const std::string& pattern); -// std::unique_ptr MakeStdRegex(const std::string& pattern); diff --git a/src/regex.cpp b/src/regex.cpp index 246742c..cb9df9a 100644 --- a/src/regex.cpp +++ b/src/regex.cpp @@ -35,19 +35,3 @@ std::unique_ptr createRegex(const std::string& pattern) { return nullptr; } } - -// std::unique_ptr createRe2Regex(const std::string& pattern) { -// auto re2 = std::make_unique(pattern); - -// if (re2->ok()) { -// return re2; -// } - -// std::cerr << "RE2 failed to compile pattern: " << pattern << "\n"; -// std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl; -// return nullptr; -// } - -// std::unique_ptr CreateStdRegex(const std::string& pattern) { -// return std::make_unique(pattern); -// } From 5d83d3e49208ac938453afb973c25ddcbd4c7368 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 08:55:17 -0700 Subject: [PATCH 7/9] Add pcre2 as default re2 fallback --- .gitmodules | 3 + CMakeLists.txt | 21 ++++- include/pytorch/tokenizers/pcre2_regex.h | 54 +++++++++++ src/pcre2_regex.cpp | 113 +++++++++++++++++++++++ src/regex.cpp | 19 +++- test/test_pcre2_regex.cpp | 68 ++++++++++++++ third-party/pcre2 | 1 + 7 files changed, 271 insertions(+), 8 deletions(-) create mode 100644 include/pytorch/tokenizers/pcre2_regex.h create mode 100644 src/pcre2_regex.cpp create mode 100644 test/test_pcre2_regex.cpp create mode 160000 third-party/pcre2 diff --git a/.gitmodules b/.gitmodules index 2fb2537..04dde04 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third-party/json"] path = third-party/json url = https://github.com/nlohmann/json.git +[submodule "third-party/pcre2"] + path = third-party/pcre2 + url = https://github.com/PCRE2Project/pcre2.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c5eac98..f0ce71c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,19 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece) + +# Configure PCRE2 +set(PCRE2_BUILD_PCRE2_8 ON) +set(PCRE2_BUILD_PCRE2_16 OFF) +set(PCRE2_BUILD_PCRE2_32 OFF) +set(PCRE2_BUILD_TESTS OFF) +set(PCRE2_BUILD_PCRE2GREP OFF) +set(PCRE2_BUILD_PCRE2TEST OFF) +set(PCRE2_BUILD_PCRE2GPERF OFF) +set(PCRE2_BUILD_DOCS OFF) +set(PCRE2_BUILD_LIBPCRE2_PDB OFF) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2) + set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) file(GLOB tokenizers_source_files ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) @@ -45,9 +58,10 @@ target_include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece/src ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/llama.cpp-unicode/include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) -target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2) +target_link_libraries(tokenizers PUBLIC sentencepiece-static re2::re2 pcre2-8) # Build test if(TOKENIZERS_BUILD_TEST) @@ -77,7 +91,8 @@ if(TOKENIZERS_BUILD_TEST) ${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/third-party/sentencepiece ${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2 - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/json/single_include + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/pcre2/src) target_link_libraries(${test_name} gtest_main GTest::gmock tokenizers) add_test(${test_name} "${test_name}") set_tests_properties(${test_name} PROPERTIES ENVIRONMENT ${test_env}) diff --git a/include/pytorch/tokenizers/pcre2_regex.h b/include/pytorch/tokenizers/pcre2_regex.h new file mode 100644 index 0000000..8d172b4 --- /dev/null +++ b/include/pytorch/tokenizers/pcre2_regex.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +// Define PCRE2 code unit width before including pcre2.h +#define PCRE2_CODE_UNIT_WIDTH 8 + +// Third Party +#include + +// Local +#include "regex.h" + +/** + * @brief PCRE2-based implementation of IRegex. + */ +class Pcre2Regex : public IRegex { + public: + /** + * @brief Construct a PCRE2 regex with the given pattern. + * + * @param pattern The regex pattern to compile. + */ + explicit Pcre2Regex(const std::string& pattern); + + /** + * @brief Destructor to clean up PCRE2 resources. + */ + ~Pcre2Regex(); + + /** + * @brief Return all non-overlapping matches found in the input string. + */ + virtual std::vector findAll(const std::string& text) const override; + + /** + * @brief Check if PCRE2 compiled the pattern successfully. + */ + bool ok() const; + + protected: + /** + * @brief Expose internal PCRE2 pointer to the factory if needed. + */ + const pcre2_code* rawRegex() const; + + private: + pcre2_code* regex_; + pcre2_match_data* match_data_; + bool is_valid_; + + friend std::unique_ptr createRegex(const std::string& pattern); +}; \ No newline at end of file diff --git a/src/pcre2_regex.cpp b/src/pcre2_regex.cpp new file mode 100644 index 0000000..65cfeef --- /dev/null +++ b/src/pcre2_regex.cpp @@ -0,0 +1,113 @@ +#include "pytorch/tokenizers/pcre2_regex.h" + +#include +#include + +Pcre2Regex::Pcre2Regex(const std::string& pattern) : regex_(nullptr), match_data_(nullptr), is_valid_(false) { + int error_code; + PCRE2_SIZE error_offset; + + // Compile the pattern + regex_ = pcre2_compile( + reinterpret_cast(pattern.c_str()), + pattern.length(), + PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode + &error_code, + &error_offset, + nullptr + ); + + if (regex_ == nullptr) { + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " << error_buffer << std::endl; + return; + } + + // Create match data + match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr); + if (match_data_ == nullptr) { + pcre2_code_free(regex_); + regex_ = nullptr; + std::cerr << "Failed to create PCRE2 match data" << std::endl; + return; + } + + is_valid_ = true; +} + +Pcre2Regex::~Pcre2Regex() { + if (match_data_) { + pcre2_match_data_free(match_data_); + } + if (regex_) { + pcre2_code_free(regex_); + } +} + +std::vector Pcre2Regex::findAll(const std::string& text) const { + std::vector result; + + if (!is_valid_ || !regex_ || !match_data_) { + return result; + } + + PCRE2_SIZE* ovector; + PCRE2_SPTR subject = reinterpret_cast(text.c_str()); + PCRE2_SIZE subject_length = text.length(); + PCRE2_SIZE offset = 0; + + while (offset < subject_length) { + int rc = pcre2_match( + regex_, + subject, + subject_length, + offset, + 0, // Default options + match_data_, + nullptr + ); + + if (rc < 0) { + if (rc == PCRE2_ERROR_NOMATCH) { + break; // No more matches + } else { + // Error occurred + PCRE2_UCHAR error_buffer[256]; + pcre2_get_error_message(rc, error_buffer, sizeof(error_buffer)); + std::cerr << "PCRE2 matching error: " << error_buffer << std::endl; + break; + } + } + + ovector = pcre2_get_ovector_pointer(match_data_); + + // Extract the match + size_t match_start = ovector[0]; + size_t match_length = ovector[1] - ovector[0]; + + // Add the match to the result + result.push_back({ + text.substr(match_start, match_length), + match_start + }); + + // Move to the next position after the match + offset = ovector[1]; + + // If the match was empty, move forward by one character to avoid infinite loop + if (ovector[0] == ovector[1]) { + offset++; + } + } + + return result; +} + +bool Pcre2Regex::ok() const { + return is_valid_ && regex_ != nullptr && match_data_ != nullptr; +} + +const pcre2_code* Pcre2Regex::rawRegex() const { + return regex_; +} \ No newline at end of file diff --git a/src/regex.cpp b/src/regex.cpp index cb9df9a..c215be0 100644 --- a/src/regex.cpp +++ b/src/regex.cpp @@ -1,6 +1,7 @@ #include "pytorch/tokenizers/regex.h" #include "pytorch/tokenizers/re2_regex.h" #include "pytorch/tokenizers/std_regex.h" +#include "pytorch/tokenizers/pcre2_regex.h" #include #include @@ -8,10 +9,10 @@ /** * @brief Factory function that creates a regex object using RE2 if possible. - * Falls back to std::regex if RE2 rejects the pattern with - * ErrorBadPerlOp. + * Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if PCRE2 fails. */ std::unique_ptr createRegex(const std::string& pattern) { + // Try RE2 first auto re2 = std::make_unique("(" + pattern + ")"); if (re2->ok()) { @@ -20,10 +21,18 @@ std::unique_ptr createRegex(const std::string& pattern) { const re2::RE2* raw = re2->rawRegex(); if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) { + // RE2 doesn't support some Perl features, try PCRE2 + auto pcre2 = std::make_unique("(" + pattern + ")"); + + if (pcre2->ok()) { + std::cout << "RE2 is unable to support things such as negative lookaheads in " + << pattern << ", using PCRE2 instead."; + return pcre2; + } + + // If PCRE2 also fails, fall back to std::regex try { - std::cout - << "RE2 is unable to support things such as negative lookaheads in " - << pattern << ", defaulting to std::regex."; + std::cout << "PCRE2 failed to compile pattern, falling back to std::regex."; return std::make_unique("(" + pattern + ")"); } catch (const std::regex_error& e) { std::cerr << "std::regex failed: " << e.what() << std::endl; diff --git a/test/test_pcre2_regex.cpp b/test/test_pcre2_regex.cpp new file mode 100644 index 0000000..0d688ab --- /dev/null +++ b/test/test_pcre2_regex.cpp @@ -0,0 +1,68 @@ +#include + +#include "pytorch/tokenizers/pcre2_regex.h" +#include "pytorch/tokenizers/regex.h" + +TEST(Pcre2RegexTest, BasicMatching) { + Pcre2Regex regex("(\\w+)"); + std::vector matches = regex.findAll("Hello world"); + + ASSERT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + EXPECT_EQ(matches[1].text, "world"); + EXPECT_EQ(matches[1].position, 6); +} + +TEST(Pcre2RegexTest, ComplexPatterns) { + // Test with a more complex pattern that includes lookaheads + Pcre2Regex regex("(?<=@)(\\w+)"); + std::vector matches = regex.findAll("user@example.com"); + + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].text, "example"); + EXPECT_EQ(matches[0].position, 5); +} + +TEST(Pcre2RegexTest, UnicodeSupport) { + // Test with Unicode characters + Pcre2Regex regex("(\\p{L}+)"); + std::vector matches = regex.findAll("Hello 世界"); + + ASSERT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + EXPECT_EQ(matches[1].text, "世界"); + EXPECT_EQ(matches[1].position, 6); +} + +TEST(Pcre2RegexTest, InvalidPattern) { + // Test with an invalid pattern + Pcre2Regex regex("("); // Unmatched parenthesis + EXPECT_FALSE(regex.ok()); + + std::vector matches = regex.findAll("test"); + EXPECT_TRUE(matches.empty()); +} + +TEST(Pcre2RegexTest, EmptyMatches) { + // Test with a pattern that can match empty strings + Pcre2Regex regex("(a*)"); + std::vector matches = regex.findAll("b"); + + // Should find one empty match at the beginning + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].text, ""); + EXPECT_EQ(matches[0].position, 0); +} + +TEST(Pcre2RegexTest, FactoryFunction) { + // Test the factory function with a pattern that RE2 doesn't support + auto regex = createRegex("(?<=@)(\\w+)"); + ASSERT_NE(regex, nullptr); + + std::vector matches = regex->findAll("user@example.com"); + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].text, "example"); + EXPECT_EQ(matches[0].position, 5); +} \ No newline at end of file diff --git a/third-party/pcre2 b/third-party/pcre2 new file mode 160000 index 0000000..2e03e32 --- /dev/null +++ b/third-party/pcre2 @@ -0,0 +1 @@ +Subproject commit 2e03e323339ab692640626f02f8d8d6f95bff9c6 From c3762c89a31fe91b406c7b609d049533cdf9c521 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 09:05:57 -0700 Subject: [PATCH 8/9] Lint --- include/pytorch/tokenizers/pcre2_regex.h | 2 +- src/pcre2_regex.cpp | 74 ++++++++++++------------ src/regex.cpp | 17 +++--- test/test_pcre2_regex.cpp | 16 ++--- 4 files changed, 55 insertions(+), 54 deletions(-) diff --git a/include/pytorch/tokenizers/pcre2_regex.h b/include/pytorch/tokenizers/pcre2_regex.h index 8d172b4..97c88d1 100644 --- a/include/pytorch/tokenizers/pcre2_regex.h +++ b/include/pytorch/tokenizers/pcre2_regex.h @@ -51,4 +51,4 @@ class Pcre2Regex : public IRegex { bool is_valid_; friend std::unique_ptr createRegex(const std::string& pattern); -}; \ No newline at end of file +}; \ No newline at end of file diff --git a/src/pcre2_regex.cpp b/src/pcre2_regex.cpp index 65cfeef..b484ed3 100644 --- a/src/pcre2_regex.cpp +++ b/src/pcre2_regex.cpp @@ -3,27 +3,28 @@ #include #include -Pcre2Regex::Pcre2Regex(const std::string& pattern) : regex_(nullptr), match_data_(nullptr), is_valid_(false) { +Pcre2Regex::Pcre2Regex(const std::string& pattern) + : regex_(nullptr), match_data_(nullptr), is_valid_(false) { int error_code; PCRE2_SIZE error_offset; - + // Compile the pattern regex_ = pcre2_compile( - reinterpret_cast(pattern.c_str()), - pattern.length(), - PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode - &error_code, - &error_offset, - nullptr - ); - + reinterpret_cast(pattern.c_str()), + pattern.length(), + PCRE2_UCP | PCRE2_UTF, // Enable Unicode support and UTF-8 mode + &error_code, + &error_offset, + nullptr); + if (regex_ == nullptr) { PCRE2_UCHAR error_buffer[256]; pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer)); - std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " << error_buffer << std::endl; + std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " + << error_buffer << std::endl; return; } - + // Create match data match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr); if (match_data_ == nullptr) { @@ -32,7 +33,7 @@ Pcre2Regex::Pcre2Regex(const std::string& pattern) : regex_(nullptr), match_data std::cerr << "Failed to create PCRE2 match data" << std::endl; return; } - + is_valid_ = true; } @@ -47,30 +48,29 @@ Pcre2Regex::~Pcre2Regex() { std::vector Pcre2Regex::findAll(const std::string& text) const { std::vector result; - + if (!is_valid_ || !regex_ || !match_data_) { return result; } - + PCRE2_SIZE* ovector; PCRE2_SPTR subject = reinterpret_cast(text.c_str()); PCRE2_SIZE subject_length = text.length(); PCRE2_SIZE offset = 0; - + while (offset < subject_length) { int rc = pcre2_match( - regex_, - subject, - subject_length, - offset, - 0, // Default options - match_data_, - nullptr - ); - + regex_, + subject, + subject_length, + offset, + 0, // Default options + match_data_, + nullptr); + if (rc < 0) { if (rc == PCRE2_ERROR_NOMATCH) { - break; // No more matches + break; // No more matches } else { // Error occurred PCRE2_UCHAR error_buffer[256]; @@ -79,28 +79,26 @@ std::vector Pcre2Regex::findAll(const std::string& text) const { break; } } - + ovector = pcre2_get_ovector_pointer(match_data_); - + // Extract the match size_t match_start = ovector[0]; size_t match_length = ovector[1] - ovector[0]; - + // Add the match to the result - result.push_back({ - text.substr(match_start, match_length), - match_start - }); - + result.push_back({text.substr(match_start, match_length), match_start}); + // Move to the next position after the match offset = ovector[1]; - - // If the match was empty, move forward by one character to avoid infinite loop + + // If the match was empty, move forward by one character to avoid infinite + // loop if (ovector[0] == ovector[1]) { offset++; } } - + return result; } @@ -110,4 +108,4 @@ bool Pcre2Regex::ok() const { const pcre2_code* Pcre2Regex::rawRegex() const { return regex_; -} \ No newline at end of file +} \ No newline at end of file diff --git a/src/regex.cpp b/src/regex.cpp index c215be0..0df2895 100644 --- a/src/regex.cpp +++ b/src/regex.cpp @@ -1,7 +1,7 @@ #include "pytorch/tokenizers/regex.h" +#include "pytorch/tokenizers/pcre2_regex.h" #include "pytorch/tokenizers/re2_regex.h" #include "pytorch/tokenizers/std_regex.h" -#include "pytorch/tokenizers/pcre2_regex.h" #include #include @@ -9,7 +9,8 @@ /** * @brief Factory function that creates a regex object using RE2 if possible. - * Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if PCRE2 fails. + * Falls back to PCRE2 if RE2 rejects the pattern, then to std::regex if + * PCRE2 fails. */ std::unique_ptr createRegex(const std::string& pattern) { // Try RE2 first @@ -23,16 +24,18 @@ std::unique_ptr createRegex(const std::string& pattern) { if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) { // RE2 doesn't support some Perl features, try PCRE2 auto pcre2 = std::make_unique("(" + pattern + ")"); - + if (pcre2->ok()) { - std::cout << "RE2 is unable to support things such as negative lookaheads in " - << pattern << ", using PCRE2 instead."; + std::cout + << "RE2 is unable to support things such as negative lookaheads in " + << pattern << ", using PCRE2 instead." << std::endl; return pcre2; } - + // If PCRE2 also fails, fall back to std::regex try { - std::cout << "PCRE2 failed to compile pattern, falling back to std::regex."; + std::cout + << "PCRE2 failed to compile pattern, falling back to std::regex."; return std::make_unique("(" + pattern + ")"); } catch (const std::regex_error& e) { std::cerr << "std::regex failed: " << e.what() << std::endl; diff --git a/test/test_pcre2_regex.cpp b/test/test_pcre2_regex.cpp index 0d688ab..3dfb83a 100644 --- a/test/test_pcre2_regex.cpp +++ b/test/test_pcre2_regex.cpp @@ -6,7 +6,7 @@ TEST(Pcre2RegexTest, BasicMatching) { Pcre2Regex regex("(\\w+)"); std::vector matches = regex.findAll("Hello world"); - + ASSERT_EQ(matches.size(), 2); EXPECT_EQ(matches[0].text, "Hello"); EXPECT_EQ(matches[0].position, 0); @@ -18,7 +18,7 @@ TEST(Pcre2RegexTest, ComplexPatterns) { // Test with a more complex pattern that includes lookaheads Pcre2Regex regex("(?<=@)(\\w+)"); std::vector matches = regex.findAll("user@example.com"); - + ASSERT_EQ(matches.size(), 1); EXPECT_EQ(matches[0].text, "example"); EXPECT_EQ(matches[0].position, 5); @@ -28,7 +28,7 @@ TEST(Pcre2RegexTest, UnicodeSupport) { // Test with Unicode characters Pcre2Regex regex("(\\p{L}+)"); std::vector matches = regex.findAll("Hello 世界"); - + ASSERT_EQ(matches.size(), 2); EXPECT_EQ(matches[0].text, "Hello"); EXPECT_EQ(matches[0].position, 0); @@ -38,9 +38,9 @@ TEST(Pcre2RegexTest, UnicodeSupport) { TEST(Pcre2RegexTest, InvalidPattern) { // Test with an invalid pattern - Pcre2Regex regex("("); // Unmatched parenthesis + Pcre2Regex regex("("); // Unmatched parenthesis EXPECT_FALSE(regex.ok()); - + std::vector matches = regex.findAll("test"); EXPECT_TRUE(matches.empty()); } @@ -49,7 +49,7 @@ TEST(Pcre2RegexTest, EmptyMatches) { // Test with a pattern that can match empty strings Pcre2Regex regex("(a*)"); std::vector matches = regex.findAll("b"); - + // Should find one empty match at the beginning ASSERT_EQ(matches.size(), 1); EXPECT_EQ(matches[0].text, ""); @@ -60,9 +60,9 @@ TEST(Pcre2RegexTest, FactoryFunction) { // Test the factory function with a pattern that RE2 doesn't support auto regex = createRegex("(?<=@)(\\w+)"); ASSERT_NE(regex, nullptr); - + std::vector matches = regex->findAll("user@example.com"); ASSERT_EQ(matches.size(), 1); EXPECT_EQ(matches[0].text, "example"); EXPECT_EQ(matches[0].position, 5); -} \ No newline at end of file +} \ No newline at end of file From 355c49a7ccc54e0e34b01f7ade7abb01dbd71766 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 15 Apr 2025 14:57:56 -0700 Subject: [PATCH 9/9] Make test general --- include/pytorch/tokenizers/pcre2_regex.h | 2 +- include/pytorch/tokenizers/re2_regex.h | 4 +- include/pytorch/tokenizers/regex.h | 7 ++ include/pytorch/tokenizers/std_regex.h | 7 ++ src/std_regex.cpp | 6 ++ test/test_pcre2_regex.cpp | 68 ------------------- test/test_regex.cpp | 83 ++++++++++++++++++++++++ 7 files changed, 106 insertions(+), 71 deletions(-) delete mode 100644 test/test_pcre2_regex.cpp create mode 100644 test/test_regex.cpp diff --git a/include/pytorch/tokenizers/pcre2_regex.h b/include/pytorch/tokenizers/pcre2_regex.h index 97c88d1..1015c43 100644 --- a/include/pytorch/tokenizers/pcre2_regex.h +++ b/include/pytorch/tokenizers/pcre2_regex.h @@ -37,7 +37,7 @@ class Pcre2Regex : public IRegex { /** * @brief Check if PCRE2 compiled the pattern successfully. */ - bool ok() const; + bool ok() const override; protected: /** diff --git a/include/pytorch/tokenizers/re2_regex.h b/include/pytorch/tokenizers/re2_regex.h index 7a3c64c..c615713 100644 --- a/include/pytorch/tokenizers/re2_regex.h +++ b/include/pytorch/tokenizers/re2_regex.h @@ -26,12 +26,12 @@ class Re2Regex : public IRegex { */ virtual std::vector findAll(const std::string& text) const override; - protected: /** * @brief Check if RE2 compiled the pattern successfully. */ - bool ok() const; + bool ok() const override; + protected: /** * @brief Expose internal RE2 pointer to the factory if needed. */ diff --git a/include/pytorch/tokenizers/regex.h b/include/pytorch/tokenizers/regex.h index ac6c80b..98dbc9f 100644 --- a/include/pytorch/tokenizers/regex.h +++ b/include/pytorch/tokenizers/regex.h @@ -23,6 +23,13 @@ class IRegex { * @return A vector of strings containing all matched substrings. */ virtual std::vector findAll(const std::string& text) const = 0; + + /** + * @brief Check if the regex pattern was compiled successfully. + * + * @return true if the pattern is valid and ready to use, false otherwise. + */ + virtual bool ok() const = 0; }; /** diff --git a/include/pytorch/tokenizers/std_regex.h b/include/pytorch/tokenizers/std_regex.h index 41828bf..e49127b 100644 --- a/include/pytorch/tokenizers/std_regex.h +++ b/include/pytorch/tokenizers/std_regex.h @@ -23,6 +23,13 @@ class StdRegex : public IRegex { */ virtual std::vector findAll(const std::string& text) const override; + /** + * @brief Check if std::regex compiled the pattern successfully. + * + * @return true if the pattern is valid, false otherwise. + */ + bool ok() const override; + private: std::regex regex_; }; diff --git a/src/std_regex.cpp b/src/std_regex.cpp index 83c8e6d..5e0e248 100644 --- a/src/std_regex.cpp +++ b/src/std_regex.cpp @@ -18,3 +18,9 @@ std::vector StdRegex::findAll(const std::string& text) const { return result; } + +bool StdRegex::ok() const { + // std::regex constructor throws if the pattern is invalid + // If we got here, the pattern is valid + return true; +} diff --git a/test/test_pcre2_regex.cpp b/test/test_pcre2_regex.cpp deleted file mode 100644 index 3dfb83a..0000000 --- a/test/test_pcre2_regex.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include - -#include "pytorch/tokenizers/pcre2_regex.h" -#include "pytorch/tokenizers/regex.h" - -TEST(Pcre2RegexTest, BasicMatching) { - Pcre2Regex regex("(\\w+)"); - std::vector matches = regex.findAll("Hello world"); - - ASSERT_EQ(matches.size(), 2); - EXPECT_EQ(matches[0].text, "Hello"); - EXPECT_EQ(matches[0].position, 0); - EXPECT_EQ(matches[1].text, "world"); - EXPECT_EQ(matches[1].position, 6); -} - -TEST(Pcre2RegexTest, ComplexPatterns) { - // Test with a more complex pattern that includes lookaheads - Pcre2Regex regex("(?<=@)(\\w+)"); - std::vector matches = regex.findAll("user@example.com"); - - ASSERT_EQ(matches.size(), 1); - EXPECT_EQ(matches[0].text, "example"); - EXPECT_EQ(matches[0].position, 5); -} - -TEST(Pcre2RegexTest, UnicodeSupport) { - // Test with Unicode characters - Pcre2Regex regex("(\\p{L}+)"); - std::vector matches = regex.findAll("Hello 世界"); - - ASSERT_EQ(matches.size(), 2); - EXPECT_EQ(matches[0].text, "Hello"); - EXPECT_EQ(matches[0].position, 0); - EXPECT_EQ(matches[1].text, "世界"); - EXPECT_EQ(matches[1].position, 6); -} - -TEST(Pcre2RegexTest, InvalidPattern) { - // Test with an invalid pattern - Pcre2Regex regex("("); // Unmatched parenthesis - EXPECT_FALSE(regex.ok()); - - std::vector matches = regex.findAll("test"); - EXPECT_TRUE(matches.empty()); -} - -TEST(Pcre2RegexTest, EmptyMatches) { - // Test with a pattern that can match empty strings - Pcre2Regex regex("(a*)"); - std::vector matches = regex.findAll("b"); - - // Should find one empty match at the beginning - ASSERT_EQ(matches.size(), 1); - EXPECT_EQ(matches[0].text, ""); - EXPECT_EQ(matches[0].position, 0); -} - -TEST(Pcre2RegexTest, FactoryFunction) { - // Test the factory function with a pattern that RE2 doesn't support - auto regex = createRegex("(?<=@)(\\w+)"); - ASSERT_NE(regex, nullptr); - - std::vector matches = regex->findAll("user@example.com"); - ASSERT_EQ(matches.size(), 1); - EXPECT_EQ(matches[0].text, "example"); - EXPECT_EQ(matches[0].position, 5); -} \ No newline at end of file diff --git a/test/test_regex.cpp b/test/test_regex.cpp new file mode 100644 index 0000000..915c62c --- /dev/null +++ b/test/test_regex.cpp @@ -0,0 +1,83 @@ +#include + +#include "pytorch/tokenizers/regex.h" +#include "pytorch/tokenizers/re2_regex.h" +#include "pytorch/tokenizers/pcre2_regex.h" + +// Test basic functionality +TEST(RegexTest, BasicMatching) { + auto regex = createRegex("\\w+"); + ASSERT_TRUE(regex->ok()); + + std::string text = "Hello world"; + auto matches = regex->findAll(text); + ASSERT_EQ(matches.size(), 2); + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + EXPECT_EQ(matches[1].text, "world"); + EXPECT_EQ(matches[1].position, 6); +} + +// Test pattern that only PCRE2 supports (lookbehind) +TEST(RegexTest, Pcre2Specific) { + // First verify that RE2 cannot handle this pattern + const std::string pattern = "(?<=@)\\w+"; + Re2Regex re2_regex(pattern); + ASSERT_FALSE(re2_regex.ok()); + + // Now verify that the factory function fallsback on a PCRE2 regex + auto regex = createRegex(pattern); + ASSERT_TRUE(regex->ok()); + + std::string text = "user@example.com"; + auto matches = regex->findAll(text); + ASSERT_EQ(matches.size(), 1); + EXPECT_EQ(matches[0].text, "example"); + EXPECT_EQ(matches[0].position, 5); +} + +// Test complex pattern with negative lookahead that should fall back to PCRE2. +// This specific pattern is from the Qwen2.5 1.5B pretokenizer. +// https://huggingface.co/Qwen/Qwen2.5-1.5B/raw/main/tokenizer.json +TEST(RegexTest, ComplexPatternWithNegativeLookahead) { + const std::string complex_pattern = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // First verify that RE2 cannot handle this pattern + Re2Regex re2_regex(complex_pattern); + ASSERT_FALSE(re2_regex.ok()); + + // Now verify that the factory function fallsback on a PCRE2 regex + auto regex = createRegex(complex_pattern); + ASSERT_TRUE(regex->ok()); + + // Test the pattern with some sample text + std::string text = "Hello's world\n test"; + auto matches = regex->findAll(text); + + // We expect to match: + // 1. "Hello" (word) + // 2. "'s" (contraction) + // 3. " world" (word with leading space) + // 4. "\n" (newline) + // 5. " " (whitespace) + // 6. " test" (word with leading space) + ASSERT_EQ(matches.size(), 6); + + EXPECT_EQ(matches[0].text, "Hello"); + EXPECT_EQ(matches[0].position, 0); + + EXPECT_EQ(matches[1].text, "'s"); + EXPECT_EQ(matches[1].position, 5); + + EXPECT_EQ(matches[2].text, " world"); + EXPECT_EQ(matches[2].position, 7); + + EXPECT_EQ(matches[3].text, "\n"); + EXPECT_EQ(matches[3].position, 13); + + EXPECT_EQ(matches[4].text, " "); + EXPECT_EQ(matches[4].position, 14); + + EXPECT_EQ(matches[5].text, " test"); + EXPECT_EQ(matches[5].position, 15); +} \ No newline at end of file