mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-23 01:36:55 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -32,8 +32,11 @@ class Jieba(BaseKeyword):
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, text.metadata["doc_id"], list(keywords)
|
||||
)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -58,20 +61,26 @@ class Jieba(BaseKeyword):
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, text.metadata["doc_id"], list(keywords)
|
||||
)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
if keyword_table is not None:
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -80,7 +89,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
@@ -137,7 +146,7 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict["__data__"]["table"]
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
@@ -188,8 +197,8 @@ class Jieba(BaseKeyword):
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords_list:
|
||||
for node_id in keyword_table[keyword]:
|
||||
chunk_indices_count[node_id] += 1
|
||||
|
||||
@@ -215,7 +224,7 @@ class Jieba(BaseKeyword):
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
@@ -226,17 +235,19 @@ class Jieba(BaseKeyword):
|
||||
if pre_segment_data["keywords"]:
|
||||
segment.keywords = pre_segment_data["keywords"]
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, segment.index_node_id, list(keywords)
|
||||
)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
import jieba.analyse
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
@@ -12,7 +12,7 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba
|
||||
import jieba # type: ignore
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
|
||||
@@ -37,6 +37,8 @@ class BaseKeyword(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata is None:
|
||||
continue
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
@@ -45,4 +47,4 @@ class BaseKeyword(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata]
|
||||
|
||||
Reference in New Issue
Block a user