mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 03:46:52 +08:00
fix: typo in package path of core.splitter (#2411)
This commit is contained in:
116
api/core/splitter/fixed_text_splitter.py
Normal file
116
api/core/splitter/fixed_text_splitter.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Functionality for splitting text."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from langchain.text_splitter import (
|
||||
TS,
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
TokenTextSplitter,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_encoder(
|
||||
cls: Type[TS],
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
return embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[text]
|
||||
)
|
||||
else:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_token_encoder, **kwargs)
|
||||
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if self._fixed_separator:
|
||||
chunks = text.split(self._fixed_separator)
|
||||
else:
|
||||
chunks = list(text)
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if self._length_function(chunk) > self._chunk_size:
|
||||
final_chunks.extend(self.recursive_split_text(chunk))
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def recursive_split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.recursive_split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
Reference in New Issue
Block a user