mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 06:16:53 +08:00
Feat:remove estimation of embedding cost (#7950)
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -108,7 +108,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function(separator)
|
||||
@@ -116,8 +116,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
total = 0
|
||||
index = 0
|
||||
for d in splits:
|
||||
_len = self._length_function(d)
|
||||
_len = lengths[index]
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
@@ -145,6 +146,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
index += 1
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
@@ -493,11 +495,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
@@ -508,25 +509,31 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
s_len = self._length_function(s)
|
||||
if s_len < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
_good_splits_lengths.append(s_len)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
_good_splits_lengths = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
|
||||
final_chunks.extend(merged_text)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user