mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-11 11:56:53 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -24,37 +24,42 @@ class Jieba(BaseKeyword):
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for text in texts:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
return self
|
||||
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keywords_list = kwargs.get('keywords_list', None)
|
||||
keywords_list = kwargs.get("keywords_list", None)
|
||||
for i in range(len(texts)):
|
||||
text = texts[i]
|
||||
if keywords_list:
|
||||
keywords = keywords_list[i]
|
||||
if not keywords:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -63,97 +68,91 @@ class Jieba(BaseKeyword):
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
k = kwargs.get('top_k', 4)
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == self.dataset.id,
|
||||
DocumentSegment.index_node_id == chunk_index
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
|
||||
.first()
|
||||
)
|
||||
|
||||
if segment:
|
||||
|
||||
documents.append(Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
))
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": chunk_index,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
db.session.delete(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
if dataset_keyword_table.data_source_type != 'database':
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
if dataset_keyword_table.data_source_type != "database":
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
storage.delete(file_key)
|
||||
|
||||
def _save_dataset_keyword_table(self, keyword_table):
|
||||
keyword_table_dict = {
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": keyword_table
|
||||
}
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == 'database':
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
|
||||
db.session.commit()
|
||||
else:
|
||||
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
|
||||
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
|
||||
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
|
||||
|
||||
def _get_dataset_keyword_table(self) -> Optional[dict]:
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict['__data__']['table']
|
||||
return keyword_table_dict["__data__"]["table"]
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
dataset_id=self.dataset.id,
|
||||
keyword_table='',
|
||||
keyword_table="",
|
||||
data_source_type=keyword_data_source_type,
|
||||
)
|
||||
if keyword_data_source_type == 'database':
|
||||
dataset_keyword_table.keyword_table = json.dumps({
|
||||
'__type__': 'keyword_table',
|
||||
'__data__': {
|
||||
"index_id": self.dataset.id,
|
||||
"summary": None,
|
||||
"table": {}
|
||||
}
|
||||
}, cls=SetEncoder)
|
||||
if keyword_data_source_type == "database":
|
||||
dataset_keyword_table.keyword_table = json.dumps(
|
||||
{
|
||||
"__type__": "keyword_table",
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
|
||||
},
|
||||
cls=SetEncoder,
|
||||
)
|
||||
db.session.add(dataset_keyword_table)
|
||||
db.session.commit()
|
||||
|
||||
@@ -174,9 +173,7 @@ class Jieba(BaseKeyword):
|
||||
keywords_to_delete = set()
|
||||
for keyword, node_idxs in keyword_table.items():
|
||||
if node_idxs_to_delete.intersection(node_idxs):
|
||||
keyword_table[keyword] = node_idxs.difference(
|
||||
node_idxs_to_delete
|
||||
)
|
||||
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
|
||||
if not keyword_table[keyword]:
|
||||
keywords_to_delete.add(keyword)
|
||||
|
||||
@@ -202,13 +199,14 @@ class Jieba(BaseKeyword):
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
).first()
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
@@ -224,14 +222,14 @@ class Jieba(BaseKeyword):
|
||||
keyword_table_handler = JiebaKeywordTableHandler()
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
for pre_segment_data in pre_segment_data_list:
|
||||
segment = pre_segment_data['segment']
|
||||
if pre_segment_data['keywords']:
|
||||
segment.keywords = pre_segment_data['keywords']
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
|
||||
pre_segment_data['keywords'])
|
||||
segment = pre_segment_data["segment"]
|
||||
if pre_segment_data["keywords"]:
|
||||
segment.keywords = pre_segment_data["keywords"]
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content,
|
||||
self._config.max_keywords_per_chunk)
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
@@ -30,4 +29,4 @@ class JiebaKeywordTableHandler:
|
||||
if len(sub_tokens) > 1:
|
||||
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,6 @@ from models.dataset import Dataset
|
||||
|
||||
|
||||
class BaseKeyword(ABC):
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
@@ -31,15 +30,12 @@ class BaseKeyword(ABC):
|
||||
def delete(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
doc_id = text.metadata['doc_id']
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
@@ -47,4 +43,4 @@ class BaseKeyword(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata['doc_id'] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
|
||||
@@ -20,9 +20,7 @@ class Keyword:
|
||||
raise ValueError("Keyword store must be specified.")
|
||||
|
||||
if keyword_type == "jieba":
|
||||
return Jieba(
|
||||
dataset=self._dataset
|
||||
)
|
||||
return Jieba(dataset=self._dataset)
|
||||
else:
|
||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||
|
||||
@@ -41,10 +39,7 @@ class Keyword:
|
||||
def delete(self) -> None:
|
||||
self._keyword_processor.delete()
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> list[Document]:
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return self._keyword_processor.search(query, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
||||
Reference in New Issue
Block a user