mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-25 02:33:00 +08:00
chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -31,7 +31,6 @@ class SortOrder(Enum):
|
||||
|
||||
|
||||
class MyScaleVector(BaseVector):
|
||||
|
||||
def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._config = config
|
||||
@@ -80,7 +79,7 @@ class MyScaleVector(BaseVector):
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {}
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
@@ -101,7 +100,8 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.command(
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}")
|
||||
f"DELETE FROM {self._config.database}.{self._collection_name} WHERE id IN {str(tuple(ids))}"
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
rows = self._client.query(
|
||||
@@ -122,9 +122,12 @@ class MyScaleVector(BaseVector):
|
||||
|
||||
def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
score_threshold = kwargs.get('score_threshold') or 0.0
|
||||
where_str = f"WHERE dist < {1 - score_threshold}" if \
|
||||
self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 else ""
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
where_str = (
|
||||
f"WHERE dist < {1 - score_threshold}"
|
||||
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0
|
||||
else ""
|
||||
)
|
||||
sql = f"""
|
||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||
@@ -133,7 +136,7 @@ class MyScaleVector(BaseVector):
|
||||
return [
|
||||
Document(
|
||||
page_content=r["text"],
|
||||
vector=r['vector'],
|
||||
vector=r["vector"],
|
||||
metadata=r["metadata"],
|
||||
)
|
||||
for r in self._client.query(sql).named_results()
|
||||
@@ -149,13 +152,12 @@ class MyScaleVector(BaseVector):
|
||||
class MyScaleVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MyScaleVector:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix.lower()
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
dataset.index_struct = json.dumps(
|
||||
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))
|
||||
|
||||
return MyScaleVector(
|
||||
collection_name=collection_name,
|
||||
|
||||
Reference in New Issue
Block a user