mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
chore: refurbish Python code by applying refurb linter rules (#8296)
This commit is contained in:
@@ -34,7 +34,7 @@ class BaseKeyword(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
||||
@@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
|
||||
@@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
|
||||
@@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
|
||||
id=uuids[i],
|
||||
document={
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
|
||||
Field.VECTOR.value: embeddings[i] or None,
|
||||
Field.METADATA_KEY.value: documents[i].metadata or {},
|
||||
},
|
||||
)
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
@@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
@@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
|
||||
for result in results[0]:
|
||||
metadata = result["entity"].get(Field.METADATA_KEY.value)
|
||||
metadata["score"] = result["distance"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if result["distance"] > score_threshold:
|
||||
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
|
||||
@@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
|
||||
metadata = {}
|
||||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@@ -200,7 +200,7 @@ class OracleVector(BaseVector):
|
||||
[numpy.array(query_vector)],
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
@@ -212,7 +212,7 @@ class OracleVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if len(query) > 0:
|
||||
# Check which language the query is in
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
|
||||
|
||||
@@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
|
||||
metadata = record.meta
|
||||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if score > score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
@@ -144,7 +144,7 @@ class PGVector(BaseVector):
|
||||
(json.dumps(query_vector),),
|
||||
)
|
||||
docs = []
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
for record in cur:
|
||||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
|
||||
@@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
|
||||
for result in results:
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
|
||||
@@ -230,7 +230,7 @@ class RelytVector(BaseVector):
|
||||
# Organize results.
|
||||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
@@ -153,7 +153,7 @@ class TencentVector(BaseVector):
|
||||
limit=kwargs.get("top_k", 4),
|
||||
timeout=self._client_config.timeout,
|
||||
)
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
return self._get_search_res(res, score_threshold)
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
|
||||
@@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
filter = kwargs.get("filter")
|
||||
distance = 1 - score_threshold
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class BaseVector(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
||||
@@ -153,7 +153,7 @@ class Vector:
|
||||
return CacheEmbedding(embedding_model)
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts[:]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
|
||||
@@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
|
||||
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
|
||||
score_threshold = kwargs.get("score_threshold", 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
|
||||
@@ -12,7 +12,7 @@ import mimetypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Iterable, Mapping
|
||||
from io import BufferedReader, BytesIO
|
||||
from pathlib import PurePath
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
@@ -56,8 +56,7 @@ class Blob(BaseModel):
|
||||
def as_string(self) -> str:
|
||||
"""Read data as a string."""
|
||||
if self.data is None and self.path:
|
||||
with open(str(self.path), encoding=self.encoding) as f:
|
||||
return f.read()
|
||||
return Path(str(self.path)).read_text(encoding=self.encoding)
|
||||
elif isinstance(self.data, bytes):
|
||||
return self.data.decode(self.encoding)
|
||||
elif isinstance(self.data, str):
|
||||
@@ -72,8 +71,7 @@ class Blob(BaseModel):
|
||||
elif isinstance(self.data, str):
|
||||
return self.data.encode(self.encoding)
|
||||
elif self.data is None and self.path:
|
||||
with open(str(self.path), "rb") as f:
|
||||
return f.read()
|
||||
return Path(str(self.path)).read_bytes()
|
||||
else:
|
||||
raise ValueError(f"Unable to get bytes for blob {self}")
|
||||
|
||||
|
||||
@@ -68,8 +68,7 @@ class ExtractProcessor:
|
||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
with open(file_path, "wb") as file:
|
||||
file.write(response.content)
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
delimiter = "\n"
|
||||
@@ -111,7 +110,7 @@ class ExtractProcessor:
|
||||
)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in [".docx"]:
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
@@ -143,7 +142,7 @@ class ExtractProcessor:
|
||||
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
|
||||
elif file_extension in [".htm", ".html"]:
|
||||
extractor = HtmlExtractor(file_path)
|
||||
elif file_extension in [".docx"]:
|
||||
elif file_extension == ".docx":
|
||||
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
|
||||
elif file_extension == ".csv":
|
||||
extractor = CSVExtractor(file_path, autodetect_encoding=True)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Document loader helpers."""
|
||||
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, cast
|
||||
|
||||
|
||||
@@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
|
||||
import chardet
|
||||
|
||||
def read_and_detect(file_path: str) -> list[dict]:
|
||||
with open(file_path, "rb") as f:
|
||||
rawdata = f.read()
|
||||
rawdata = Path(file_path).read_bytes()
|
||||
return cast(list[dict], chardet.detect_all(rawdata))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor):
|
||||
"""Parse file into tuples."""
|
||||
content = ""
|
||||
try:
|
||||
with open(filepath, encoding=self._encoding) as f:
|
||||
content = f.read()
|
||||
content = Path(filepath).read_text(encoding=self._encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(filepath)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(filepath, encoding=encoding.encoding) as f:
|
||||
content = f.read()
|
||||
content = Path(filepath).read_text(encoding=encoding.encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor):
|
||||
"""Load from file path."""
|
||||
text = ""
|
||||
try:
|
||||
with open(self._file_path, encoding=self._encoding) as f:
|
||||
text = f.read()
|
||||
text = Path(self._file_path).read_text(encoding=self._encoding)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, encoding=encoding.encoding) as f:
|
||||
text = f.read()
|
||||
text = Path(self._file_path).read_text(encoding=encoding.encoding)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
@@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor):
|
||||
if col_index >= total_cols:
|
||||
break
|
||||
cell_content = self._parse_cell(cell, image_map).strip()
|
||||
cell_colspan = cell.grid_span if cell.grid_span else 1
|
||||
cell_colspan = cell.grid_span or 1
|
||||
for i in range(cell_colspan):
|
||||
if col_index + i < total_cols:
|
||||
row_cells[col_index + i] = cell_content if i == 0 else ""
|
||||
|
||||
@@ -256,7 +256,7 @@ class DatasetRetrieval:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if dataset:
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
@@ -410,7 +410,7 @@ class DatasetRetrieval:
|
||||
return []
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
@@ -433,9 +433,7 @@ class DatasetRetrieval:
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
@@ -486,7 +484,7 @@ class DatasetRetrieval:
|
||||
}
|
||||
|
||||
for dataset in available_datasets:
|
||||
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config["top_k"]
|
||||
|
||||
Reference in New Issue
Block a user