chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -31,7 +31,6 @@ class SortOrder(Enum):
class MyScaleVector(BaseVector):
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
super().__init__(collection_name)
self._config = config
@@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {}
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
@@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
self._client.command(
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
)
def get_ids_by_metadata_field(self, key: str, value: str):
rows = self._client.query(
@@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get('score_threshold') or 0.0
where_str = f"WHERE dist < {1 - score_threshold}" if \
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
score_threshold = kwargs.get("score_threshold") or 0.0
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
else ""
)
sql = f"""
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
@@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
return [
Document(
page_content=r["text"],
vector=r['vector'],
vector=r["vector"],
metadata=r["metadata"],
)
for r in self._client.query(sql).named_results()
@@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
class MyScaleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
return MyScaleVector(
collection_name=collection_name,