Files
ragflow/internal/cpp/rag_analyzer.cpp
Jin Hai 70e9743ef1 RAGFlow go API server (#13240)
# RAGFlow Go Implementation Plan 🚀

This repository tracks the progress of porting RAGFlow to Go. We'll
implement core features and provide performance comparisons between
Python and Go versions.

## Implementation Checklist

- [x] User Management APIs
- [x] Dataset Management Operations
- [x] Retrieval Test
- [x] Chat Management Operations
- [x] Infinity Go SDK

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
Co-authored-by: Yingfeng Zhang <yingfeng.zhang@gmail.com>
2026-03-04 19:17:16 +08:00

2431 lines
92 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright(C) 2024 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#define PCRE2_CODE_UNIT_WIDTH 8
#include "opencc/openccxx.h"
#include "pcre2.h"
#include "string_utils.h"
#include "rag_analyzer.h"
#include "re2/re2.h"
#include <cassert>
#include <cstdint>
#include <filesystem>
#include <iostream>
#include <cmath>
#include <fstream>
// import :term;
// import :stemmer;
// import :analyzer;
// import :darts_trie;
// import :wordnet_lemmatizer;
// import :stemmer;
// import :term;
//
// import std.compat;
namespace fs = std::filesystem;
static const std::string DICT_PATH = "rag/huqie.txt";
static const std::string POS_DEF_PATH = "rag/pos-id.def";
static const std::string TRIE_PATH = "rag/huqie.trie";
static const std::string WORDNET_PATH = "wordnet";
static const std::string OPENCC_PATH = "opencc";
static const std::string REGEX_SPLIT_CHAR =
R"#(([ ,\.<>/?;'\[\]\`!@#$%^&*$$\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z\.-]+|[0-9,\.-]+))#";
static const std::string NLTK_TOKENIZE_PATTERN =
R"((?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|(?=[^\(\"\`{\[:;&\#\*@\)}\]\-,])\S+?(?=\s|$|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)|,(?=$|\s|(?:[)\";}\]\*:@\'\({\[\?!])|(?:\-{2,}|\.{2,}|(?:\.\s){2,}\.)))|\S)";
static constexpr std::size_t MAX_SENTENCE_LEN = 100;
static inline int32_t Encode(int32_t freq, int32_t idx) {
uint32_t encoded_value = 0;
if (freq < 0) {
encoded_value |= static_cast<uint32_t>(-freq);
encoded_value |= (1U << 23);
} else {
encoded_value = static_cast<uint32_t>(freq & 0x7FFFFF);
}
encoded_value |= static_cast<uint32_t>(idx) << 24;
return static_cast<int32_t>(encoded_value);
}
static inline int32_t DecodeFreq(int32_t value) {
uint32_t v1 = static_cast<uint32_t>(value) & 0xFFFFFF;
if (v1 & (1 << 23)) {
v1 &= 0x7FFFFF;
return -static_cast<int32_t>(v1);
} else {
v1 = static_cast<int32_t>(v1);
}
return v1;
}
static inline int32_t DecodePOSIndex(int32_t value) {
// POS index is stored in the high 8 bits (bits 24-31)
return static_cast<int32_t>(static_cast<uint32_t>(value) >> 24);
}
void Split(const std::string &input, const std::string &split_pattern, std::vector<std::string> &result, bool keep_delim = false) {
re2::RE2 pattern(split_pattern);
re2::StringPiece leftover(input.data());
re2::StringPiece last_end = leftover;
re2::StringPiece extracted_delim_token;
while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) {
std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data());
if (!token.empty()) {
result.emplace_back(token.data(), token.size());
}
if (keep_delim)
result.emplace_back(extracted_delim_token.data(), extracted_delim_token.size());
last_end = leftover;
}
if (!leftover.empty()) {
result.emplace_back(leftover.data(), leftover.size());
}
}
void Split(const std::string &input, const re2::RE2 &pattern, std::vector<std::string> &result, bool keep_delim = false) {
re2::StringPiece leftover(input.data());
re2::StringPiece last_end = leftover;
re2::StringPiece extracted_delim_token;
while (RE2::FindAndConsume(&leftover, pattern, &extracted_delim_token)) {
std::string_view token(last_end.data(), extracted_delim_token.data() - last_end.data());
if (!token.empty()) {
result.emplace_back(token.data(), token.size());
}
if (keep_delim)
result.emplace_back(extracted_delim_token.data(), extracted_delim_token.size());
last_end = leftover;
}
if (!leftover.empty()) {
result.emplace_back(leftover.data(), leftover.size());
}
}
std::string Replace(const re2::RE2 &re, const std::string &replacement, const std::string &input) {
std::string output = input;
re2::RE2::GlobalReplace(&output, re, replacement);
return output;
}
template <typename T>
std::string Join(const std::vector<T> &tokens, int start, int end, const std::string &delim = " ") {
std::ostringstream oss;
for (int i = start; i < end; ++i) {
if (i > start)
oss << delim;
oss << tokens[i];
}
return std::move(oss).str();
}
template <typename T>
std::string Join(const std::vector<T> &tokens, int start, const std::string &delim = " ") {
return Join(tokens, start, tokens.size(), delim);
}
std::string Join(const TermList &tokens, int start, int end, const std::string &delim = " ") {
std::ostringstream oss;
for (int i = start; i < end; ++i) {
if (i > start)
oss << delim;
oss << tokens[i].text_;
}
return std::move(oss).str();
}
bool IsChinese(const std::string &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
if (c >= 0xE4 && c <= 0xE9) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 >= 0x80 && c2 <= 0xBF) && (c3 >= 0x80 && c3 <= 0xBF)) {
return true;
}
}
}
}
return false;
}
bool IsAlphabet(const std::string &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
if (c > 0x7F) {
return false;
}
}
return true;
}
bool IsKorean(const std::string &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
if (c == 0xE1) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 == 0x84 || c2 == 0x85 || c2 == 0x86 || c2 == 0x87) && (c3 >= 0x80 && c3 <= 0xBF)) {
return true;
}
}
}
}
return false;
}
bool IsJapanese(const std::string &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
if (c == 0xE3) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 == 0x81 || c2 == 0x82 || c2 == 0x83) && (c3 >= 0x81 && c3 <= 0xBF)) {
return true;
}
}
}
}
return false;
}
bool IsCJK(const std::string &str) {
for (std::size_t i = 0; i < str.length(); ++i) {
unsigned char c = str[i];
// Check Chinese
if (c >= 0xE4 && c <= 0xE9) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 >= 0x80 && c2 <= 0xBF) && (c3 >= 0x80 && c3 <= 0xBF)) {
return true;
}
}
}
// Check Japanese
if (c == 0xE3) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 == 0x81 || c2 == 0x82 || c2 == 0x83) && (c3 >= 0x81 && c3 <= 0xBF)) {
return true;
}
}
}
// Check Korean
if (c == 0xE1) {
if (i + 2 < str.length()) {
unsigned char c2 = str[i + 1];
unsigned char c3 = str[i + 2];
if ((c2 == 0x84 || c2 == 0x85 || c2 == 0x86 || c2 == 0x87) && (c3 >= 0x80 && c3 <= 0xBF)) {
return true;
}
}
}
}
return false;
}
class RegexTokenizer {
public:
RegexTokenizer() {
int errorcode = 0;
PCRE2_SIZE erroffset = 0;
re_ = pcre2_compile((PCRE2_SPTR)(NLTK_TOKENIZE_PATTERN.c_str()),
PCRE2_ZERO_TERMINATED,
PCRE2_MULTILINE | PCRE2_UTF,
&errorcode,
&erroffset,
nullptr);
}
~RegexTokenizer() {
pcre2_code_free(re_);
}
void RegexTokenize(const std::string &input, TermList &tokens) {
PCRE2_SPTR subject = (PCRE2_SPTR)input.c_str();
PCRE2_SIZE subject_length = input.length();
pcre2_match_data_8 *match_data = pcre2_match_data_create_8(1024, nullptr);
PCRE2_SIZE start_offset = 0;
while (start_offset < subject_length) {
int res = pcre2_match(re_, subject, subject_length, start_offset, 0, match_data, nullptr);
if (res < 0) {
if (res == PCRE2_ERROR_NOMATCH) {
break; // No more matches
} else {
std::cerr << "Matching error code: " << res << std::endl;
break; // Other error
}
}
// Extract matched substring
PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data);
for (int i = 0; i < res; ++i) {
PCRE2_SIZE start = ovector[2 * i];
PCRE2_SIZE end = ovector[2 * i + 1];
tokens.Add(input.c_str() + start, end - start, start, end);
}
// Update the start offset for the next search
start_offset = ovector[1]; // Move to the end of the last match
}
// Free memory
pcre2_match_data_free(match_data);
}
private:
pcre2_code_8 *re_{nullptr};
};
class MacIntyreContractions {
public:
// List of contractions adapted from Robert MacIntyre's tokenizer.
std::vector<std::string> CONTRACTIONS2 = {R"((?i)\b(can)(?#X)(not)\b)",
R"((?i)\b(d)(?#X)('ye)\b)",
R"((?i)\b(gim)(?#X)(me)\b)",
R"((?i)\b(gon)(?#X)(na)\b)",
R"((?i)\b(got)(?#X)(ta)\b)",
R"((?i)\b(lem)(?#X)(me)\b)",
R"((?i)\b(more)(?#X)('n)\b)",
R"((?i)\b(wan)(?#X)(na)(?=\s))"};
std::vector<std::string> CONTRACTIONS3 = {R"((?i) ('t)(?#X)(is)\b)", R"((?i) ('t)(?#X)(was)\b)"};
std::vector<std::string> CONTRACTIONS4 = {R"((?i)\b(whad)(dd)(ya)\b)", R"((?i)\b(wha)(t)(cha)\b)"};
};
// Structure to hold precompiled regex patterns
struct CompiledRegex {
pcre2_code *re{nullptr};
std::string substitution;
CompiledRegex(pcre2_code *r, std::string sub) : re(r), substitution(std::move(sub)) {
}
CompiledRegex(const CompiledRegex &) = delete;
CompiledRegex &operator=(const CompiledRegex &) = delete;
CompiledRegex(CompiledRegex &&other) noexcept : re(other.re), substitution(std::move(other.substitution)) { other.re = nullptr; }
CompiledRegex &operator=(CompiledRegex &&other) noexcept {
if (this != &other) {
if (re)
pcre2_code_free(re);
re = other.re;
substitution = std::move(other.substitution);
other.re = nullptr;
}
return *this;
}
~CompiledRegex() {
if (re) {
pcre2_code_free(re);
}
}
};
class NLTKWordTokenizer {
MacIntyreContractions contractions_;
// Static singleton instance
static std::unique_ptr<NLTKWordTokenizer> instance_;
static std::once_flag init_flag_;
public:
// Static method to get the singleton instance
static NLTKWordTokenizer &GetInstance() {
std::call_once(init_flag_, []() { instance_ = std::make_unique<NLTKWordTokenizer>(); });
return *instance_;
}
// Starting quotes.
std::vector<std::pair<std::string, std::string>> STARTING_QUOTES = {
{std::string(R"(([«“‘„]|[`]+))"), std::string(R"( $1 )")},
{std::string(R"(^\")"), std::string(R"(``)")},
{std::string(R"((``))"), std::string(R"( $1 )")},
{std::string(R"(([ \(\[{<])(\"|\'{2}))"), std::string(R"($1 `` )")},
{std::string(R"((?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b)"), std::string(R"($1 $2)")}};
// Ending quotes.
std::vector<std::pair<std::string, std::string>> ENDING_QUOTES = {
{std::string(R"(([»”’]))"), std::string(R"( $1 )")},
{std::string(R"('')"), std::string(R"( '' )")},
{std::string(R"(")"), std::string(R"( '' )")},
{std::string(R"(\s+)"), std::string(R"( )")},
{std::string(R"(([^' ])('[sS]|'[mM]|'[dD]|') )"), std::string(R"($1 $2 )")},
{std::string(R"(([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) )"), std::string(R"($1 $2 )")}};
// Punctuation.
std::vector<std::pair<std::string, std::string>> PUNCTUATION = {
{std::string(R"(([^\.])(\.)([\]\)}>"\'» ]*)\s*$)"), std::string(R"($1 $2 $3 )")},
{std::string(R"(([:,])([^\d]))"), std::string(R"( $1 $2)")},
{std::string(R"(([:,])$)"), std::string(R"($1 )")},
{std::string(R"(\.{2,})"), std::string(R"($0 )")},
{std::string(R"([;@#$%&])"), std::string(R"($0 )")},
{std::string(R"(([^\.])(\.)([\]\)}>"\']*)\s*$)"), std::string(R"($1 $2 $3 )")},
{std::string(R"([?!])"), std::string(R"($0 )")},
{std::string(R"(([^'])' )"), std::string(R"($1 ' )")},
{std::string(R"([*])"), std::string(R"($0 )")}};
// Pads parentheses
std::pair<std::string, std::string> PARENS_BRACKETS = {std::string(R"([\]\[\(\)\{\}\<\>])"), std::string(R"( $0 )")};
std::vector<std::pair<std::string, std::string>> CONVERT_PARENTHESES = {{std::string(R"(\()"), std::string("-LRB-")},
{std::string(R"(\))"), std::string("-RRB-")},
{std::string(R"(\[)"), std::string("-LSB-")},
{std::string(R"(\])"), std::string("-RSB-")},
{std::string(R"(\{)"), std::string("-LCB-")},
{std::string(R"(\})"), std::string("-RCB-")}};
std::pair<std::string, std::string> DOUBLE_DASHES = {std::string(R"(--)"), std::string(R"( -- )")};
// Cache for compiled regex patterns
std::vector<CompiledRegex> compiled_starting_quotes_;
std::vector<CompiledRegex> compiled_ending_quotes_;
std::vector<CompiledRegex> compiled_punctuation_;
CompiledRegex compiled_parens_brackets_;
std::vector<CompiledRegex> compiled_convert_parentheses_;
CompiledRegex compiled_double_dashes_;
std::vector<CompiledRegex> compiled_contractions2_;
std::vector<CompiledRegex> compiled_contractions3_;
// Constructor that precompiles all regex patterns
NLTKWordTokenizer() : compiled_parens_brackets_(nullptr, ""), compiled_double_dashes_(nullptr, "") { CompileRegexPatterns(); }
void Tokenize(const std::string &text, std::vector<std::string> &tokens, bool convert_parentheses = false) {
std::string result = text;
for (const auto &compiled : compiled_starting_quotes_) {
result = ApplyRegex(result, compiled);
}
for (const auto &compiled : compiled_punctuation_) {
result = ApplyRegex(result, compiled);
}
// Handles parentheses.
result = ApplyRegex(result, compiled_parens_brackets_);
// Optionally convert parentheses
if (convert_parentheses) {
for (const auto &compiled : compiled_convert_parentheses_) {
result = ApplyRegex(result, compiled);
}
}
// Handles double dash.
result = ApplyRegex(result, compiled_double_dashes_);
// Add extra space to make things easier
result = " " + result + " ";
for (const auto &compiled : compiled_ending_quotes_) {
result = ApplyRegex(result, compiled);
}
for (const auto &compiled : compiled_contractions2_) {
result = ApplyRegex(result, compiled);
}
for (const auto &compiled : compiled_contractions3_) {
result = ApplyRegex(result, compiled);
}
// Split the result into tokens
size_t start = 0;
size_t end = result.find(' ');
while (end != std::string::npos) {
if (end != start) {
std::string token = result.substr(start, end - start);
// Handle underscore tokens properly
if (token == "_") {
// Single underscore token
tokens.push_back("_");
} else if (token.find('_') != std::string::npos) {
// Split tokens containing underscores and keep underscores as separate tokens
std::stringstream ss(token);
std::string sub_token;
bool first = true;
while (std::getline(ss, sub_token, '_')) {
if (!first) {
tokens.push_back("_");
}
if (!sub_token.empty()) {
tokens.push_back(sub_token);
}
first = false;
}
// Handle case where token ends with underscore
if (token.back() == '_') {
tokens.push_back("_");
}
} else {
tokens.push_back(token);
}
}
start = end + 1;
end = result.find(' ', start);
}
if (start != result.length()) {
std::string token = result.substr(start);
// Handle underscore tokens properly
if (token == "_") {
// Single underscore token
tokens.push_back("_");
} else if (token.find('_') != std::string::npos) {
// Split tokens containing underscores and keep underscores as separate tokens
std::stringstream ss(token);
std::string sub_token;
bool first = true;
while (std::getline(ss, sub_token, '_')) {
if (!first) {
tokens.push_back("_");
}
if (!sub_token.empty()) {
tokens.push_back(sub_token);
}
first = false;
}
// Handle case where token ends with underscore
if (token.back() == '_') {
tokens.push_back("_");
}
} else {
tokens.push_back(token);
}
}
}
private:
void CompileRegexPatterns() {
compiled_starting_quotes_.reserve(STARTING_QUOTES.size());
for (const auto &[pattern, substitution] : STARTING_QUOTES) {
compiled_starting_quotes_.emplace_back(CompilePattern(pattern), substitution);
}
compiled_ending_quotes_.reserve(ENDING_QUOTES.size());
for (const auto &[pattern, substitution] : ENDING_QUOTES) {
compiled_ending_quotes_.emplace_back(CompilePattern(pattern), substitution);
}
compiled_punctuation_.reserve(PUNCTUATION.size());
for (const auto &[pattern, substitution] : PUNCTUATION) {
compiled_punctuation_.emplace_back(CompilePattern(pattern), substitution);
}
compiled_parens_brackets_ = CompiledRegex(CompilePattern(PARENS_BRACKETS.first), PARENS_BRACKETS.second);
compiled_convert_parentheses_.reserve(CONVERT_PARENTHESES.size());
for (const auto &[pattern, substitution] : CONVERT_PARENTHESES) {
compiled_convert_parentheses_.emplace_back(CompilePattern(pattern), substitution);
}
compiled_double_dashes_ = CompiledRegex(CompilePattern(DOUBLE_DASHES.first), DOUBLE_DASHES.second);
compiled_contractions2_.reserve(contractions_.CONTRACTIONS2.size());
for (const auto &pattern : contractions_.CONTRACTIONS2) {
compiled_contractions2_.emplace_back(CompilePattern(pattern), R"( $1 $2 )");
}
compiled_contractions3_.reserve(contractions_.CONTRACTIONS3.size());
for (const auto &pattern : contractions_.CONTRACTIONS3) {
compiled_contractions3_.emplace_back(CompilePattern(pattern), R"( $1 $2 )");
}
}
pcre2_code *CompilePattern(const std::string &pattern) {
int errorcode = 0;
PCRE2_SIZE erroffset = 0;
pcre2_code *re = pcre2_compile(reinterpret_cast<PCRE2_SPTR>(pattern.c_str()),
PCRE2_ZERO_TERMINATED,
PCRE2_MULTILINE | PCRE2_UTF,
&errorcode,
&erroffset,
nullptr);
if (re == nullptr) {
PCRE2_UCHAR buffer[256];
pcre2_get_error_message(errorcode, buffer, sizeof(buffer));
std::cerr << "PCRE2 compilation failed at offset " << erroffset << ": " << buffer << std::endl;
return nullptr;
}
return re;
}
std::string ApplyRegex(const std::string &text, const CompiledRegex &compiled) {
if (compiled.re == nullptr) {
return text;
}
PCRE2_SPTR pcre2_subject = reinterpret_cast<PCRE2_SPTR>(text.c_str());
PCRE2_SPTR pcre2_replacement = reinterpret_cast<PCRE2_SPTR>(compiled.substitution.c_str());
size_t outlength = text.length() * 2 < 1024 ? 1024 : text.length() * 2;
auto buffer = std::make_unique<PCRE2_UCHAR[]>(outlength);
int rc = pcre2_substitute(compiled.re,
pcre2_subject,
text.length(),
0,
PCRE2_SUBSTITUTE_GLOBAL,
nullptr,
nullptr,
pcre2_replacement,
PCRE2_ZERO_TERMINATED,
buffer.get(),
&outlength);
if (rc < 0) {
return text;
}
return std::string(reinterpret_cast<char *>(buffer.get()), outlength);
}
};
// Static member definitions for NLTKWordTokenizer singleton
std::unique_ptr<NLTKWordTokenizer> NLTKWordTokenizer::instance_ = nullptr;
std::once_flag NLTKWordTokenizer::init_flag_;
void SentenceSplitter(const std::string &text, std::vector<std::string> &result) {
int error_code;
PCRE2_SIZE error_offset;
const char *pattern = R"( *[\.\?!]['"\)\]]* *)";
pcre2_code *re = pcre2_compile((PCRE2_SPTR)pattern, PCRE2_ZERO_TERMINATED, PCRE2_MULTILINE | PCRE2_UTF, &error_code, &error_offset, nullptr);
if (re == nullptr) {
PCRE2_UCHAR buffer[256];
pcre2_get_error_message(error_code, buffer, sizeof(buffer));
std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": " << buffer << std::endl;
return;
}
pcre2_match_data *match_data = pcre2_match_data_create_from_pattern(re, nullptr);
PCRE2_SIZE start_offset = 0;
while (start_offset < text.size()) {
int rc = pcre2_match(re, (PCRE2_SPTR)text.c_str(), text.size(), start_offset, 0, match_data, nullptr);
if (rc < 0) {
result.push_back(text.substr(start_offset));
break;
}
PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data);
PCRE2_SIZE match_start = ovector[0];
PCRE2_SIZE match_end = ovector[1];
if (match_start > start_offset) {
result.push_back(text.substr(start_offset, match_end - start_offset));
}
start_offset = match_end;
}
pcre2_match_data_free(match_data);
pcre2_code_free(re);
}
RAGAnalyzer::RAGAnalyzer(const std::string &path)
: dict_path_(path), stemmer_(std::make_unique<Stemmer>()) {
InitStemmer(STEM_LANG_ENGLISH);
}
RAGAnalyzer::RAGAnalyzer(const RAGAnalyzer &other)
: own_dict_(false), trie_(other.trie_), pos_table_(other.pos_table_), wordnet_lemma_(other.wordnet_lemma_), stemmer_(std::make_unique<Stemmer>()),
opencc_(other.opencc_), fine_grained_(other.fine_grained_) {
InitStemmer(STEM_LANG_ENGLISH);
}
RAGAnalyzer::~RAGAnalyzer() {
if (own_dict_) {
delete trie_;
delete pos_table_;
delete wordnet_lemma_;
delete opencc_;
}
}
int32_t RAGAnalyzer::Load() {
fs::path root(dict_path_);
fs::path dict_path(root / DICT_PATH);
if (!fs::exists(dict_path)) {
printf("Invalid analyzer file: %s", dict_path.string().c_str());
// return Status::InvalidAnalyzerFile(dict_path);
return -1;
}
fs::path pos_def_path(root / POS_DEF_PATH);
if (!fs::exists(pos_def_path)) {
printf("Invalid post file: %s", pos_def_path.string().c_str());
// return Status::InvalidAnalyzerFile(pos_def_path);
return -1;
}
own_dict_ = true;
trie_ = new DartsTrie();
pos_table_ = new POSTable(pos_def_path.string());
if (pos_table_->Load() != 0) {
printf("Fail to load post table: %s", pos_def_path.string().c_str());
return -1;
// return Status::InvalidAnalyzerFile("Failed to load RAGAnalyzer POS definition");
}
fs::path trie_path(root / TRIE_PATH);
if (fs::exists(trie_path)) {
trie_->Load(trie_path.string());
} else {
// Build trie
try {
std::ifstream from(dict_path.string());
std::string line;
re2::RE2 re_pattern(R"([\r\n]+)");
std::string split_pattern("([ \t])");
while (getline(from, line)) {
line = line.substr(0, line.find('\r'));
if (line.empty())
continue;
line = Replace(re_pattern, "", line);
std::vector<std::string> results;
Split(line, split_pattern, results);
if (results.size() != 3)
throw std::runtime_error("Invalid dictionary format");
int32_t freq = std::stoi(results[1]);
freq = int32_t(std::log(float(freq) / DENOMINATOR) + 0.5);
int32_t pos_idx = pos_table_->GetPOSIndex(results[2]);
int value = Encode(freq, pos_idx);
trie_->Add(results[0], value);
std::string rkey = RKey(results[0]);
trie_->Add(rkey, Encode(1, 0));
}
trie_->Build();
} catch (const std::exception &e) {
return -1;
// return Status::InvalidAnalyzerFile("Failed to load RAGAnalyzer analyzer");
}
trie_->Save(trie_path.string());
}
fs::path lemma_path(root / WORDNET_PATH);
if (!fs::exists(lemma_path)) {
printf("Fail to load wordnet: %s", lemma_path.string().c_str());
return -1;
// return Status::InvalidAnalyzerFile(lemma_path);
}
wordnet_lemma_ = new WordNetLemmatizer(lemma_path.string());
fs::path opencc_path(root / OPENCC_PATH);
if (!fs::exists(opencc_path)) {
printf("Fail to load opencc_path: %s", opencc_path.string().c_str());
return -1;
// return Status::InvalidAnalyzerFile(opencc_path);
}
try {
opencc_ = new ::OpenCC(opencc_path.string());
} catch (const std::exception &e) {
return -1;
// return Status::InvalidAnalyzerFile("Failed to load OpenCC");
}
// return Status::OK();
return 0;
}
void RAGAnalyzer::BuildPositionMapping(const std::string &original, const std::string &converted, std::vector<unsigned> &pos_mapping) {
pos_mapping.clear();
pos_mapping.resize(converted.size() + 1);
size_t orig_pos = 0;
size_t conv_pos = 0;
// Map each character position from converted string to original string
while (orig_pos < original.size() && conv_pos < converted.size()) {
// Get character lengths
size_t orig_char_len = UTF8_BYTE_LENGTH_TABLE[static_cast<uint8_t>(original[orig_pos])];
size_t conv_char_len = UTF8_BYTE_LENGTH_TABLE[static_cast<uint8_t>(converted[conv_pos])];
// Map all bytes of current converted character to current original position
for (size_t i = 0; i < conv_char_len && conv_pos + i < pos_mapping.size(); ++i) {
pos_mapping[conv_pos + i] = static_cast<unsigned>(orig_pos);
}
// Move to next character in both strings
orig_pos += orig_char_len;
conv_pos += conv_char_len;
}
// Fill any remaining positions
for (size_t i = conv_pos; i < pos_mapping.size(); ++i) {
pos_mapping[i] = static_cast<unsigned>(original.size());
}
}
std::string RAGAnalyzer::StrQ2B(const std::string &input) {
std::string output;
size_t i = 0;
while (i < input.size()) {
unsigned char c = input[i];
uint32_t codepoint = 0;
if (c < 0x80) {
codepoint = c;
i += 1;
} else if ((c & 0xE0) == 0xC0) {
codepoint = (c & 0x1F) << 6;
codepoint |= (input[i + 1] & 0x3F);
i += 2;
} else if ((c & 0xF0) == 0xE0) {
codepoint = (c & 0x0F) << 12;
codepoint |= (input[i + 1] & 0x3F) << 6;
codepoint |= (input[i + 2] & 0x3F);
i += 3;
} else {
output += c;
i += 1;
continue;
}
if (codepoint >= 0xFF01 && codepoint <= 0xFF5E) {
output += static_cast<char>(codepoint - 0xFEE0);
} else if (codepoint == 0x3000) {
output += ' ';
} else {
if (codepoint < 0x80) {
output += static_cast<char>(codepoint);
} else if (codepoint < 0x800) {
output += static_cast<char>(0xC0 | (codepoint >> 6));
output += static_cast<char>(0x80 | (codepoint & 0x3F));
} else if (codepoint < 0x10000) {
output += static_cast<char>(0xE0 | (codepoint >> 12));
output += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F));
output += static_cast<char>(0x80 | (codepoint & 0x3F));
}
}
}
return output;
}
int32_t RAGAnalyzer::Freq(const std::string_view key) const {
int32_t v = trie_->Get(key);
v = DecodeFreq(v);
return static_cast<int32_t>(std::exp(v) * DENOMINATOR + 0.5);
}
std::string RAGAnalyzer::Tag(std::string_view key) const {
std::string lower_key = Key(std::string(key));
int32_t encoded_value = trie_->Get(lower_key);
if (encoded_value == -1) {
return "";
}
int32_t pos_idx = DecodePOSIndex(encoded_value);
if (pos_table_ == nullptr) {
return "";
}
const char* pos_tag = pos_table_->GetPOS(pos_idx);
return pos_tag ? std::string(pos_tag) : "";
}
std::string RAGAnalyzer::Key(const std::string_view line) { return ToLowerString(line); }
std::string RAGAnalyzer::RKey(const std::string_view line) {
std::string reversed;
reversed.reserve(line.size() + 2);
reversed += "DD";
for (size_t i = line.size(); i > 0;) {
size_t start = i - 1;
while (start > 0 && (line[start] & 0xC0) == 0x80) {
--start;
}
reversed += line.substr(start, i - start);
i = start;
}
ToLower(reversed.data() + 2, reversed.size() - 2);
return reversed;
}
std::pair<std::vector<std::string>, double> RAGAnalyzer::Score(const std::vector<std::pair<std::string, int>> &token_freqs) {
constexpr int64_t B = 30;
int64_t F = 0, L = 0;
std::vector<std::string> tokens;
tokens.reserve(token_freqs.size());
for (const auto &[token, freq_tag] : token_freqs) {
F += DecodeFreq(freq_tag);
L += (UTF8Length(token) < 2) ? 0 : 1;
tokens.push_back(token);
}
const auto score = B / static_cast<double>(tokens.size()) + L / static_cast<double>(tokens.size()) + F;
return {std::move(tokens), score};
}
void RAGAnalyzer::SortTokens(const std::vector<std::vector<std::pair<std::string, int>>> &token_list,
std::vector<std::pair<std::vector<std::string>, double>> &res) {
for (const auto &tfts : token_list) {
res.push_back(Score(tfts));
}
std::sort(res.begin(), res.end(), [](const auto &a, const auto &b) { return a.second > b.second; });
}
std::pair<std::vector<std::string>, double> RAGAnalyzer::MaxForward(const std::string &line) const {
std::vector<std::pair<std::string, int>> res;
std::size_t s = 0;
std::size_t len = UTF8Length(line);
while (s < len) {
std::size_t e = s + 1;
std::string t = UTF8Substr(line, s, e - s);
while (e < len && trie_->HasKeysWithPrefix(Key(t))) {
e += 1;
t = UTF8Substr(line, s, e - s);
}
while (e - 1 > s && trie_->Get(Key(t)) == -1) {
e -= 1;
t = UTF8Substr(line, s, e - s);
}
int v = trie_->Get(Key(t));
if (v != -1) {
res.emplace_back(std::move(t), v);
} else {
res.emplace_back(std::move(t), 0);
}
s = e;
}
return Score(res);
}
std::pair<std::vector<std::string>, double> RAGAnalyzer::MaxBackward(const std::string &line) const {
std::vector<std::pair<std::string, int>> res;
int s = UTF8Length(line) - 1;
while (s >= 0) {
const int e = s + 1;
std::string t = UTF8Substr(line, s, e - s);
while (s > 0 && trie_->HasKeysWithPrefix(RKey(t))) {
s -= 1;
t = UTF8Substr(line, s, e - s);
}
while (s + 1 < e && trie_->Get(Key(t)) == -1) {
s += 1;
t = UTF8Substr(line, s, e - s);
}
int v = trie_->Get(Key(t));
if (v != -1) {
res.emplace_back(std::move(t), v);
} else {
res.emplace_back(std::move(t), 0);
}
s -= 1;
}
std::reverse(res.begin(), res.end());
return Score(res);
}
int RAGAnalyzer::DFS(const std::string &chars,
const int s,
std::vector<std::pair<std::string, int>> &pre_tokens,
std::vector<std::vector<std::pair<std::string, int>>> &token_list,
std::vector<std::string> &best_tokens,
double &max_score,
const bool memo_all) const {
int res = s;
const int len = UTF8Length(chars);
if (s >= len) {
if (memo_all) {
token_list.push_back(pre_tokens);
} else if (auto [vec_str, current_score] = Score(pre_tokens); current_score > max_score) {
best_tokens = std::move(vec_str);
max_score = current_score;
}
return res;
}
// pruning
int S = s + 1;
if (s + 2 <= len) {
std::string t1 = UTF8Substr(chars, s, 1);
std::string t2 = UTF8Substr(chars, s, 2);
if (trie_->HasKeysWithPrefix(Key(t1)) && !trie_->HasKeysWithPrefix(Key(t2))) {
S = s + 2;
}
}
if (pre_tokens.size() > 2 && UTF8Length(pre_tokens[pre_tokens.size() - 1].first) == 1 &&
UTF8Length(pre_tokens[pre_tokens.size() - 2].first) == 1 && UTF8Length(pre_tokens[pre_tokens.size() - 3].first) == 1) {
std::string t1 = pre_tokens[pre_tokens.size() - 1].first + UTF8Substr(chars, s, 1);
if (trie_->HasKeysWithPrefix(Key(t1))) {
S = s + 2;
}
}
for (int e = S; e <= len; ++e) {
std::string t = UTF8Substr(chars, s, e - s);
std::string k = Key(t);
if (e > s + 1 && !trie_->HasKeysWithPrefix(k)) {
break;
}
if (const int v = trie_->Get(k); v != -1) {
auto pretks = pre_tokens;
pretks.emplace_back(std::move(t), v);
res = std::max(res, DFS(chars, e, pretks, token_list, best_tokens, max_score, memo_all));
}
}
if (res > s) {
return res;
}
std::string t = UTF8Substr(chars, s, 1);
if (const int v = trie_->Get(Key(t)); v != -1) {
pre_tokens.emplace_back(std::move(t), v);
} else {
pre_tokens.emplace_back(std::move(t), Encode(-12, 0));
}
return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all);
}
struct TokensList {
const TokensList *prev = nullptr;
std::string_view token = {};
};
struct BestTokenCandidate {
static constexpr int64_t B = 30;
TokensList tl{};
// N: token num
// L: num of tokens with length >= 2
// F: sum of freq
uint32_t N{};
uint32_t L{};
int64_t F{};
auto k() const {
#ifdef DIVIDE_F_BY_N
return N;
#else
return std::make_pair(N, L);
#endif
}
auto v() const { return F; }
auto score() const {
#ifdef DIVIDE_F_BY_N
return static_cast<double>(B + L + F) / N;
#else
return F + (static_cast<double>(B + L) / N);
#endif
}
BestTokenCandidate update(const std::string_view new_token_sv, const int32_t key_f, const uint32_t add_l) const {
return {{&tl, new_token_sv}, N + 1, L + add_l, F + key_f};
}
};
struct GrowingBestTokenCandidatesTopN {
int32_t top_n{};
std::vector<BestTokenCandidate> candidates{};
explicit GrowingBestTokenCandidatesTopN(const int32_t top_n) : top_n(top_n) {
}
void AddBestTokenCandidateTopN(const BestTokenCandidate &add_candidate) {
const auto [it_b, it_e] =
std::equal_range(candidates.begin(), candidates.end(), add_candidate, [](const auto &a, const auto &b) { return a.k() < b.k(); });
auto target_it = it_b;
bool do_replace = false;
if (const auto match_cnt = std::distance(it_b, it_e); match_cnt >= top_n) {
assert(match_cnt == top_n);
const auto it = std::min_element(it_b, it_e, [](const auto &a, const auto &b) { return a.v() < b.v(); });
if (it->v() >= add_candidate.v()) {
return;
}
target_it = it;
do_replace = true;
}
if (do_replace) {
*target_it = add_candidate;
} else {
candidates.insert(target_it, add_candidate);
}
}
};
std::vector<std::pair<std::vector<std::string_view>, double>> RAGAnalyzer::GetBestTokensTopN(const std::string_view chars, const uint32_t n) const {
const auto utf8_len = UTF8Length(chars);
std::vector<GrowingBestTokenCandidatesTopN> dp_vec(utf8_len + 1, GrowingBestTokenCandidatesTopN(n));
dp_vec[0].candidates.resize(1);
const char *current_utf8_ptr = chars.data();
uint32_t current_left_chars = chars.size();
std::string growing_key; // in lower case
for (uint32_t i = 0; i < utf8_len; ++i) {
const std::string_view current_chars{current_utf8_ptr, current_left_chars};
const uint32_t left_utf8_cnt = utf8_len - i;
growing_key.clear();
const char *lookup_until = current_utf8_ptr;
uint32_t lookup_left_chars = current_left_chars;
std::size_t reuse_node_pos = 0;
std::size_t reuse_key_pos = 0;
for (uint32_t j = 1; j <= left_utf8_cnt; ++j) {
{
// handle growing_key
const auto next_one_utf8 = UTF8Substrview({lookup_until, lookup_left_chars}, 0, 1);
if (next_one_utf8.size() == 1 && next_one_utf8[0] >= 'A' && next_one_utf8[0] <= 'Z') {
growing_key.push_back(next_one_utf8[0] - 'A' + 'a');
} else {
growing_key.append(next_one_utf8);
}
lookup_until += next_one_utf8.size();
lookup_left_chars -= next_one_utf8.size();
}
auto dp_f = [&dp_vec, i, j, original_sv = std::string_view{current_utf8_ptr, growing_key.size()}](
const int32_t key_f,
const uint32_t add_l) {
auto &target_dp = dp_vec[i + j];
for (const auto &c : dp_vec[i].candidates) {
target_dp.AddBestTokenCandidateTopN(c.update(original_sv, key_f, add_l));
}
};
if (const auto traverse_result = trie_->Traverse(growing_key.data(), reuse_node_pos, reuse_key_pos, growing_key.size());
traverse_result >= 0) {
// in dictionary
const int32_t key_f = DecodeFreq(traverse_result);
const auto add_l = static_cast<uint32_t>(j >= 2);
dp_f(key_f, add_l);
} else {
// not in dictionary
if (j == 1) {
// also give a score: -12
dp_f(-12, 0);
}
if (traverse_result == -2) {
// no more results
break;
}
}
}
// update current_utf8_ptr and current_left_chars
const auto forward_cnt = UTF8Substrview(current_chars, 0, 1).size();
current_utf8_ptr += forward_cnt;
current_left_chars -= forward_cnt;
}
std::vector<std::pair<const TokensList *, double>> mid_result;
mid_result.reserve(n);
for (const auto &c : dp_vec.back().candidates) {
const auto new_pair = std::make_pair(&(c.tl), c.score());
if (mid_result.size() < n) {
mid_result.push_back(new_pair);
} else {
assert(mid_result.size() == n);
if (new_pair.second > mid_result.back().second) {
mid_result.pop_back();
const auto insert_pos = std::lower_bound(mid_result.begin(),
mid_result.end(),
new_pair,
[](const auto &a, const auto &b) {
return a.second > b.second;
});
mid_result.insert(insert_pos, new_pair);
}
}
}
class HelperFunc {
uint32_t cnt = 0;
std::vector<std::string_view> result{};
void GetTokensInner(const TokensList *tl) {
if (!tl->prev) {
result.reserve(cnt);
return;
}
++cnt;
GetTokensInner(tl->prev);
result.push_back(tl->token);
}
public:
std::vector<std::string_view> GetTokens(const TokensList *tl) {
GetTokensInner(tl);
return std::move(result);
}
};
std::vector<std::pair<std::vector<std::string_view>, double>> result;
result.reserve(mid_result.size());
for (const auto [tl, score] : mid_result) {
result.emplace_back(HelperFunc{}.GetTokens(tl), score);
}
return result;
}
// TODO: for test
// #ifndef INFINITY_DEBUG
// #define INFINITY_DEBUG 1
// #endif
#ifdef INFINITY_DEBUG
namespace dp_debug {
template <typename T>
std::string TestPrintTokens(const std::vector<T> &tokens) {
std::ostringstream oss;
for (std::size_t i = 0; i < tokens.size(); ++i) {
oss << (i ? " #" : "#") << tokens[i] << "#";
}
return std::move(oss).str();
}
auto print_1 = [](const bool b) { return b ? "" : ""; };
auto print_2 = [](const bool b) { return b ? "equal" : "not equal"; };
void compare_score_and_tokens(const std::vector<std::string> &dfs_tokens,
const double dfs_score,
const std::vector<std::string_view> &dp_tokens,
const double dp_score,
const std::string &prefix) {
std::ostringstream oss;
const auto b_score_eq = dp_score == dfs_score;
oss << fmt::format("\n{} {} DFS and DP score {}:\nDFS: {}\nDP : {}\n", print_1(b_score_eq), prefix, print_2(b_score_eq), dfs_score, dp_score);
bool vec_equal = true;
if (dp_tokens.size() != dfs_tokens.size()) {
vec_equal = false;
} else {
for (std::size_t k = 0; k < dp_tokens.size(); ++k) {
if (dp_tokens[k] != dfs_tokens[k]) {
vec_equal = false;
break;
}
}
}
oss << fmt::format("{} {} DFS and DP result {}:\nDFS: {}\nDP : {}\n",
print_1(vec_equal),
prefix,
print_2(vec_equal),
TestPrintTokens(dfs_tokens),
TestPrintTokens(dp_tokens));
std::cerr << std::move(oss).str() << std::endl;
}
inline void CheckDP(const RAGAnalyzer *this_ptr,
const std::string_view input_str,
const std::vector<std::string> &dfs_tokens,
const double dfs_score,
const auto t0,
const auto t1) {
const auto dp_result = this_ptr->GetBestTokensTopN(input_str, 1);
const auto t2 = std::chrono::high_resolution_clock::now();
const auto dfs_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t1 - t0);
const auto dp_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t2 - t1);
const auto dp_faster = dp_duration < dfs_duration;
std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP1 DFS duration: " << dfs_duration << " \nDP duration: " << dp_duration;
const auto &[dp_vec, dp_score] = dp_result[0];
compare_score_and_tokens(dfs_tokens, dfs_score, dp_vec, dp_score, "[1 in top1]");
}
inline void CheckDP2(const RAGAnalyzer *this_ptr, const std::string_view input_str, auto get_dfs_sorted_tokens, const auto t0, const auto t1) {
constexpr int topn = 2;
const auto dp_result = this_ptr->GetBestTokensTopN(input_str, topn);
const auto t2 = std::chrono::high_resolution_clock::now();
const auto dfs_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t1 - t0);
const auto dp_duration = std::chrono::duration_cast<std::chrono::duration<float, std::milli>>(t2 - t1);
const auto dp_faster = dp_duration < dfs_duration;
std::cerr << "\n!!! " << print_1(dp_faster) << "\nTOP2 DFS duration: " << dfs_duration << " \nTOP2 DP duration: " << dp_duration;
const auto dfs_sorted_tokens = get_dfs_sorted_tokens();
for (int i = 0; i < std::min(topn, (int)dfs_sorted_tokens.size()); ++i) {
compare_score_and_tokens(dfs_sorted_tokens[i].first,
dfs_sorted_tokens[i].second,
dp_result[i].first,
dp_result[i].second,
std::format("[{} in top{}]", i + 1, topn));
}
}
} // namespace dp_debug
#endif
std::string RAGAnalyzer::Merge(const std::string &tks_str) const {
std::string tks = tks_str;
tks = Replace(replace_space_pattern_, " ", tks);
std::vector<std::string> tokens;
Split(tks, blank_pattern_, tokens);
std::vector<std::string> res;
std::size_t s = 0;
while (true) {
if (s >= tokens.size())
break;
std::size_t E = s + 1;
for (std::size_t e = s + 2; e < std::min(tokens.size() + 1, s + 6); ++e) {
std::string tk = Join(tokens, s, e, "");
if (re2::RE2::PartialMatch(tk, regex_split_pattern_)) {
if (Freq(tk) > 0) {
E = e;
}
}
}
res.push_back(Join(tokens, s, E, ""));
s = E;
}
return Join(res, 0, res.size());
}
void RAGAnalyzer::MergeWithPosition(const std::vector<std::string> &tokens,
const std::vector<std::pair<unsigned, unsigned>> &positions,
std::vector<std::string> &merged_tokens,
std::vector<std::pair<unsigned, unsigned>> &merged_positions) const {
// Filter out empty tokens first (like spaces) to match Merge behavior
std::vector<std::string> filtered_tokens;
std::vector<std::pair<unsigned, unsigned>> filtered_positions;
for (size_t i = 0; i < tokens.size(); ++i) {
if (!tokens[i].empty() && tokens[i] != " ") {
filtered_tokens.push_back(tokens[i]);
filtered_positions.push_back(positions[i]);
}
}
std::vector<std::string> res;
std::size_t s = 0;
std::vector<std::pair<unsigned, unsigned>> res_positions;
while (true) {
if (s >= filtered_tokens.size())
break;
std::size_t E = s + 1;
for (std::size_t e = s + 2; e < std::min(filtered_tokens.size() + 1, s + 6); ++e) {
std::string tk = Join(filtered_tokens, s, e, "");
if (re2::RE2::PartialMatch(tk, regex_split_pattern_)) {
if (Freq(tk) > 0) {
E = e;
}
}
}
std::string merged_token = Join(filtered_tokens, s, E, "");
res.push_back(merged_token);
unsigned start_pos = filtered_positions[s].first;
unsigned end_pos = filtered_positions[E - 1].second;
res_positions.emplace_back(start_pos, end_pos);
s = E;
}
merged_tokens = std::move(res);
merged_positions = std::move(res_positions);
}
void RAGAnalyzer::EnglishNormalize(const std::vector<std::string> &tokens, std::vector<std::string> &res) const {
for (auto &t : tokens) {
if (re2::RE2::PartialMatch(t, pattern1_)) {
//"[a-zA-Z_-]+$"
std::string lemma_term = wordnet_lemma_->Lemmatize(t);
std::vector<char> lowercase_buffer(term_string_buffer_limit_);
char *lowercase_term = lowercase_buffer.data();
ToLower(lemma_term.c_str(), lemma_term.size(), lowercase_term, term_string_buffer_limit_);
std::string stem_term;
stemmer_->Stem(lowercase_term, stem_term);
res.push_back(stem_term);
} else {
res.push_back(t);
}
}
}
void RAGAnalyzer::SplitByLang(const std::string &line, std::vector<std::pair<std::string, bool>> &txt_lang_pairs) const {
std::vector<std::string> arr;
Split(line, regex_split_pattern_, arr, true);
for (const auto &a : arr) {
if (a.empty()) {
continue;
}
std::size_t s = 0;
std::size_t e = s + 1;
bool zh = IsChinese(UTF8Substr(a, s, 1));
while (e < UTF8Length(a)) {
bool _zh = IsChinese(UTF8Substr(a, e, 1));
if (_zh == zh) {
e++;
continue;
}
std::string segment = UTF8Substr(a, s, e - s);
txt_lang_pairs.emplace_back(segment, zh);
s = e;
e = s + 1;
zh = _zh;
}
if (s >= UTF8Length(a)) {
continue;
}
std::string segment = UTF8Substr(a, s, e - s);
txt_lang_pairs.emplace_back(segment, zh);
}
}
void RAGAnalyzer::TokenizeInner(std::vector<std::string> &res, const std::string &L) const {
auto [tks, s] = MaxForward(L);
auto [tks1, s1] = MaxBackward(L);
#if 0
std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0;
while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) {
same++;
}
if (same > 0) {
res.push_back(Join(tks, j, j + same));
}
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
while (i < tks1.size() && j < tks.size()) {
std::string tk1 = Join(tks1, _i, i, "");
std::string tk = Join(tks, _j, j, "");
if (tk1 != tk) {
if (tk1.length() > tk.length()) {
j++;
} else {
i++;
}
continue;
}
if (tks1[i] != tks[j]) {
i++;
j++;
continue;
}
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, j, "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));
same = 1;
while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same])
same++;
res.push_back(Join(tks, j, j + same));
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
}
if (_i < tks1.size()) {
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, tks.size(), "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));
}
#else
std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0;
while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) {
same++;
}
if (same > 0) {
res.push_back(Join(tks, j, j + same));
}
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
while (i < tks1.size() && j < tks.size()) {
std::string tk1 = Join(tks1, _i, i, "");
std::string tk = Join(tks, _j, j, "");
if (tk1 != tk) {
if (tk1.length() > tk.length()) {
j++;
} else {
i++;
}
continue;
}
if (tks1[i] != tks[j]) {
i++;
j++;
continue;
}
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, j, "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));
same = 1;
while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same])
same++;
res.push_back(Join(tks, j, j + same));
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
}
if (_i < tks1.size()) {
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, tks.size(), "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
res.push_back(Join(best_tokens, 0));
}
#endif
}
void RAGAnalyzer::SplitLongText(const std::string &L, uint32_t length, std::vector<std::string> &sublines) const {
uint32_t slice_count = length / MAX_SENTENCE_LEN + 1;
sublines.reserve(slice_count);
std::size_t last_sentence_start = 0;
std::size_t next_sentence_start = 0;
for (unsigned i = 0; i < slice_count; ++i) {
next_sentence_start = MAX_SENTENCE_LEN * (i + 1) - 5;
if (next_sentence_start + 5 < length) {
std::size_t sentence_length = MAX_SENTENCE_LEN * (i + 1) + 5 > length ? length - next_sentence_start : 10;
std::string substr = UTF8Substr(L, next_sentence_start, sentence_length);
auto [tks, s] = MaxForward(substr);
auto [tks1, s1] = MaxBackward(substr);
std::vector<int> diff(std::max(tks.size(), tks1.size()), 0);
for (std::size_t j = 0; j < std::min(tks.size(), tks1.size()); ++j) {
if (tks[j] != tks1[j]) {
diff[j] = 1;
}
}
if (s1 > s) {
tks = tks1;
}
std::size_t start = 0;
std::size_t forward_same_len = 0;
while (start < tks.size() && diff[start] == 0) {
forward_same_len += UTF8Length(tks[start]);
start++;
}
if (forward_same_len == 0) {
std::size_t end = tks.size() - 1;
std::size_t backward_same_len = 0;
while (end >= 0 && diff[end] == 0) {
backward_same_len += UTF8Length(tks[end]);
end--;
}
next_sentence_start += sentence_length - backward_same_len;
} else
next_sentence_start += forward_same_len;
} else
next_sentence_start = length;
if (next_sentence_start == last_sentence_start)
continue;
std::string str = UTF8Substr(L, last_sentence_start, next_sentence_start - last_sentence_start);
sublines.push_back(str);
last_sentence_start = next_sentence_start;
}
}
// PCRE2-based replacement function to match Python's re.sub behavior
// Returns processed string and position mapping from processed to original
std::pair<std::string, std::vector<std::pair<unsigned, unsigned>>>
PCRE2GlobalReplaceWithPosition(const std::string &text, const std::string &pattern, const std::string &replacement) {
std::vector<std::pair<unsigned, unsigned>> pos_mapping;
std::string result;
pcre2_code *re;
PCRE2_SPTR pcre2_pattern = reinterpret_cast<PCRE2_SPTR>(pattern.c_str());
PCRE2_SPTR pcre2_subject = reinterpret_cast<PCRE2_SPTR>(text.c_str());
// Note: pcre2_replacement is used in the replacement logic below
int errorcode;
PCRE2_SIZE erroroffset;
// Compile the pattern with UTF and UCP flags for Unicode support
re = pcre2_compile(pcre2_pattern, PCRE2_ZERO_TERMINATED, PCRE2_UCP | PCRE2_UTF, &errorcode, &erroroffset, nullptr);
if (re == nullptr) {
PCRE2_UCHAR buffer[256];
pcre2_get_error_message(errorcode, buffer, sizeof(buffer));
std::cerr << "PCRE2 compilation failed at offset " << erroroffset << ": " << buffer << std::endl;
return {text, {}};
}
pcre2_match_data *match_data = pcre2_match_data_create_from_pattern(re, nullptr);
PCRE2_SIZE current_pos = 0;
PCRE2_SIZE last_match_end = 0;
// Process the string match by match
while (current_pos < text.length()) {
int rc = pcre2_match(re, pcre2_subject, text.length(), current_pos, 0, match_data, nullptr);
if (rc < 0) {
// No more matches, copy remaining text
if (last_match_end < text.length()) {
std::string remaining = text.substr(last_match_end);
result += remaining;
// Map each character in remaining text
for (size_t i = 0; i < remaining.length(); ++i) {
pos_mapping.emplace_back(last_match_end + i, last_match_end + i);
}
}
break;
}
PCRE2_SIZE *ovector = pcre2_get_ovector_pointer(match_data);
PCRE2_SIZE match_start = ovector[0];
PCRE2_SIZE match_end = ovector[1];
// Copy text before the match
if (last_match_end < match_start) {
std::string before_match = text.substr(last_match_end, match_start - last_match_end);
result += before_match;
// Map each character in before_match
for (size_t i = 0; i < before_match.length(); ++i) {
pos_mapping.emplace_back(last_match_end + i, last_match_end + i);
}
}
// Add the replacement string
result += replacement;
// Map each character in replacement to the start of the match
for (size_t i = 0; i < replacement.length(); ++i) {
pos_mapping.emplace_back(match_start, match_start);
}
last_match_end = match_end;
current_pos = match_end;
// If the match was zero-length, move forward one character to avoid infinite loop
if (match_start == match_end) {
if (current_pos < text.length()) {
current_pos++;
} else {
break;
}
}
}
pcre2_match_data_free(match_data);
pcre2_code_free(re);
return {result, pos_mapping};
}
// Original PCRE2GlobalReplace for backward compatibility
std::string PCRE2GlobalReplace(const std::string &text, const std::string &pattern, const std::string &replacement) {
auto [result, _] = PCRE2GlobalReplaceWithPosition(text, pattern, replacement);
return result;
}
std::string RAGAnalyzer::Tokenize(const std::string &line) const {
// Python-style simple tokenization: re.sub(r"\\W+", " ", line)
std::string processed_line = PCRE2GlobalReplace(line, R"#(\W+)#", " ");
std::string str1 = StrQ2B(processed_line);
std::string strline;
opencc_->convert(str1, strline);
std::vector<std::string> res;
// Use SplitByLang to separate by language
std::vector<std::pair<std::string, bool>> arr;
SplitByLang(strline, arr);
for (const auto &[L, lang] : arr) {
if (!lang) {
// Non-Chinese text: use NLTK tokenizer, lemmatize and stem
std::vector<std::string> term_list;
std::vector<std::string> sentences;
SentenceSplitter(L, sentences);
for (auto &sentence : sentences) {
NLTKWordTokenizer::GetInstance().Tokenize(sentence, term_list);
}
for (unsigned i = 0; i < term_list.size(); ++i) {
std::string t = wordnet_lemma_->Lemmatize(term_list[i]);
std::vector<char> lowercase_buffer(term_string_buffer_limit_);
char *lowercase_term = lowercase_buffer.data();
ToLower(t.c_str(), t.size(), lowercase_term, term_string_buffer_limit_);
std::string stem_term;
stemmer_->Stem(lowercase_term, stem_term);
res.push_back(stem_term);
}
continue;
}
auto length = UTF8Length(L);
if (length < 2 || re2::RE2::PartialMatch(L, pattern2_) || re2::RE2::PartialMatch(L, pattern3_)) {
//[a-z\\.-]+$ [0-9\\.-]+$
res.push_back(L);
continue;
}
// Chinese processing: use TokenizeInner
#if 0
if (length > MAX_SENTENCE_LEN) {
std::vector<std::string> sublines;
SplitLongText(L, length, sublines);
for (auto &l : sublines) {
TokenizeInner(res, l);
}
} else
#endif
TokenizeInner(res, L);
}
// std::vector<std::string> normalize_res;
// EnglishNormalize(res, normalize_res);
std::string r = Join(res, 0);
std::string ret = Merge(r);
return ret;
}
std::pair<std::vector<std::string>, std::vector<std::pair<unsigned, unsigned>>> RAGAnalyzer::TokenizeWithPosition(const std::string &line) const {
// Python-style simple tokenization: re.sub(r"\W+", " ", line)
// Get processed line and position mapping from PCRE2GlobalReplace
auto [processed_line, pcre2_pos_mapping] = PCRE2GlobalReplaceWithPosition(line, R"#(\W+)#", " ");
std::string str1 = StrQ2B(processed_line);
std::string strline;
opencc_->convert(str1, strline);
std::vector<std::string> tokens;
std::vector<std::pair<unsigned, unsigned>> positions;
// Build character position mapping from StrQ2B conversion
std::vector<unsigned> strq2b_pos_mapping;
BuildPositionMapping(processed_line, str1, strq2b_pos_mapping);
// Build character position mapping from OpenCC conversion
std::vector<unsigned> opencc_pos_mapping;
BuildPositionMapping(str1, strline, opencc_pos_mapping);
// Combine all position mappings: strline -> str1 -> processed_line -> line
std::vector<unsigned> final_pos_mapping;
final_pos_mapping.resize(strline.size() + 1);
for (size_t i = 0; i < strline.size(); ++i) {
if (i < opencc_pos_mapping.size()) {
unsigned str1_pos = opencc_pos_mapping[i];
if (str1_pos < strq2b_pos_mapping.size()) {
unsigned processed_pos = strq2b_pos_mapping[str1_pos];
if (processed_pos < pcre2_pos_mapping.size()) {
final_pos_mapping[i] = pcre2_pos_mapping[processed_pos].first;
} else {
final_pos_mapping[i] = static_cast<unsigned>(line.size());
}
} else {
final_pos_mapping[i] = static_cast<unsigned>(line.size());
}
} else {
final_pos_mapping[i] = static_cast<unsigned>(line.size());
}
}
// Fill the last position
if (strline.size() < final_pos_mapping.size()) {
final_pos_mapping[strline.size()] = static_cast<unsigned>(line.size());
}
// Use SplitByLang to separate by language
std::vector<std::pair<std::string, bool>> arr;
SplitByLang(strline, arr);
unsigned current_pos = 0;
for (const auto &[L, lang] : arr) {
if (L.empty()) {
continue;
}
std::size_t processed_pos = strline.find(L, current_pos);
if (processed_pos == std::string::npos) {
continue;
}
unsigned original_start = current_pos;
current_pos = original_start + static_cast<unsigned>(L.size());
if (!lang) {
// Non-Chinese text: use NLTK tokenizer, lemmatize and stem
std::vector<std::string> term_list;
std::vector<std::string> sentences;
SentenceSplitter(L, sentences);
unsigned sentence_start_pos = original_start;
for (auto &sentence : sentences) {
std::vector<std::string> sentence_terms;
NLTKWordTokenizer::GetInstance().Tokenize(sentence, sentence_terms);
unsigned current_search_pos = 0;
for (auto &term : sentence_terms) {
size_t pos_in_sentence = sentence.find(term, current_search_pos);
if (pos_in_sentence != std::string::npos) {
unsigned start_pos = sentence_start_pos + static_cast<unsigned>(pos_in_sentence);
unsigned end_pos = start_pos + static_cast<unsigned>(term.size());
std::string t = wordnet_lemma_->Lemmatize(term);
std::vector<char> lowercase_buffer(term_string_buffer_limit_);
char *lowercase_term = lowercase_buffer.data();
ToLower(t.c_str(), t.size(), lowercase_term, term_string_buffer_limit_);
std::string stem_term;
stemmer_->Stem(lowercase_term, stem_term);
tokens.push_back(stem_term);
// Map positions back to original string using final_pos_mapping
if (start_pos < final_pos_mapping.size()) {
positions.emplace_back(final_pos_mapping[start_pos], final_pos_mapping[end_pos]);
} else {
positions.emplace_back(static_cast<unsigned>(line.size()), static_cast<unsigned>(line.size()));
}
current_search_pos = pos_in_sentence + term.size();
}
}
sentence_start_pos += static_cast<unsigned>(sentence.size());
}
continue;
}
auto length = UTF8Length(L);
if (length < 2 || re2::RE2::PartialMatch(L, pattern2_) || re2::RE2::PartialMatch(L, pattern3_)) {
tokens.push_back(L);
// Map positions back to original string using final_pos_mapping
unsigned start_pos = original_start;
unsigned end_pos = original_start + static_cast<unsigned>(L.size());
if (start_pos < final_pos_mapping.size() && end_pos < final_pos_mapping.size()) {
positions.emplace_back(final_pos_mapping[start_pos], final_pos_mapping[end_pos]);
} else {
positions.emplace_back(static_cast<unsigned>(line.size()), static_cast<unsigned>(line.size()));
}
continue;
}
// Chinese processing: use TokenizeInnerWithPosition
#if 0
if (length > MAX_SENTENCE_LEN) {
std::vector<std::string> sublines;
SplitLongText(L, length, sublines);
unsigned subline_start_pos = original_start;
for (auto &l : sublines) {
TokenizeInnerWithPosition(l, tokens, positions, subline_start_pos, &final_pos_mapping);
subline_start_pos += static_cast<unsigned>(l.size());
}
} else
#endif
TokenizeInnerWithPosition(L, tokens, positions, original_start, &final_pos_mapping);
}
// std::vector<std::string> normalize_tokens;
// std::vector<std::pair<unsigned, unsigned>> normalize_positions;
// EnglishNormalizeWithPosition(tokens, positions, normalize_tokens, normalize_positions);
// Apply MergeWithPosition to match Tokenize behavior
std::vector<std::string> merged_tokens;
std::vector<std::pair<unsigned, unsigned>> merged_positions;
MergeWithPosition(tokens, positions, merged_tokens, merged_positions);
tokens = std::move(merged_tokens);
positions = std::move(merged_positions);
return {std::move(tokens), std::move(positions)};
}
unsigned RAGAnalyzer::MapToOriginalPosition(unsigned processed_pos, const std::vector<std::pair<unsigned, unsigned>> &mapping) const {
for (const auto &[orig, proc] : mapping) {
if (proc == processed_pos) {
return orig;
}
}
return processed_pos;
}
static unsigned CalculateTokensLength(const std::vector<std::string> &tokens, int start, int end) {
unsigned total_length = 0;
for (int i = start; i < end; ++i) {
total_length += static_cast<unsigned>(tokens[i].size());
}
return total_length;
}
void RAGAnalyzer::TokenizeInnerWithPosition(const std::string &L,
std::vector<std::string> &tokens,
std::vector<std::pair<unsigned, unsigned>> &positions,
unsigned base_pos,
const std::vector<unsigned> *pos_mapping) const {
auto [tks, s] = MaxForward(L);
auto [tks1, s1] = MaxBackward(L);
// Use the same algorithm as Python version
std::size_t i = 0, j = 0, _i = 0, _j = 0, same = 0;
while ((i + same < tks1.size()) && (j + same < tks.size()) && tks1[i + same] == tks[j + same]) {
same++;
}
if (same > 0) {
std::string token_str = Join(tks, j, j + same);
unsigned token_len = static_cast<unsigned>(token_str.size());
unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, j);
if (token_str.find(' ') != std::string::npos) {
std::vector<std::string> space_split_tokens;
Split(token_str, blank_pattern_, space_split_tokens, false);
unsigned space_start_pos = start_pos;
for (const auto &space_token : space_split_tokens) {
if (space_token.empty()) {
continue;
}
unsigned space_token_len = static_cast<unsigned>(space_token.size());
tokens.push_back(space_token);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0;
unsigned mapped_end =
(space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(space_start_pos, space_start_pos + space_token_len);
}
space_start_pos += space_token_len;
}
} else {
tokens.push_back(token_str);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0;
unsigned mapped_end = (start_pos + token_len) < pos_mapping->size() ? (*pos_mapping)[start_pos + token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(start_pos, start_pos + token_len);
}
}
}
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
while (i < tks1.size() && j < tks.size()) {
std::string tk1 = Join(tks1, _i, i, "");
std::string tk = Join(tks, _j, j, "");
if (tk1 != tk) {
if (tk1.length() > tk.length()) {
j++;
} else {
i++;
}
continue;
}
if (tks1[i] != tks[j]) {
i++;
j++;
continue;
}
// Handle different part with DFS
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, j, "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
std::string best_token_str = Join(best_tokens, 0);
unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, _j);
std::string original_token_str = Join(tks, _j, j, "");
unsigned end_pos = start_pos + static_cast<unsigned>(original_token_str.size());
if (best_token_str.find(' ') != std::string::npos) {
std::vector<std::string> space_split_tokens;
Split(best_token_str, blank_pattern_, space_split_tokens, false);
unsigned space_start_pos = start_pos;
for (const auto &space_token : space_split_tokens) {
if (space_token.empty()) {
continue;
}
unsigned space_token_len = static_cast<unsigned>(space_token.size());
tokens.push_back(space_token);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0;
unsigned mapped_end =
(space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(space_start_pos, space_start_pos + space_token_len);
}
space_start_pos += space_token_len;
}
} else {
tokens.push_back(best_token_str);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0;
unsigned mapped_end = end_pos < pos_mapping->size() ? (*pos_mapping)[end_pos] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(start_pos, end_pos);
}
}
same = 1;
while (i + same < tks1.size() && j + same < tks.size() && tks1[i + same] == tks[j + same])
same++;
// Handle same part after different tokens
std::string token_str = Join(tks, j, j + same);
unsigned token_len = static_cast<unsigned>(token_str.size());
start_pos = base_pos + CalculateTokensLength(tks, 0, j);
if (token_str.find(' ') != std::string::npos) {
std::vector<std::string> space_split_tokens;
Split(token_str, blank_pattern_, space_split_tokens, false);
unsigned space_start_pos = start_pos;
for (const auto &space_token : space_split_tokens) {
if (space_token.empty()) {
continue;
}
unsigned space_token_len = static_cast<unsigned>(space_token.size());
tokens.push_back(space_token);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0;
unsigned mapped_end =
(space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(space_start_pos, space_start_pos + space_token_len);
}
space_start_pos += space_token_len;
}
} else {
tokens.push_back(token_str);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0;
unsigned mapped_end = (start_pos + token_len) < pos_mapping->size() ? (*pos_mapping)[start_pos + token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(start_pos, start_pos + token_len);
}
}
_i = i + same;
_j = j + same;
j = _j + 1;
i = _i + 1;
}
// Handle remaining part
if (_i < tks1.size()) {
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::vector<std::pair<std::string, int>>> token_list;
std::vector<std::string> best_tokens;
double max_score = std::numeric_limits<double>::lowest();
const auto str_for_dfs = Join(tks, _j, tks.size(), "");
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(str_for_dfs, 0, pre_tokens, token_list, best_tokens, max_score, false);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
dp_debug::CheckDP(this, str_for_dfs, best_tokens, max_score, t0, t1);
#endif
std::string best_token_str = Join(best_tokens, 0);
unsigned start_pos = base_pos + CalculateTokensLength(tks, 0, _j);
std::string original_token_str = Join(tks, _j, tks.size(), "");
unsigned end_pos = start_pos + static_cast<unsigned>(original_token_str.size());
if (best_token_str.find(' ') != std::string::npos) {
std::vector<std::string> space_split_tokens;
Split(best_token_str, blank_pattern_, space_split_tokens, false);
unsigned space_start_pos = start_pos;
for (const auto &space_token : space_split_tokens) {
if (space_token.empty()) {
continue;
}
unsigned space_token_len = static_cast<unsigned>(space_token.size());
tokens.push_back(space_token);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = space_start_pos < pos_mapping->size() ? (*pos_mapping)[space_start_pos] : 0;
unsigned mapped_end =
(space_start_pos + space_token_len) < pos_mapping->size() ? (*pos_mapping)[space_start_pos + space_token_len] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(space_start_pos, space_start_pos + space_token_len);
}
space_start_pos += space_token_len;
}
} else {
tokens.push_back(best_token_str);
// Map position back to original string if mapping is provided
if (pos_mapping) {
unsigned mapped_start = start_pos < pos_mapping->size() ? (*pos_mapping)[start_pos] : 0;
unsigned mapped_end = end_pos < pos_mapping->size() ? (*pos_mapping)[end_pos] : 0;
positions.emplace_back(mapped_start, mapped_end);
} else {
positions.emplace_back(start_pos, end_pos);
}
}
}
}
void RAGAnalyzer::EnglishNormalizeWithPosition(const std::vector<std::string> &tokens,
const std::vector<std::pair<unsigned, unsigned>> &positions,
std::vector<std::string> &normalize_tokens,
std::vector<std::pair<unsigned, unsigned>> &normalize_positions) const {
for (size_t i = 0; i < tokens.size(); ++i) {
const auto &token = tokens[i];
const auto &[start_pos, end_pos] = positions[i];
if (re2::RE2::PartialMatch(token, pattern1_)) {
//"[a-zA-Z_-]+$"
std::string lemma_term = wordnet_lemma_->Lemmatize(token);
std::vector<char> lowercase_buffer(term_string_buffer_limit_);
char *lowercase_term = lowercase_buffer.data();
ToLower(lemma_term.c_str(), lemma_term.size(), lowercase_term, term_string_buffer_limit_);
std::string stem_term;
stemmer_->Stem(lowercase_term, stem_term);
normalize_tokens.push_back(stem_term);
normalize_positions.emplace_back(start_pos, end_pos);
} else {
normalize_tokens.push_back(token);
normalize_positions.emplace_back(start_pos, end_pos);
}
}
}
void RAGAnalyzer::FineGrainedTokenizeWithPosition(const std::string &tokens_str,
const std::vector<std::pair<unsigned, unsigned>> &positions,
std::vector<std::string> &fine_tokens,
std::vector<std::pair<unsigned, unsigned>> &fine_positions) const {
std::vector<std::string> tks;
Split(tokens_str, blank_pattern_, tks);
std::size_t zh_num = 0;
for (auto &token : tks) {
int len = UTF8Length(token);
for (int i = 0; i < len; ++i) {
std::string t = UTF8Substr(token, i, 1);
if (IsChinese(t)) {
zh_num++;
}
}
}
if (zh_num < tks.size() * 0.2) {
// English text processing - apply normalization
std::vector<std::string> temp_tokens;
for (size_t i = 0; i < tks.size(); ++i) {
const auto &token = tks[i];
const auto &[start_pos, end_pos] = positions[i];
std::istringstream iss(token);
std::string sub_token;
unsigned sub_start = start_pos;
while (std::getline(iss, sub_token, '/')) {
if (!sub_token.empty()) {
unsigned sub_end = sub_start + sub_token.size();
fine_tokens.push_back(sub_token);
fine_positions.emplace_back(sub_start, sub_end);
sub_start = sub_end + 1;
}
}
}
// Apply English normalization to get lowercase and stemmed tokens
// std::vector<std::pair<unsigned, unsigned>> temp_positions = fine_positions;
// EnglishNormalizeWithPosition(temp_tokens, temp_positions, fine_tokens, fine_positions);
} else {
// Chinese or mixed text processing - match FineGrainedTokenize behavior
for (size_t i = 0; i < tks.size(); ++i) {
const auto &token = tks[i];
const auto &[start_pos, end_pos] = positions[i];
const auto token_len = UTF8Length(token);
if (token_len < 3 || re2::RE2::PartialMatch(token, pattern4_)) {
fine_tokens.push_back(token);
fine_positions.emplace_back(start_pos, end_pos);
continue;
}
std::vector<std::vector<std::pair<std::string, int>>> token_list;
if (token_len > 10) {
std::vector<std::pair<std::string, int>> tk;
tk.emplace_back(token, Encode(-1, 0));
token_list.push_back(tk);
} else {
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::string> best_tokens;
double max_score = 0.0F;
DFS(token, 0, pre_tokens, token_list, best_tokens, max_score, true);
}
if (token_list.size() < 2) {
fine_tokens.push_back(token);
fine_positions.emplace_back(start_pos, end_pos);
continue;
}
std::vector<std::pair<std::vector<std::string>, double>> sorted_tokens;
SortTokens(token_list, sorted_tokens);
const auto &stk = sorted_tokens[1].first;
if (stk.size() == token_len) {
fine_tokens.push_back(token);
fine_positions.emplace_back(start_pos, end_pos);
} else if (re2::RE2::PartialMatch(token, pattern5_)) {
bool need_append_stk = true;
for (auto &t : stk) {
if (UTF8Length(t) < 3) {
fine_tokens.push_back(token);
fine_positions.emplace_back(start_pos, end_pos);
need_append_stk = false;
break;
}
}
if (need_append_stk) {
unsigned sub_pos = start_pos;
for (auto &t : stk) {
unsigned sub_end = sub_pos + UTF8Length(t);
fine_tokens.push_back(t);
fine_positions.emplace_back(sub_pos, sub_end);
sub_pos = sub_end;
}
}
} else {
unsigned sub_pos = start_pos;
for (auto &t : stk) {
unsigned sub_end = sub_pos + static_cast<unsigned>(t.size());
fine_tokens.push_back(t);
fine_positions.emplace_back(sub_pos, sub_end);
sub_pos = sub_end;
}
}
}
}
// Apply English normalization only if needed, similar to FineGrainedTokenize
// For Chinese text, no additional normalization needed
// fine_tokens already contains the correct Chinese tokens
}
void RAGAnalyzer::FineGrainedTokenize(const std::string &tokens, std::vector<std::string> &result) const {
std::vector<std::string> tks;
Split(tokens, blank_pattern_, tks);
std::vector<std::string> res;
std::size_t zh_num = 0;
for (auto &token : tks) {
int len = UTF8Length(token);
for (int i = 0; i < len; ++i) {
std::string t = UTF8Substr(token, i, 1);
if (IsChinese(t)) {
zh_num++;
}
}
}
if (zh_num < tks.size() * 0.2) {
for (auto &token : tks) {
std::istringstream iss(token);
std::string sub_token;
while (std::getline(iss, sub_token, '/')) {
result.push_back(sub_token);
}
}
// std::string ret = Join(res, 0);
return;
}
for (auto &token : tks) {
const auto token_len = UTF8Length(token);
if (token_len < 3 || re2::RE2::PartialMatch(token, pattern4_)) {
//[0-9,\\.-]+$
res.push_back(token);
continue;
}
std::vector<std::vector<std::pair<std::string, int>>> token_list;
if (token_len > 10) {
std::vector<std::pair<std::string, int>> tk;
tk.emplace_back(token, Encode(-1, 0));
token_list.push_back(tk);
} else {
std::vector<std::pair<std::string, int>> pre_tokens;
std::vector<std::string> best_tokens;
double max_score = 0.0F;
#ifdef INFINITY_DEBUG
const auto t0 = std::chrono::high_resolution_clock::now();
#endif
DFS(token, 0, pre_tokens, token_list, best_tokens, max_score, true);
#ifdef INFINITY_DEBUG
const auto t1 = std::chrono::high_resolution_clock::now();
auto get_dfs_sorted_tokens = [&]() {
std::vector<std::pair<std::vector<std::string>, double>> sorted_tokens;
SortTokens(token_list, sorted_tokens);
return sorted_tokens;
};
dp_debug::CheckDP2(this, token, get_dfs_sorted_tokens, t0, t1);
#endif
}
if (token_list.size() < 2) {
res.push_back(token);
continue;
}
std::vector<std::pair<std::vector<std::string>, double>> sorted_tokens;
SortTokens(token_list, sorted_tokens);
const auto &stk = sorted_tokens[1].first;
if (stk.size() == token_len) {
res.push_back(token);
} else if (re2::RE2::PartialMatch(token, pattern5_)) {
// [a-z\\.-]+
bool need_append_stk = true;
for (auto &t : stk) {
if (UTF8Length(t) < 3) {
res.push_back(token);
need_append_stk = false;
break;
}
}
if (need_append_stk) {
for (auto &t : stk) {
res.push_back(t);
}
}
} else {
for (auto &t : stk) {
res.push_back(t);
}
}
}
EnglishNormalize(res, result);
// std::string ret = Join(normalize_res, 0);
// return ret;
}
int RAGAnalyzer::AnalyzeImpl(const Term &input, void *data, bool fine_grained, bool enable_position, HookType func) const {
if (enable_position) {
auto [tokens, positions] = TokenizeWithPosition(input.text_);
if (fine_grained) {
std::vector<std::string> fine_tokens;
std::vector<std::pair<unsigned, unsigned>> fine_positions;
FineGrainedTokenizeWithPosition(Join(tokens, 0), positions, fine_tokens, fine_positions);
tokens = std::move(fine_tokens);
positions = std::move(fine_positions);
}
for (size_t i = 0; i < tokens.size(); ++i) {
if (tokens[i].empty())
continue;
const auto &[start_pos, end_pos] = positions[i];
func(data, tokens[i].c_str(), tokens[i].size(), start_pos, end_pos, false, 0);
}
} else {
std::string result = Tokenize(input.text_);
std::vector<std::string> tokens;
if (fine_grained) {
FineGrainedTokenize(result, tokens);
} else {
Split(result, blank_pattern_, tokens);
}
unsigned offset = 0;
for (auto &t : tokens) {
if (t.empty())
continue;
func(data, t.c_str(), t.size(), offset++, 0, false, 0);
}
}
return 0;
}