chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -24,37 +24,42 @@ class Jieba(BaseKeyword):
self._config = KeywordTableConfig()
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
keywords_list = kwargs.get('keywords_list', None)
keywords_list = kwargs.get("keywords_list", None)
for i in range(len(texts)):
text = texts[i]
if keywords_list:
keywords = keywords_list[i]
if not keywords:
keywords = keyword_table_handler.extract_keywords(text.page_content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
else:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
self._save_dataset_keyword_table(keyword_table)
@@ -63,97 +68,91 @@ class Jieba(BaseKeyword):
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
k = kwargs.get('top_k', 4)
k = kwargs.get("top_k", 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == self.dataset.id,
DocumentSegment.index_node_id == chunk_index
).first()
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index)
.first()
)
if segment:
documents.append(Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
))
documents.append(
Document(
page_content=segment.content,
metadata={
"doc_id": chunk_index,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
)
return documents
def delete(self) -> None:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
if dataset_keyword_table.data_source_type != "database":
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
storage.delete(file_key)
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": keyword_table
}
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == 'database':
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
return keyword_table_dict["__data__"]["table"]
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table='',
keyword_table="",
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
{
"__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
)
db.session.add(dataset_keyword_table)
db.session.commit()
@@ -174,9 +173,7 @@ class Jieba(BaseKeyword):
keywords_to_delete = set()
for keyword, node_idxs in keyword_table.items():
if node_idxs_to_delete.intersection(node_idxs):
keyword_table[keyword] = node_idxs.difference(
node_idxs_to_delete
)
keyword_table[keyword] = node_idxs.difference(node_idxs_to_delete)
if not keyword_table[keyword]:
keywords_to_delete.add(keyword)
@@ -202,13 +199,14 @@ class Jieba(BaseKeyword):
reverse=True,
)
return sorted_chunk_indices[: k]
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
).first()
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
@@ -224,14 +222,14 @@ class Jieba(BaseKeyword):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
segment = pre_segment_data["segment"]
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)

View File

@@ -8,7 +8,6 @@ from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
@@ -30,4 +29,4 @@ class JiebaKeywordTableHandler:
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results
return results

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,6 @@ from models.dataset import Dataset
class BaseKeyword(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
@@ -31,15 +30,12 @@ class BaseKeyword(ABC):
def delete(self) -> None:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
doc_id = text.metadata['doc_id']
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
@@ -47,4 +43,4 @@ class BaseKeyword(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]
return [text.metadata["doc_id"] for text in texts]

View File

@@ -20,9 +20,7 @@ class Keyword:
raise ValueError("Keyword store must be specified.")
if keyword_type == "jieba":
return Jieba(
dataset=self._dataset
)
return Jieba(dataset=self._dataset)
else:
raise ValueError(f"Keyword store {keyword_type} is not supported.")
@@ -41,10 +39,7 @@ class Keyword:
def delete(self) -> None:
self._keyword_processor.delete()
def search(
self, query: str,
**kwargs: Any
) -> list[Document]:
def search(self, query: str, **kwargs: Any) -> list[Document]:
return self._keyword_processor.search(query, **kwargs)
def __getattr__(self, name):