mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
modify rerank and splitter code directory (#4924)
This commit is contained in:
@@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.data_post_processor.reorder import ReorderRunner
|
||||
from core.rag.models.document import Document
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
|
||||
|
||||
class DataPostProcessor:
|
||||
|
||||
201
api/core/rag/docstore/dataset_docstore.py
Normal file
201
api/core/rag/docstore/dataset_docstore.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
user_id: str,
|
||||
document_id: Optional[str] = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._user_id = user_id
|
||||
self._document_id = document_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore":
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize to dict."""
|
||||
return {
|
||||
"dataset_id": self._dataset.id,
|
||||
}
|
||||
|
||||
@property
|
||||
def dateset_id(self) -> Any:
|
||||
return self._dataset.id
|
||||
|
||||
@property
|
||||
def user_id(self) -> Any:
|
||||
return self._user_id
|
||||
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id
|
||||
).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
doc_id = document_segment.index_node_id
|
||||
output[doc_id] = Document(
|
||||
page_content=document_segment.content,
|
||||
metadata={
|
||||
"doc_id": document_segment.index_node_id,
|
||||
"doc_hash": document_segment.index_node_hash,
|
||||
"document_id": document_segment.document_id,
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def add_documents(
|
||||
self, docs: Sequence[Document], allow_update: bool = True
|
||||
) -> None:
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == self._document_id
|
||||
).scalar()
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == 'high_quality':
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
raise ValueError(
|
||||
f"doc_id {doc.metadata['doc_id']} already exists. "
|
||||
"Set allow_update to True to overwrite."
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
if embedding_model:
|
||||
model_type_instance = embedding_model.model_type_instance
|
||||
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||
tokens = model_type_instance.get_num_tokens(
|
||||
model=embedding_model.model,
|
||||
credentials=embedding_model.credentials,
|
||||
texts=[doc.page_content]
|
||||
)
|
||||
else:
|
||||
tokens = 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
document_id=self._document_id,
|
||||
index_node_id=doc.metadata['doc_id'],
|
||||
index_node_hash=doc.metadata['doc_hash'],
|
||||
position=max_position,
|
||||
content=doc.page_content,
|
||||
word_count=len(doc.page_content),
|
||||
tokens=tokens,
|
||||
enabled=False,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
if doc.metadata.get('answer'):
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
|
||||
db.session.add(segment_document)
|
||||
else:
|
||||
segment_document.content = doc.page_content
|
||||
if doc.metadata.get('answer'):
|
||||
segment_document.answer = doc.metadata.pop('answer', '')
|
||||
segment_document.index_node_hash = doc.metadata['doc_hash']
|
||||
segment_document.word_count = len(doc.page_content)
|
||||
segment_document.tokens = tokens
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def document_exists(self, doc_id: str) -> bool:
|
||||
"""Check if document exists."""
|
||||
result = self.get_document_segment(doc_id)
|
||||
return result is not None
|
||||
|
||||
def get_document(
|
||||
self, doc_id: str, raise_error: bool = True
|
||||
) -> Optional[Document]:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
return Document(
|
||||
page_content=document_segment.content,
|
||||
metadata={
|
||||
"doc_id": document_segment.index_node_id,
|
||||
"doc_hash": document_segment.index_node_hash,
|
||||
"document_id": document_segment.document_id,
|
||||
"dataset_id": document_segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
if raise_error:
|
||||
raise ValueError(f"doc_id {doc_id} not found.")
|
||||
else:
|
||||
return None
|
||||
|
||||
db.session.delete(document_segment)
|
||||
db.session.commit()
|
||||
|
||||
def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
|
||||
"""Set the hash for a given doc_id."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
document_segment.index_node_hash = doc_hash
|
||||
db.session.commit()
|
||||
|
||||
def get_document_hash(self, doc_id: str) -> Optional[str]:
|
||||
"""Get the stored hash for a document, if it exists."""
|
||||
document_segment = self.get_document_segment(doc_id)
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
|
||||
return document_segment.index_node_hash
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self._dataset.id,
|
||||
DocumentSegment.index_node_id == doc_id
|
||||
).first()
|
||||
|
||||
return document_segment
|
||||
@@ -7,8 +7,11 @@ from flask import current_app
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from core.splitter.text_splitter import TextSplitter
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
|
||||
|
||||
|
||||
0
api/core/rag/rerank/__init__.py
Normal file
0
api/core/rag/rerank/__init__.py
Normal file
57
api/core/rag/rerank/rerank.py
Normal file
57
api/core/rag/rerank/rerank.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class RerankRunner:
|
||||
def __init__(self, rerank_model_instance: ModelInstance) -> None:
|
||||
self.rerank_model_instance = rerank_model_instance
|
||||
|
||||
def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
|
||||
"""
|
||||
Run rerank model
|
||||
:param query: search query
|
||||
:param documents: documents for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id if needed
|
||||
:return:
|
||||
"""
|
||||
docs = []
|
||||
doc_id = []
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.metadata['doc_id'] not in doc_id:
|
||||
doc_id.append(document.metadata['doc_id'])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
|
||||
documents = unique_documents
|
||||
|
||||
rerank_result = self.rerank_model_instance.invoke_rerank(
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
|
||||
for result in rerank_result.docs:
|
||||
# format document
|
||||
rerank_document = Document(
|
||||
page_content=result.text,
|
||||
metadata={
|
||||
"doc_id": documents[result.index].metadata['doc_id'],
|
||||
"doc_hash": documents[result.index].metadata['doc_hash'],
|
||||
"document_id": documents[result.index].metadata['document_id'],
|
||||
"dataset_id": documents[result.index].metadata['dataset_id'],
|
||||
'score': result.score
|
||||
}
|
||||
)
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return rerank_documents
|
||||
@@ -14,9 +14,9 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank import RerankRunner
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.rerank.rerank import RerankRunner
|
||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
114
api/core/rag/splitter/fixed_text_splitter.py
Normal file
114
api/core/rag/splitter/fixed_text_splitter.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Functionality for splitting text."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
TokenTextSplitter,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_encoder(
|
||||
cls: type[TS],
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
allowed_special: Union[Literal[all], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal[all], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
return embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
texts=[text]
|
||||
)
|
||||
else:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_token_encoder, **kwargs)
|
||||
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = fixed_separator
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
if self._fixed_separator:
|
||||
chunks = text.split(self._fixed_separator)
|
||||
else:
|
||||
chunks = list(text)
|
||||
|
||||
final_chunks = []
|
||||
for chunk in chunks:
|
||||
if self._length_function(chunk) > self._chunk_size:
|
||||
final_chunks.extend(self.recursive_split_text(chunk))
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
|
||||
return final_chunks
|
||||
|
||||
def recursive_split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.recursive_split_text(s)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
534
api/core/rag/splitter/text_splitter.py
Normal file
534
api/core/rag/splitter/text_splitter.py
Normal file
@@ -0,0 +1,534 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from core.rag.models.document import BaseDocumentTransformer, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text_with_regex(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> list[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({re.escape(separator)})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = re.split(separator, text)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of chunks to return
|
||||
chunk_overlap: Overlap in characters between chunks
|
||||
length_function: Function that measures the length of given chunks
|
||||
keep_separator: Whether to keep the separator in the chunks
|
||||
add_start_index: If `True`, includes chunk's start index in metadata
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._keep_separator = keep_separator
|
||||
self._add_start_index = add_start_index
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split text into multiple components."""
|
||||
|
||||
def create_documents(
|
||||
self, texts: list[str], metadatas: Optional[list[dict]] = None
|
||||
) -> list[Document]:
|
||||
"""Create documents from a list of texts."""
|
||||
_metadatas = metadatas or [{}] * len(texts)
|
||||
documents = []
|
||||
for i, text in enumerate(texts):
|
||||
index = -1
|
||||
for chunk in self.split_text(text):
|
||||
metadata = copy.deepcopy(_metadatas[i])
|
||||
if self._add_start_index:
|
||||
index = text.find(chunk, index + 1)
|
||||
metadata["start_index"] = index
|
||||
new_doc = Document(page_content=chunk, metadata=metadata)
|
||||
documents.append(new_doc)
|
||||
return documents
|
||||
|
||||
def split_documents(self, documents: Iterable[Document]) -> list[Document]:
|
||||
"""Split documents."""
|
||||
texts, metadatas = [], []
|
||||
for doc in documents:
|
||||
texts.append(doc.page_content)
|
||||
metadatas.append(doc.metadata)
|
||||
return self.create_documents(texts, metadatas=metadatas)
|
||||
|
||||
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
|
||||
text = separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
separator_len = self._length_function(separator)
|
||||
|
||||
docs = []
|
||||
current_doc: list[str] = []
|
||||
total = 0
|
||||
for d in splits:
|
||||
_len = self._length_function(d)
|
||||
if (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
):
|
||||
if total > self._chunk_size:
|
||||
logger.warning(
|
||||
f"Created a chunk of size {total}, "
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
while total > self._chunk_overlap or (
|
||||
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
||||
> self._chunk_size
|
||||
and total > 0
|
||||
):
|
||||
total -= self._length_function(current_doc[0]) + (
|
||||
separator_len if len(current_doc) > 1 else 0
|
||||
)
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
||||
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError(
|
||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||
)
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str) -> int:
|
||||
return len(
|
||||
enc.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"encoding_name": encoding_name,
|
||||
"model_name": model_name,
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a sequence of documents by splitting them."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text that looks at characters."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
|
||||
class LineType(TypedDict):
|
||||
"""Line type as typed dict."""
|
||||
|
||||
metadata: dict[str, str]
|
||||
content: str
|
||||
|
||||
|
||||
class HeaderType(TypedDict):
|
||||
"""Header type as typed dict."""
|
||||
|
||||
level: int
|
||||
name: str
|
||||
data: str
|
||||
|
||||
|
||||
class MarkdownHeaderTextSplitter:
|
||||
"""Splitting markdown files based on specified headers."""
|
||||
|
||||
def __init__(
|
||||
self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
|
||||
):
|
||||
"""Create a new MarkdownHeaderTextSplitter.
|
||||
|
||||
Args:
|
||||
headers_to_split_on: Headers we want to track
|
||||
return_each_line: Return each line w/ associated headers
|
||||
"""
|
||||
# Output line-by-line or aggregated into chunks w/ common headers
|
||||
self.return_each_line = return_each_line
|
||||
# Given the headers we want to split on,
|
||||
# (e.g., "#, ##, etc") order by length
|
||||
self.headers_to_split_on = sorted(
|
||||
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
||||
)
|
||||
|
||||
def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
|
||||
"""Combine lines with common metadata into chunks
|
||||
Args:
|
||||
lines: Line of text / associated header metadata
|
||||
"""
|
||||
aggregated_chunks: list[LineType] = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
aggregated_chunks
|
||||
and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
||||
):
|
||||
# If the last line in the aggregated list
|
||||
# has the same metadata as the current line,
|
||||
# append the current content to the last lines's content
|
||||
aggregated_chunks[-1]["content"] += " \n" + line["content"]
|
||||
else:
|
||||
# Otherwise, append the current line to the aggregated list
|
||||
aggregated_chunks.append(line)
|
||||
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in aggregated_chunks
|
||||
]
|
||||
|
||||
def split_text(self, text: str) -> list[Document]:
|
||||
"""Split markdown file
|
||||
Args:
|
||||
text: Markdown file"""
|
||||
|
||||
# Split the input text by newline character ("\n").
|
||||
lines = text.split("\n")
|
||||
# Final output
|
||||
lines_with_metadata: list[LineType] = []
|
||||
# Content and metadata of the chunk currently being processed
|
||||
current_content: list[str] = []
|
||||
current_metadata: dict[str, str] = {}
|
||||
# Keep track of the nested header structure
|
||||
# header_stack: List[Dict[str, Union[int, str]]] = []
|
||||
header_stack: list[HeaderType] = []
|
||||
initial_metadata: dict[str, str] = {}
|
||||
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
# Check each line against each of the header types (e.g., #, ##)
|
||||
for sep, name in self.headers_to_split_on:
|
||||
# Check if line starts with a header that we intend to split on
|
||||
if stripped_line.startswith(sep) and (
|
||||
# Header with no text OR header is followed by space
|
||||
# Both are valid conditions that sep is being used a header
|
||||
len(stripped_line) == len(sep)
|
||||
or stripped_line[len(sep)] == " "
|
||||
):
|
||||
# Ensure we are tracking the header as metadata
|
||||
if name is not None:
|
||||
# Get the current header level
|
||||
current_header_level = sep.count("#")
|
||||
|
||||
# Pop out headers of lower or same level from the stack
|
||||
while (
|
||||
header_stack
|
||||
and header_stack[-1]["level"] >= current_header_level
|
||||
):
|
||||
# We have encountered a new header
|
||||
# at the same or higher level
|
||||
popped_header = header_stack.pop()
|
||||
# Clear the metadata for the
|
||||
# popped header in initial_metadata
|
||||
if popped_header["name"] in initial_metadata:
|
||||
initial_metadata.pop(popped_header["name"])
|
||||
|
||||
# Push the current header to the stack
|
||||
header: HeaderType = {
|
||||
"level": current_header_level,
|
||||
"name": name,
|
||||
"data": stripped_line[len(sep):].strip(),
|
||||
}
|
||||
header_stack.append(header)
|
||||
# Update initial_metadata with the current header
|
||||
initial_metadata[name] = header["data"]
|
||||
|
||||
# Add the previous line to the lines_with_metadata
|
||||
# only if current_content is not empty
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{
|
||||
"content": "\n".join(current_content),
|
||||
"metadata": current_metadata.copy(),
|
||||
}
|
||||
)
|
||||
current_content.clear()
|
||||
|
||||
break
|
||||
else:
|
||||
if stripped_line:
|
||||
current_content.append(stripped_line)
|
||||
elif current_content:
|
||||
lines_with_metadata.append(
|
||||
{
|
||||
"content": "\n".join(current_content),
|
||||
"metadata": current_metadata.copy(),
|
||||
}
|
||||
)
|
||||
current_content.clear()
|
||||
|
||||
current_metadata = initial_metadata.copy()
|
||||
|
||||
if current_content:
|
||||
lines_with_metadata.append(
|
||||
{"content": "\n".join(current_content), "metadata": current_metadata}
|
||||
)
|
||||
|
||||
# lines_with_metadata has each line with associated header metadata
|
||||
# aggregate these into chunks based on common metadata
|
||||
if not self.return_each_line:
|
||||
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
||||
else:
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in lines_with_metadata
|
||||
]
|
||||
|
||||
|
||||
# should be in newer Python versions (3.10+)
|
||||
# @dataclass(frozen=True, kw_only=True, slots=True)
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
chunk_overlap: int
|
||||
tokens_per_chunk: int
|
||||
decode: Callable[[list[int]], str]
|
||||
encode: Callable[[str], list[int]]
|
||||
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
|
||||
"""Split incoming text and return chunks using tokenizer."""
|
||||
splits: list[str] = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
while start_idx < len(input_ids):
|
||||
splits.append(tokenizer.decode(chunk_ids))
|
||||
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
return splits
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Splitting text to tokens using model tokenizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
return self._tokenizer.encode(
|
||||
_text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer(
|
||||
chunk_overlap=self._chunk_overlap,
|
||||
tokens_per_chunk=self._chunk_size,
|
||||
decode=self._tokenizer.decode,
|
||||
encode=_encode,
|
||||
)
|
||||
|
||||
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text by recursively look at characters.
|
||||
|
||||
Recursively tries to split by different characters to find one
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[list[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def _split_text(self, text: str, separators: list[str]) -> list[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if re.search(_s, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1:]
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
Reference in New Issue
Block a user