feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -49,6 +49,7 @@ class BaseIndexProcessor(ABC):
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule
rules = processing_rule["rules"]

View File

@@ -9,7 +9,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit."""
def __init__(self, index_type: str):
def __init__(self, index_type: str | None):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:

View File

@@ -27,12 +27,13 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
processing_rule=kwargs.get("process_rule", {}),
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {}))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
@@ -41,8 +42,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:

View File

@@ -32,15 +32,16 @@ class QAIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
processing_rule=kwargs.get("process_rule") or {},
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
# Split the text documents into nodes.
all_documents = []
all_qa_documents = []
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {})
document.page_content = document_text
# parse document to nodes
@@ -50,8 +51,9 @@ class QAIndexProcessor(BaseIndexProcessor):
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
@@ -64,7 +66,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
@@ -148,11 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor):
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e: