Feat:remove estimation of embedding cost (#7950)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
KVOJJJin
2024-09-04 14:41:47 +08:00
committed by GitHub
parent 83e84865be
commit 14af87527f
14 changed files with 122 additions and 162 deletions

View File

@@ -108,7 +108,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
@@ -116,8 +116,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
docs = []
current_doc: list[str] = []
total = 0
index = 0
for d in splits:
_len = self._length_function(d)
_len = lengths[index]
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
@@ -145,6 +146,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
index += 1
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
@@ -493,11 +495,10 @@ class RecursiveCharacterTextSplitter(TextSplitter):
self._separators = separators or ["\n\n", "\n", " ", ""]
def _split_text(self, text: str, separators: list[str]) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
if _s == "":
separator = _s
@@ -508,25 +509,31 @@ class RecursiveCharacterTextSplitter(TextSplitter):
break
splits = _split_text_with_regex(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
s_len = self._length_function(s)
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> list[str]: