chore: refurbish Python code by applying refurb linter rules (#8296)

This commit is contained in:
Bowen Liang
2024-09-12 15:50:49 +08:00
committed by GitHub
parent c69f5b07ba
commit 40fb4d16ef
105 changed files with 220 additions and 276 deletions

View File

@@ -34,7 +34,7 @@ class BaseKeyword(ABC):
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

View File

@@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,

View File

@@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]

View File

@@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)

View File

@@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

View File

@@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
metadata = {}
metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@@ -200,7 +200,7 @@ class OracleVector(BaseVector):
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
@@ -212,7 +212,7 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")

View File

@@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
metadata = record.meta
score = 1 - dis
metadata["score"] = score
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)

View File

@@ -144,7 +144,7 @@ class PGVector(BaseVector):
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance

View File

@@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

View File

@@ -230,7 +230,7 @@ class RelytVector(BaseVector):
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if 1 - score > score_threshold:
docs.append(document)
return docs

View File

@@ -153,7 +153,7 @@ class TencentVector(BaseVector):
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
return self._get_search_res(res, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@@ -185,7 +185,7 @@ class TiDBVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold

View File

@@ -49,7 +49,7 @@ class BaseVector(ABC):
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

View File

@@ -153,7 +153,7 @@ class Vector:
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

View File

@@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score

View File

@@ -12,7 +12,7 @@ import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO
from pathlib import PurePath
from pathlib import Path, PurePath
from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, model_validator
@@ -56,8 +56,7 @@ class Blob(BaseModel):
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
return Path(str(self.path)).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
@@ -72,8 +71,7 @@ class Blob(BaseModel):
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
return Path(str(self.path)).read_bytes()
else:
raise ValueError(f"Unable to get bytes for blob {self}")

View File

@@ -68,8 +68,7 @@ class ExtractProcessor:
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, "wb") as file:
file.write(response.content)
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
delimiter = "\n"
@@ -111,7 +110,7 @@ class ExtractProcessor:
)
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]:
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
@@ -143,7 +142,7 @@ class ExtractProcessor:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]:
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)

View File

@@ -1,6 +1,7 @@
"""Document loader helpers."""
import concurrent.futures
from pathlib import Path
from typing import NamedTuple, Optional, cast
@@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
import chardet
def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f:
rawdata = f.read()
rawdata = Path(file_path).read_bytes()
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:

View File

@@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations."""
import re
from pathlib import Path
from typing import Optional, cast
from core.rag.extractor.extractor_base import BaseExtractor
@@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor):
"""Parse file into tuples."""
content = ""
try:
with open(filepath, encoding=self._encoding) as f:
content = f.read()
content = Path(filepath).read_text(encoding=self._encoding)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
content = Path(filepath).read_text(encoding=encoding.encoding)
break
except UnicodeDecodeError:
continue

View File

@@ -1,5 +1,6 @@
"""Abstract interface for document loader implementations."""
from pathlib import Path
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
@@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor):
"""Load from file path."""
text = ""
try:
with open(self._file_path, encoding=self._encoding) as f:
text = f.read()
text = Path(self._file_path).read_text(encoding=self._encoding)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, encoding=encoding.encoding) as f:
text = f.read()
text = Path(self._file_path).read_text(encoding=encoding.encoding)
break
except UnicodeDecodeError:
continue

View File

@@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor):
if col_index >= total_cols:
break
cell_content = self._parse_cell(cell, image_map).strip()
cell_colspan = cell.grid_span if cell.grid_span else 1
cell_colspan = cell.grid_span or 1
for i in range(cell_colspan):
if col_index + i < total_cols:
row_cells[col_index + i] = cell_content if i == 0 else ""

View File

@@ -256,7 +256,7 @@ class DatasetRetrieval:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
top_k = retrieval_model_config["top_k"]
@@ -410,7 +410,7 @@ class DatasetRetrieval:
return []
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
@@ -433,9 +433,7 @@ class DatasetRetrieval:
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
@@ -486,7 +484,7 @@ class DatasetRetrieval:
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
# get top k
top_k = retrieval_model_config["top_k"]