diff --git a/internal/cpp/rag_analyzer.cpp b/internal/cpp/rag_analyzer.cpp index 5f7799bb1..9584b2c06 100644 --- a/internal/cpp/rag_analyzer.cpp +++ b/internal/cpp/rag_analyzer.cpp @@ -964,15 +964,33 @@ std::pair, double> RAGAnalyzer::MaxBackward(const std:: return Score(res); } +static constexpr int MAX_DFS_DEPTH = 10; int RAGAnalyzer::DFS(const std::string &chars, const int s, std::vector> &pre_tokens, std::vector>> &token_list, std::vector &best_tokens, double &max_score, - const bool memo_all) const { + const bool memo_all, + const int depth) const { int res = s; const int len = UTF8Length(chars); + + // Check max recursion depth - graceful degradation like Python version + if (depth > MAX_DFS_DEPTH) { + if (s < len) { + auto pretks = pre_tokens; + std::string remaining = UTF8Substr(chars, s, len - s); + pretks.emplace_back(std::move(remaining), Encode(-12, 0)); + if (memo_all) { + token_list.push_back(std::move(pretks)); + } else if (auto [vec_str, current_score] = Score(pretks); current_score > max_score) { + best_tokens = std::move(vec_str); + max_score = current_score; + } + } + return len; + } if (s >= len) { if (memo_all) { token_list.push_back(pre_tokens); @@ -1011,7 +1029,7 @@ int RAGAnalyzer::DFS(const std::string &chars, if (const int v = trie_->Get(k); v != -1) { auto pretks = pre_tokens; pretks.emplace_back(std::move(t), v); - res = std::max(res, DFS(chars, e, pretks, token_list, best_tokens, max_score, memo_all)); + res = std::max(res, DFS(chars, e, pretks, token_list, best_tokens, max_score, memo_all, depth + 1)); } } @@ -1026,7 +1044,7 @@ int RAGAnalyzer::DFS(const std::string &chars, pre_tokens.emplace_back(std::move(t), Encode(-12, 0)); } - return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all); + return DFS(chars, s + 1, pre_tokens, token_list, best_tokens, max_score, memo_all, depth + 1); } struct TokensList { diff --git a/internal/cpp/rag_analyzer.h b/internal/cpp/rag_analyzer.h index 70331445d..78a75d713 100644 --- a/internal/cpp/rag_analyzer.h +++ b/internal/cpp/rag_analyzer.h @@ -121,7 +121,8 @@ private: std::vector>>& token_list, std::vector& best_tokens, double& max_score, - bool memo_all) const; + bool memo_all, + int depth = 0) const; void TokenizeInner(std::vector& res, const std::string& L) const;