Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Jyong
2024-12-26 00:16:35 +08:00
committed by GitHub
parent bb35818976
commit 84ac004772
20 changed files with 264 additions and 210 deletions

View File

@@ -258,78 +258,79 @@ class RetrievalService:
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
.first()
)
if result:
child_chunk, segment = result
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)

View File

@@ -122,26 +122,27 @@ class DatasetDocumentStore:
db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
@@ -160,8 +161,8 @@ class DatasetDocumentStore:
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",

View File

@@ -4,7 +4,7 @@ import os
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter
return character_splitter # type: ignore

View File

@@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,

View File

@@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
all_documents = []
all_documents = [] # type: ignore
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
@@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
if not rules.subchunk_segmentation:
raise ValueError("No subchunk segmentation found in rules.")
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,

View File

@@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
separator=rules.segmentation.separator if rules.segmentation else "",
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
@@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
if preview:
self._format_qa_document(
current_app._get_current_object(),
kwargs.get("tenant_id"),
current_app._get_current_object(), # type: ignore
kwargs.get("tenant_id"), # type: ignore
all_documents[0],
all_qa_documents,
kwargs.get("doc_language", "English"),
@@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor):
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"), # type: ignore
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),

View File

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
class ChildDocument(BaseModel):
@@ -15,7 +15,7 @@ class ChildDocument(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}
class Document(BaseModel):
@@ -28,7 +28,7 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)
metadata: dict = {}
provider: Optional[str] = "dify"