mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: support multi token count
This commit is contained in:
@ -79,7 +79,13 @@ class DatasetDocumentStore:
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if embedding_model:
|
||||
page_content_list = [doc.page_content for doc in docs]
|
||||
tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list)
|
||||
else:
|
||||
tokens_list = [0] * len(docs)
|
||||
|
||||
for doc, tokens in zip(docs, tokens_list):
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
@ -91,12 +97,6 @@ class DatasetDocumentStore:
|
||||
f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if embedding_model:
|
||||
tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
|
||||
@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
chunks = [text]
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if self._length_function(chunk) > self._chunk_size:
|
||||
chunks_lengths = self._length_function(chunks)
|
||||
for chunk, chunk_length in zip(chunks, chunks_lengths):
|
||||
if chunk_length > self._chunk_size:
|
||||
final_chunks.extend(self.recursive_split_text(chunk))
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
@ -93,7 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
for s in splits:
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
s_len = self._length_function(s)
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
|
||||
@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
) -> None:
|
||||
@ -224,8 +224,8 @@ class CharacterTextSplitter(TextSplitter):
|
||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
for split in splits:
|
||||
_good_splits_lengths.append(self._length_function(split))
|
||||
if splits:
|
||||
_good_splits_lengths.extend(self._length_function(splits))
|
||||
return self._merge_splits(splits, _separator, _good_splits_lengths)
|
||||
|
||||
|
||||
@ -478,9 +478,8 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
|
||||
for s in splits:
|
||||
s_len = self._length_function(s)
|
||||
s_lens = self._length_function(splits)
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
|
||||
Reference in New Issue
Block a user