mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 18:23:07 +08:00
feat: mypy for all type check (#10921)
This commit is contained in:
@@ -32,8 +32,11 @@ class Jieba(BaseKeyword):
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, text.metadata["doc_id"], list(keywords)
|
||||
)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -58,20 +61,26 @@ class Jieba(BaseKeyword):
|
||||
keywords = keyword_table_handler.extract_keywords(
|
||||
text.page_content, self._config.max_keywords_per_chunk
|
||||
)
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
|
||||
if text.metadata is not None:
|
||||
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, text.metadata["doc_id"], list(keywords)
|
||||
)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
if keyword_table is not None:
|
||||
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@@ -80,7 +89,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
k = kwargs.get("top_k", 4)
|
||||
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
|
||||
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
|
||||
|
||||
documents = []
|
||||
for chunk_index in sorted_chunk_indices:
|
||||
@@ -137,7 +146,7 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return keyword_table_dict["__data__"]["table"]
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
@@ -188,8 +197,8 @@ class Jieba(BaseKeyword):
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords_list:
|
||||
for node_id in keyword_table[keyword]:
|
||||
chunk_indices_count[node_id] += 1
|
||||
|
||||
@@ -215,7 +224,7 @@ class Jieba(BaseKeyword):
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def multi_create_segment_keywords(self, pre_segment_data_list: list):
|
||||
@@ -226,17 +235,19 @@ class Jieba(BaseKeyword):
|
||||
if pre_segment_data["keywords"]:
|
||||
segment.keywords = pre_segment_data["keywords"]
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
|
||||
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
|
||||
)
|
||||
else:
|
||||
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
|
||||
segment.keywords = list(keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
keyword_table = self._add_text_to_keyword_table(
|
||||
keyword_table or {}, segment.index_node_id, list(keywords)
|
||||
)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
import jieba.analyse
|
||||
import jieba.analyse # type: ignore
|
||||
|
||||
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
|
||||
|
||||
@@ -12,7 +12,7 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
import jieba
|
||||
import jieba # type: ignore
|
||||
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
|
||||
@@ -37,6 +37,8 @@ class BaseKeyword(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata is None:
|
||||
continue
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
@@ -45,4 +47,4 @@ class BaseKeyword(ABC):
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata]
|
||||
|
||||
@@ -6,6 +6,7 @@ from flask import Flask, current_app
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@@ -31,7 +32,7 @@ class RetrievalService:
|
||||
top_k: int,
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
reranking_model: Optional[dict] = None,
|
||||
reranking_mode: Optional[str] = "reranking_model",
|
||||
reranking_mode: str = "reranking_model",
|
||||
weights: Optional[dict] = None,
|
||||
):
|
||||
if not query:
|
||||
@@ -42,15 +43,15 @@ class RetrievalService:
|
||||
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
threads = []
|
||||
exceptions = []
|
||||
all_documents: list[Document] = []
|
||||
threads: list[threading.Thread] = []
|
||||
exceptions: list[str] = []
|
||||
# retrieval_model source with keyword
|
||||
if retrieval_method == "keyword_search":
|
||||
keyword_thread = threading.Thread(
|
||||
target=RetrievalService.keyword_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
@@ -65,7 +66,7 @@ class RetrievalService:
|
||||
embedding_thread = threading.Thread(
|
||||
target=RetrievalService.embedding_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
@@ -84,7 +85,7 @@ class RetrievalService:
|
||||
full_text_index_thread = threading.Thread(
|
||||
target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"retrieval_method": retrieval_method,
|
||||
@@ -124,7 +125,7 @@ class RetrievalService:
|
||||
if not dataset:
|
||||
return []
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
|
||||
)
|
||||
return all_documents
|
||||
|
||||
@@ -135,6 +136,8 @@ class RetrievalService:
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
keyword = Keyword(dataset=dataset)
|
||||
|
||||
@@ -159,6 +162,8 @@ class RetrievalService:
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector = Vector(dataset=dataset)
|
||||
|
||||
@@ -209,6 +214,8 @@ class RetrievalService:
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("dataset not found")
|
||||
|
||||
vector_processor = Vector(
|
||||
dataset=dataset,
|
||||
|
||||
@@ -17,12 +17,19 @@ from models.dataset import Dataset
|
||||
|
||||
class AnalyticdbVector(BaseVector):
|
||||
def __init__(
|
||||
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
|
||||
self,
|
||||
collection_name: str,
|
||||
api_config: AnalyticdbVectorOpenAPIConfig | None,
|
||||
sql_config: AnalyticdbVectorBySqlConfig | None,
|
||||
):
|
||||
super().__init__(collection_name)
|
||||
if api_config is not None:
|
||||
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
|
||||
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
|
||||
collection_name, api_config
|
||||
)
|
||||
else:
|
||||
if sql_config is None:
|
||||
raise ValueError("Either api_config or sql_config must be provided")
|
||||
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
|
||||
|
||||
def get_type(self) -> str:
|
||||
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
|
||||
self.analyticdb_vector._create_collection_if_not_exists(dimension)
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
|
||||
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(texts, embeddings)
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.analyticdb_vector.add_texts(documents, embeddings)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
return self.analyticdb_vector.text_exists(id)
|
||||
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
if dify_config.ANALYTICDB_HOST is None:
|
||||
# implemented through OpenAPI
|
||||
apiConfig = AnalyticdbVectorOpenAPIConfig(
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID,
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID,
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
|
||||
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
|
||||
region_id=dify_config.ANALYTICDB_REGION_ID or "",
|
||||
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
|
||||
)
|
||||
sqlConfig = None
|
||||
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
|
||||
sqlConfig = AnalyticdbVectorBySqlConfig(
|
||||
host=dify_config.ANALYTICDB_HOST,
|
||||
port=dify_config.ANALYTICDB_PORT,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT,
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD,
|
||||
account=dify_config.ANALYTICDB_ACCOUNT or "",
|
||||
account_password=dify_config.ANALYTICDB_PASSWORD or "",
|
||||
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
|
||||
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE,
|
||||
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
|
||||
)
|
||||
apiConfig = None
|
||||
return AnalyticdbVector(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
namespace_password: Optional[str] = None
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
from alibabacloud_gpdb20160503.client import Client # type: ignore
|
||||
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
from Tea.exceptions import TeaException # type: ignore
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
if doc.metadata is not None:
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
assert self.pool is not None, "Connection pool is not initialized"
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
if doc.metadata is not None:
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymochow import MochowClient
|
||||
from pymochow.auth.bce_credentials import BceCredentials
|
||||
from pymochow.configuration import Configuration
|
||||
from pymochow.exception import ServerError
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
|
||||
from pymochow import MochowClient # type: ignore
|
||||
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
|
||||
from pymochow.configuration import Configuration # type: ignore
|
||||
from pymochow.exception import ServerError # type: ignore
|
||||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
|
||||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
|
||||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
|
||||
total_count = len(documents)
|
||||
batch_size = 1000
|
||||
|
||||
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
|
||||
for start in range(0, total_count, batch_size):
|
||||
end = min(start + batch_size, total_count)
|
||||
rows = []
|
||||
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
|
||||
# FIXME do you need this assert?
|
||||
for i in range(start, end, 1):
|
||||
row = Row(
|
||||
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
|
||||
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
|
||||
# baidu vector database doesn't support bm25 search on current version
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
def _get_search_res(self, res, score_threshold) -> list[Document]:
|
||||
docs = []
|
||||
for row in res.rows:
|
||||
row_data = row.get("row", {})
|
||||
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
|
||||
return BaiduVector(
|
||||
collection_name=collection_name,
|
||||
config=BaiduConfig(
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
|
||||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
|
||||
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
|
||||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
|
||||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
|
||||
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
|
||||
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
|
||||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
|
||||
),
|
||||
|
||||
@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
collection.delete(where={key: {"$eq": value}})
|
||||
# FIXME: fix the type error later
|
||||
collection.delete(where={key: {"$eq": value}}) # type: ignore
|
||||
|
||||
def delete(self):
|
||||
self._client.delete_collection(self._collection_name)
|
||||
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
|
||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
|
||||
ids: list[str] = results["ids"][0]
|
||||
documents: list[str] = results["documents"][0]
|
||||
metadatas: dict[str, Any] = results["metadatas"][0]
|
||||
distances: list[float] = results["distances"][0]
|
||||
# Check if results contain data
|
||||
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
|
||||
return []
|
||||
|
||||
ids = results["ids"][0]
|
||||
documents = results["documents"][0]
|
||||
metadatas = results["metadatas"][0]
|
||||
distances = results["distances"][0]
|
||||
|
||||
docs = []
|
||||
for index in range(len(ids)):
|
||||
distance = distances[index]
|
||||
metadata = metadatas[index]
|
||||
metadata = dict(metadatas[index])
|
||||
if distance >= score_threshold:
|
||||
metadata["score"] = distance
|
||||
doc = Document(
|
||||
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
|
||||
return ChromaVector(
|
||||
collection_name=collection_name,
|
||||
config=ChromaConfig(
|
||||
host=dify_config.CHROMA_HOST,
|
||||
host=dify_config.CHROMA_HOST or "",
|
||||
port=dify_config.CHROMA_PORT,
|
||||
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
|
||||
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
|
||||
|
||||
@@ -5,14 +5,14 @@ import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from couchbase import search
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.cluster import Cluster
|
||||
from couchbase.management.search import SearchIndex
|
||||
from couchbase import search # type: ignore
|
||||
from couchbase.auth import PasswordAuthenticator # type: ignore
|
||||
from couchbase.cluster import Cluster # type: ignore
|
||||
from couchbase.management.search import SearchIndex # type: ignore
|
||||
|
||||
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
|
||||
from couchbase.options import ClusterOptions, SearchOptions
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
from couchbase.options import ClusterOptions, SearchOptions # type: ignore
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
|
||||
from flask import current_app
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -231,7 +231,7 @@ class CouchbaseVector(BaseVector):
|
||||
# Pass the id as a parameter to the query
|
||||
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
|
||||
for row in result:
|
||||
return row["count"] > 0
|
||||
return bool(row["count"] > 0)
|
||||
return False # Return False if no rows are returned
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
@@ -369,10 +369,10 @@ class CouchbaseVectorFactory(AbstractVectorFactory):
|
||||
return CouchbaseVector(
|
||||
collection_name=collection_name,
|
||||
config=CouchbaseConfig(
|
||||
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
|
||||
user=config.get("COUCHBASE_USER"),
|
||||
password=config.get("COUCHBASE_PASSWORD"),
|
||||
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
|
||||
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
|
||||
connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
|
||||
user=config.get("COUCHBASE_USER", ""),
|
||||
password=config.get("COUCHBASE_PASSWORD", ""),
|
||||
bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
|
||||
scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector):
|
||||
|
||||
def _get_version(self) -> str:
|
||||
info = self._client.info()
|
||||
return info["version"]["number"]
|
||||
return cast(str, info["version"]["number"])
|
||||
|
||||
def _check_version(self):
|
||||
if self._version < "8.0.0":
|
||||
@@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector):
|
||||
return docs
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings, **kwargs)
|
||||
|
||||
def create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
|
||||
self,
|
||||
embeddings: list[list[float]],
|
||||
metadatas: Optional[list[dict[Any, Any]]] = None,
|
||||
index_params: Optional[dict] = None,
|
||||
):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
@@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
|
||||
return ElasticSearchVector(
|
||||
index_name=collection_name,
|
||||
config=ElasticSearchConfig(
|
||||
host=config.get("ELASTICSEARCH_HOST"),
|
||||
port=config.get("ELASTICSEARCH_PORT"),
|
||||
username=config.get("ELASTICSEARCH_USERNAME"),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD"),
|
||||
host=config.get("ELASTICSEARCH_HOST", "localhost"),
|
||||
port=config.get("ELASTICSEARCH_PORT", 9200),
|
||||
username=config.get("ELASTICSEARCH_USERNAME", ""),
|
||||
password=config.get("ELASTICSEARCH_PASSWORD", ""),
|
||||
),
|
||||
attributes=[],
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {"hosts": self.hosts}
|
||||
params: dict[str, Any] = {"hosts": self.hosts}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
@@ -53,7 +53,7 @@ class LindormVectorStore(BaseVector):
|
||||
self._routing = None
|
||||
self._routing_field = None
|
||||
if using_ugc:
|
||||
routing_value: str = kwargs.get("routing_value")
|
||||
routing_value: str | None = kwargs.get("routing_value")
|
||||
if routing_value is None:
|
||||
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
|
||||
self._routing = routing_value.lower()
|
||||
@@ -87,14 +87,15 @@ class LindormVectorStore(BaseVector):
|
||||
"_id": uuids[i],
|
||||
}
|
||||
}
|
||||
action_values = {
|
||||
action_values: dict[str, Any] = {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
}
|
||||
if self._using_ugc:
|
||||
action_header["index"]["routing"] = self._routing
|
||||
action_values[self._routing_field] = self._routing
|
||||
if self._routing_field is not None:
|
||||
action_values[self._routing_field] = self._routing
|
||||
actions.append(action_header)
|
||||
actions.append(action_values)
|
||||
response = self._client.bulk(actions)
|
||||
@@ -105,7 +106,9 @@ class LindormVectorStore(BaseVector):
|
||||
self.refresh()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
|
||||
query: dict[str, Any] = {
|
||||
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
|
||||
}
|
||||
if self._using_ugc:
|
||||
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
@@ -191,7 +194,8 @@ class LindormVectorStore(BaseVector):
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
@@ -366,6 +370,7 @@ def default_text_search_query(
|
||||
routing_field: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
query_clause: dict[str, Any] = {}
|
||||
if routing is not None:
|
||||
query_clause = {
|
||||
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
|
||||
@@ -386,7 +391,7 @@ def default_text_search_query(
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query = {"must": must}
|
||||
boolean_query: dict[str, Any] = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
@@ -426,7 +431,7 @@ def default_vector_search_query(
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext = {"lvector": {}}
|
||||
final_ext: dict[str, Any] = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
@@ -438,7 +443,7 @@ def default_vector_search_query(
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query = {
|
||||
search_query: dict[str, Any] = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
@@ -446,8 +451,8 @@ def default_vector_search_query(
|
||||
|
||||
if filters is not None:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
|
||||
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
@@ -459,17 +464,19 @@ def default_vector_search_query(
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
hosts=dify_config.LINDORM_URL or "",
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
using_ugc=dify_config.USING_UGC_INDEX,
|
||||
)
|
||||
using_ugc = dify_config.USING_UGC_INDEX
|
||||
if using_ugc is None:
|
||||
raise ValueError("USING_UGC_INDEX is not set")
|
||||
routing_value = None
|
||||
if dataset.index_struct:
|
||||
# if an existed record's index_struct_dict doesn't contain using_ugc field,
|
||||
# it actually stores in the normal index format
|
||||
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
|
||||
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
|
||||
using_ugc = stored_in_ugc
|
||||
if stored_in_ugc:
|
||||
dimension = dataset.index_struct_dict["dimension"]
|
||||
|
||||
@@ -3,8 +3,8 @@ import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pymilvus import MilvusClient, MilvusException
|
||||
from pymilvus.milvus_client import IndexParams
|
||||
from pymilvus import MilvusClient, MilvusException # type: ignore
|
||||
from pymilvus.milvus_client import IndexParams # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
@@ -54,14 +54,14 @@ class MilvusVector(BaseVector):
|
||||
self._client_config = config
|
||||
self._client = self._init_client(config)
|
||||
self._consistency_level = "Session"
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.MILVUS
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas, index_params)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@@ -161,8 +161,8 @@ class MilvusVector(BaseVector):
|
||||
return
|
||||
# Grab the existing collection if it exists
|
||||
if not self._client.has_collection(self._collection_name):
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
|
||||
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
@@ -217,10 +217,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
uri=dify_config.MILVUS_URI,
|
||||
token=dify_config.MILVUS_TOKEN,
|
||||
user=dify_config.MILVUS_USER,
|
||||
password=dify_config.MILVUS_PASSWORD,
|
||||
database=dify_config.MILVUS_DATABASE,
|
||||
uri=dify_config.MILVUS_URI or "",
|
||||
token=dify_config.MILVUS_TOKEN or "",
|
||||
user=dify_config.MILVUS_USER or "",
|
||||
password=dify_config.MILVUS_PASSWORD or "",
|
||||
database=dify_config.MILVUS_DATABASE or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -74,15 +74,16 @@ class MyScaleVector(BaseVector):
|
||||
columns = ["id", "text", "vector", "metadata"]
|
||||
values = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
row = (
|
||||
doc_id,
|
||||
self.escape_str(doc.page_content),
|
||||
embeddings[i],
|
||||
json.dumps(doc.metadata) if doc.metadata else {},
|
||||
)
|
||||
values.append(str(row))
|
||||
ids.append(doc_id)
|
||||
sql = f"""
|
||||
INSERT INTO {self._config.database}.{self._collection_name}
|
||||
({",".join(columns)}) VALUES {",".join(values)}
|
||||
|
||||
@@ -4,7 +4,7 @@ import math
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pyobvector import VECTOR, ObVecClient
|
||||
from pyobvector import VECTOR, ObVecClient # type: ignore
|
||||
from sqlalchemy import JSON, Column, String, func
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
|
||||
@@ -131,7 +131,7 @@ class OceanBaseVector(BaseVector):
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
cur = self._client.get(table_name=self._collection_name, id=id)
|
||||
return cur.rowcount != 0
|
||||
return bool(cur.rowcount != 0)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.delete(table_name=self._collection_name, ids=ids)
|
||||
|
||||
@@ -66,7 +66,7 @@ class OpenSearchVector(BaseVector):
|
||||
return VectorType.OPENSEARCH
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
metadatas = [d.metadata for d in texts]
|
||||
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
|
||||
self.create_collection(embeddings, metadatas)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
@@ -244,7 +244,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
|
||||
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST,
|
||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import jieba.posseg as pseg
|
||||
import jieba.posseg as pseg # type: ignore
|
||||
import numpy
|
||||
import oracledb
|
||||
from pydantic import BaseModel, model_validator
|
||||
@@ -88,12 +88,11 @@ class OracleVector(BaseVector):
|
||||
|
||||
def numpy_converter_out(self, value):
|
||||
if value.typecode == "b":
|
||||
dtype = numpy.int8
|
||||
return numpy.array(value, copy=False, dtype=numpy.int8)
|
||||
elif value.typecode == "f":
|
||||
dtype = numpy.float32
|
||||
return numpy.array(value, copy=False, dtype=numpy.float32)
|
||||
else:
|
||||
dtype = numpy.float64
|
||||
return numpy.array(value, copy=False, dtype=dtype)
|
||||
return numpy.array(value, copy=False, dtype=numpy.float64)
|
||||
|
||||
def output_type_handler(self, cursor, metadata):
|
||||
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||
@@ -135,17 +134,18 @@ class OracleVector(BaseVector):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
# array.array("f", embeddings[i]),
|
||||
numpy.array(embeddings[i]),
|
||||
)
|
||||
)
|
||||
)
|
||||
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
|
||||
with self._get_cursor() as cur:
|
||||
cur.executemany(
|
||||
@@ -201,8 +201,8 @@ class OracleVector(BaseVector):
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# lazy import
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
import nltk # type: ignore
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
top_k = kwargs.get("top_k", 5)
|
||||
# just not implement fetch by score_threshold now, may be later
|
||||
@@ -285,10 +285,10 @@ class OracleVectorFactory(AbstractVectorFactory):
|
||||
return OracleVector(
|
||||
collection_name=collection_name,
|
||||
config=OracleVectorConfig(
|
||||
host=dify_config.ORACLE_HOST,
|
||||
host=dify_config.ORACLE_HOST or "localhost",
|
||||
port=dify_config.ORACLE_PORT,
|
||||
user=dify_config.ORACLE_USER,
|
||||
password=dify_config.ORACLE_PASSWORD,
|
||||
database=dify_config.ORACLE_DATABASE,
|
||||
user=dify_config.ORACLE_USER or "system",
|
||||
password=dify_config.ORACLE_PASSWORD or "oracle",
|
||||
database=dify_config.ORACLE_DATABASE or "orcl",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from numpy import ndarray
|
||||
from pgvecto_rs.sqlalchemy import VECTOR
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Float, String, create_engine, insert, select, text
|
||||
from sqlalchemy import text as sql_text
|
||||
@@ -58,7 +58,7 @@ class PGVectoRS(BaseVector):
|
||||
with Session(self._client) as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
|
||||
session.commit()
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
|
||||
class _Table(CollectionORM):
|
||||
__tablename__ = collection_name
|
||||
@@ -222,11 +222,11 @@ class PGVectoRSFactory(AbstractVectorFactory):
|
||||
return PGVectoRS(
|
||||
collection_name=collection_name,
|
||||
config=PgvectoRSConfig(
|
||||
host=dify_config.PGVECTO_RS_HOST,
|
||||
port=dify_config.PGVECTO_RS_PORT,
|
||||
user=dify_config.PGVECTO_RS_USER,
|
||||
password=dify_config.PGVECTO_RS_PASSWORD,
|
||||
database=dify_config.PGVECTO_RS_DATABASE,
|
||||
host=dify_config.PGVECTO_RS_HOST or "localhost",
|
||||
port=dify_config.PGVECTO_RS_PORT or 5432,
|
||||
user=dify_config.PGVECTO_RS_USER or "postgres",
|
||||
password=dify_config.PGVECTO_RS_PASSWORD or "",
|
||||
database=dify_config.PGVECTO_RS_DATABASE or "postgres",
|
||||
),
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
import psycopg2.extras # type: ignore
|
||||
import psycopg2.pool # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -98,16 +98,17 @@ class PGVector(BaseVector):
|
||||
values = []
|
||||
pks = []
|
||||
for i, doc in enumerate(documents):
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
if doc.metadata is not None:
|
||||
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
|
||||
pks.append(doc_id)
|
||||
values.append(
|
||||
(
|
||||
doc_id,
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
embeddings[i],
|
||||
)
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
|
||||
@@ -216,11 +217,11 @@ class PGVectorFactory(AbstractVectorFactory):
|
||||
return PGVector(
|
||||
collection_name=collection_name,
|
||||
config=PGVectorConfig(
|
||||
host=dify_config.PGVECTOR_HOST,
|
||||
host=dify_config.PGVECTOR_HOST or "localhost",
|
||||
port=dify_config.PGVECTOR_PORT,
|
||||
user=dify_config.PGVECTOR_USER,
|
||||
password=dify_config.PGVECTOR_PASSWORD,
|
||||
database=dify_config.PGVECTOR_DATABASE,
|
||||
user=dify_config.PGVECTOR_USER or "postgres",
|
||||
password=dify_config.PGVECTOR_PASSWORD or "",
|
||||
database=dify_config.PGVECTOR_DATABASE or "postgres",
|
||||
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
|
||||
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
|
||||
),
|
||||
|
||||
@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
if not self.root_path:
|
||||
raise ValueError("Root path is not set")
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
# Filter out None values from metadatas list to match expected type
|
||||
filtered_metadatas = [m for m in metadatas if m is not None]
|
||||
for batch_ids, points in self._generate_rest_batches(
|
||||
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
|
||||
):
|
||||
self._client.upsert(collection_name=self._collection_name, points=points)
|
||||
added_ids.extend(batch_ids)
|
||||
|
||||
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
group_id or "", # Ensure group_id is never None
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=QdrantConfig(
|
||||
endpoint=dify_config.QDRANT_URL,
|
||||
endpoint=dify_config.QDRANT_URL or "",
|
||||
api_key=dify_config.QDRANT_API_KEY,
|
||||
root_path=current_app.config.root_path,
|
||||
root_path=str(current_app.config.root_path),
|
||||
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
|
||||
from sqlalchemy import Column, String, Table, create_engine, insert
|
||||
from sqlalchemy import text as sql_text
|
||||
from sqlalchemy.dialects.postgresql import JSON, TEXT
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
|
||||
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
|
||||
)
|
||||
self.client = create_engine(self._url)
|
||||
self._fields = []
|
||||
self._fields: list[str] = []
|
||||
self._group_id = group_id
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.RELYT
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
index_params = {}
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
|
||||
index_params: dict[str, Any] = {}
|
||||
metadatas = [d.metadata for d in texts]
|
||||
self.create_collection(len(embeddings[0]))
|
||||
self.embedding_dimension = len(embeddings[0])
|
||||
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from pgvecto_rs.sqlalchemy import VECTOR
|
||||
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
|
||||
|
||||
ids = [str(uuid.uuid1()) for _ in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
metadatas = [d.metadata for d in documents if d.metadata is not None]
|
||||
for metadata in metadatas:
|
||||
metadata["group_id"] = self._group_id
|
||||
texts = [d.page_content for d in documents]
|
||||
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
|
||||
filter: Optional[dict] = None,
|
||||
) -> list[tuple[Document, float]]:
|
||||
# Add the filter if provided
|
||||
try:
|
||||
from sqlalchemy.engine import Row
|
||||
except ImportError:
|
||||
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
|
||||
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with self.client.connect() as conn:
|
||||
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
results = conn.execute(sql_text(sql_query), params).fetchall()
|
||||
|
||||
documents_with_scores = [
|
||||
(
|
||||
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
|
||||
return RelytVector(
|
||||
collection_name=collection_name,
|
||||
config=RelytConfig(
|
||||
host=dify_config.RELYT_HOST,
|
||||
host=dify_config.RELYT_HOST or "localhost",
|
||||
port=dify_config.RELYT_PORT,
|
||||
user=dify_config.RELYT_USER,
|
||||
password=dify_config.RELYT_PASSWORD,
|
||||
database=dify_config.RELYT_DATABASE,
|
||||
user=dify_config.RELYT_USER or "",
|
||||
password=dify_config.RELYT_PASSWORD or "",
|
||||
database=dify_config.RELYT_DATABASE or "default",
|
||||
),
|
||||
group_id=dataset.id,
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tcvectordb import VectorDBClient
|
||||
from tcvectordb.model import document, enum
|
||||
from tcvectordb.model import index as vdb_index
|
||||
from tcvectordb.model.document import Filter
|
||||
from tcvectordb import VectorDBClient # type: ignore
|
||||
from tcvectordb.model import document, enum # type: ignore
|
||||
from tcvectordb.model import index as vdb_index # type: ignore
|
||||
from tcvectordb.model.document import Filter # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
@@ -25,8 +25,8 @@ class TencentConfig(BaseModel):
|
||||
database: Optional[str]
|
||||
index_type: str = "HNSW"
|
||||
metric_type: str = "L2"
|
||||
shard: int = (1,)
|
||||
replicas: int = (2,)
|
||||
shard: int = 1
|
||||
replicas: int = 2
|
||||
|
||||
def to_tencent_params(self):
|
||||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
@@ -120,15 +120,15 @@ class TencentVector(BaseVector):
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
total_count = len(embeddings)
|
||||
docs = []
|
||||
for id in range(0, total_count):
|
||||
for i in range(0, total_count):
|
||||
if metadatas is None:
|
||||
continue
|
||||
metadata = json.dumps(metadatas[id])
|
||||
metadata = metadatas[i] or {}
|
||||
doc = document.Document(
|
||||
id=metadatas[id]["doc_id"],
|
||||
vector=embeddings[id],
|
||||
text=texts[id],
|
||||
metadata=metadata,
|
||||
id=metadata.get("doc_id"),
|
||||
vector=embeddings[i],
|
||||
text=texts[i],
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
docs.append(doc)
|
||||
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
|
||||
@@ -159,8 +159,8 @@ class TencentVector(BaseVector):
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
return []
|
||||
|
||||
def _get_search_res(self, res, score_threshold):
|
||||
docs = []
|
||||
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
|
||||
docs: list[Document] = []
|
||||
if res is None or len(res) == 0:
|
||||
return docs
|
||||
|
||||
@@ -193,7 +193,7 @@ class TencentVectorFactory(AbstractVectorFactory):
|
||||
return TencentVector(
|
||||
collection_name=collection_name,
|
||||
config=TencentConfig(
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL,
|
||||
url=dify_config.TENCENT_VECTOR_DB_URL or "",
|
||||
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
|
||||
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
|
||||
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
|
||||
|
||||
@@ -54,7 +54,10 @@ class TidbOnQdrantConfig(BaseModel):
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(self.root_path, path)
|
||||
if self.root_path:
|
||||
path = os.path.join(self.root_path, path)
|
||||
else:
|
||||
raise ValueError("root_path is required")
|
||||
|
||||
return {"path": path}
|
||||
else:
|
||||
@@ -157,7 +160,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
metadatas = [d.metadata for d in documents if d.metadata is not None]
|
||||
|
||||
added_ids = []
|
||||
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
|
||||
@@ -203,7 +206,7 @@ class TidbOnQdrantVector(BaseVector):
|
||||
batch_metadatas,
|
||||
Field.CONTENT_KEY.value,
|
||||
Field.METADATA_KEY.value,
|
||||
group_id,
|
||||
group_id or "",
|
||||
Field.GROUP_KEY.value,
|
||||
),
|
||||
)
|
||||
@@ -334,18 +337,20 @@ class TidbOnQdrantVector(BaseVector):
|
||||
)
|
||||
docs = []
|
||||
for result in results:
|
||||
if result.payload is None:
|
||||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value),
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
metadata=metadata,
|
||||
)
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -427,12 +432,12 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
|
||||
else:
|
||||
new_cluster = TidbService.create_tidb_serverless_cluster(
|
||||
dify_config.TIDB_PROJECT_ID,
|
||||
dify_config.TIDB_API_URL,
|
||||
dify_config.TIDB_IAM_API_URL,
|
||||
dify_config.TIDB_PUBLIC_KEY,
|
||||
dify_config.TIDB_PRIVATE_KEY,
|
||||
dify_config.TIDB_REGION,
|
||||
dify_config.TIDB_PROJECT_ID or "",
|
||||
dify_config.TIDB_API_URL or "",
|
||||
dify_config.TIDB_IAM_API_URL or "",
|
||||
dify_config.TIDB_PUBLIC_KEY or "",
|
||||
dify_config.TIDB_PRIVATE_KEY or "",
|
||||
dify_config.TIDB_REGION or "",
|
||||
)
|
||||
new_tidb_auth_binding = TidbAuthBinding(
|
||||
cluster_id=new_cluster["cluster_id"],
|
||||
@@ -464,9 +469,9 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
collection_name=collection_name,
|
||||
group_id=dataset.id,
|
||||
config=TidbOnQdrantConfig(
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL,
|
||||
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
|
||||
api_key=TIDB_ON_QDRANT_API_KEY,
|
||||
root_path=config.root_path,
|
||||
root_path=str(config.root_path),
|
||||
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
|
||||
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
|
||||
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,
|
||||
|
||||
@@ -146,7 +146,7 @@ class TidbService:
|
||||
iam_url: str,
|
||||
public_key: str,
|
||||
private_key: str,
|
||||
) -> list[dict]:
|
||||
):
|
||||
"""
|
||||
Update the status of a new TiDB Serverless cluster.
|
||||
:param project_id: The project ID of the TiDB Cloud project (required).
|
||||
@@ -159,7 +159,6 @@ class TidbService:
|
||||
|
||||
:return: The response from the API.
|
||||
"""
|
||||
clusters = []
|
||||
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
|
||||
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
|
||||
params = {"clusterIds": cluster_ids, "view": "BASIC"}
|
||||
@@ -169,7 +168,6 @@ class TidbService:
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
state = item["state"]
|
||||
userPrefix = item["userPrefix"]
|
||||
@@ -236,16 +234,17 @@ class TidbService:
|
||||
cluster_infos = []
|
||||
for item in response_data["clusters"]:
|
||||
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
|
||||
password = redis_client.get(cache_key)
|
||||
if not password:
|
||||
cached_password = redis_client.get(cache_key)
|
||||
if not cached_password:
|
||||
continue
|
||||
cluster_info = {
|
||||
"cluster_id": item["clusterId"],
|
||||
"cluster_name": item["displayName"],
|
||||
"account": "root",
|
||||
"password": password.decode("utf-8"),
|
||||
"password": cached_password.decode("utf-8"),
|
||||
}
|
||||
cluster_infos.append(cluster_info)
|
||||
return cluster_infos
|
||||
else:
|
||||
response.raise_for_status()
|
||||
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
|
||||
|
||||
@@ -49,7 +49,7 @@ class TiDBVector(BaseVector):
|
||||
return VectorType.TIDB_VECTOR
|
||||
|
||||
def _table(self, dim: int) -> Table:
|
||||
from tidb_vector.sqlalchemy import VectorType
|
||||
from tidb_vector.sqlalchemy import VectorType # type: ignore
|
||||
|
||||
return Table(
|
||||
self._collection_name,
|
||||
@@ -241,11 +241,11 @@ class TiDBVectorFactory(AbstractVectorFactory):
|
||||
return TiDBVector(
|
||||
collection_name=collection_name,
|
||||
config=TiDBVectorConfig(
|
||||
host=dify_config.TIDB_VECTOR_HOST,
|
||||
port=dify_config.TIDB_VECTOR_PORT,
|
||||
user=dify_config.TIDB_VECTOR_USER,
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD,
|
||||
database=dify_config.TIDB_VECTOR_DATABASE,
|
||||
host=dify_config.TIDB_VECTOR_HOST or "",
|
||||
port=dify_config.TIDB_VECTOR_PORT or 0,
|
||||
user=dify_config.TIDB_VECTOR_USER or "",
|
||||
password=dify_config.TIDB_VECTOR_PASSWORD or "",
|
||||
database=dify_config.TIDB_VECTOR_DATABASE or "",
|
||||
program_name=dify_config.APPLICATION_NAME,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -51,15 +51,16 @@ class BaseVector(ABC):
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
if text.metadata and "doc_id" in text.metadata:
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
def _get_uuids(self, texts: list[Document]) -> list[str]:
|
||||
return [text.metadata["doc_id"] for text in texts]
|
||||
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
|
||||
|
||||
@property
|
||||
def collection_name(self):
|
||||
|
||||
@@ -193,10 +193,13 @@ class Vector:
|
||||
|
||||
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
|
||||
for text in texts.copy():
|
||||
if text.metadata is None:
|
||||
continue
|
||||
doc_id = text.metadata["doc_id"]
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
if doc_id:
|
||||
exists_duplicate_node = self.text_exists(doc_id)
|
||||
if exists_duplicate_node:
|
||||
texts.remove(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from volcengine.viking_db import (
|
||||
from volcengine.viking_db import ( # type: ignore
|
||||
Data,
|
||||
DistanceType,
|
||||
Field,
|
||||
@@ -121,11 +121,12 @@ class VikingDBVector(BaseVector):
|
||||
for i, page_content in enumerate(page_contents):
|
||||
metadata = {}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
metadata[key] = val
|
||||
# FIXME: fix the type of metadata later
|
||||
doc = Data(
|
||||
{
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
|
||||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
|
||||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
|
||||
vdb_Field.CONTENT_KEY.value: page_content,
|
||||
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
|
||||
@@ -178,7 +179,7 @@ class VikingDBVector(BaseVector):
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
return self._get_search_res(results, score_threshold)
|
||||
|
||||
def _get_search_res(self, results, score_threshold):
|
||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||
if len(results) == 0:
|
||||
return []
|
||||
|
||||
@@ -191,7 +192,7 @@ class VikingDBVector(BaseVector):
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
import weaviate # type: ignore
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
@@ -107,7 +107,8 @@ class WeaviateVector(BaseVector):
|
||||
for i, text in enumerate(texts):
|
||||
data_properties = {Field.TEXT_KEY.value: text}
|
||||
if metadatas is not None:
|
||||
for key, val in metadatas[i].items():
|
||||
# metadata maybe None
|
||||
for key, val in (metadatas[i] or {}).items():
|
||||
data_properties[key] = self._json_serializable(val)
|
||||
|
||||
batch.add_data_object(
|
||||
@@ -208,10 +209,11 @@ class WeaviateVector(BaseVector):
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
# Sort the documents by score in descending order
|
||||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
|
||||
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
@@ -275,7 +277,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT,
|
||||
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
|
||||
api_key=dify_config.WEAVIATE_API_KEY,
|
||||
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
|
||||
),
|
||||
|
||||
@@ -83,6 +83,9 @@ class DatasetDocumentStore:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
if doc.metadata is None:
|
||||
raise ValueError("doc.metadata must be a dict")
|
||||
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
@@ -179,10 +182,10 @@ class DatasetDocumentStore:
|
||||
|
||||
if document_segment is None:
|
||||
return None
|
||||
data: Optional[str] = document_segment.index_node_hash
|
||||
return data
|
||||
|
||||
return document_segment.index_node_hash
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> DocumentSegment:
|
||||
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -27,7 +27,7 @@ class CacheEmbedding(Embeddings):
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed search docs in batches of 10."""
|
||||
# use doc embedding cache or store if not exists
|
||||
text_embeddings = [None for _ in range(len(texts))]
|
||||
text_embeddings: list[Any] = [None for _ in range(len(texts))]
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
@@ -64,7 +64,8 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
# FIXME: type ignore for numpy here
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
@@ -77,8 +78,8 @@ class CacheEmbedding(Embeddings):
|
||||
logging.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = n_embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
@@ -86,7 +87,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
embedding_cache.set_embedding(n_embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
db.session.commit()
|
||||
@@ -115,7 +116,8 @@ class CacheEmbedding(Embeddings):
|
||||
)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
# FIXME: type ignore for numpy here
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -14,7 +14,7 @@ class NotionInfo(BaseModel):
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
document: Optional[Document] = None
|
||||
tenant_id: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
@@ -47,7 +47,7 @@ class ExcelExtractor(BaseExtractor):
|
||||
for col_index, (k, v) in enumerate(row.items()):
|
||||
if pd.notna(v):
|
||||
cell = sheet.cell(
|
||||
row=index + 2, column=col_index + 1
|
||||
row=cast(int, index) + 2, column=col_index + 1
|
||||
) # +2 to account for header and 1-based index
|
||||
if cell.hyperlink:
|
||||
value = f"[{v}]({cell.hyperlink.target})"
|
||||
@@ -60,8 +60,8 @@ class ExcelExtractor(BaseExtractor):
|
||||
|
||||
elif file_extension == ".xls":
|
||||
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=sheet_name)
|
||||
for excel_sheet_name in excel_file.sheet_names:
|
||||
df = excel_file.parse(sheet_name=excel_sheet_name)
|
||||
df.dropna(how="all", inplace=True)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
|
||||
@@ -10,6 +10,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
|
||||
@@ -66,9 +67,13 @@ class ExtractProcessor:
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
|
||||
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
match = re.search(r"\.(\w+)$", filename)
|
||||
if match:
|
||||
suffix = "." + match.group(1)
|
||||
else:
|
||||
suffix = ""
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
|
||||
if return_text:
|
||||
@@ -89,15 +94,20 @@ class ExtractProcessor:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE.value:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
if not file_path:
|
||||
assert extract_setting.upload_file is not None, "upload_file is required"
|
||||
upload_file: UploadFile = extract_setting.upload_file
|
||||
suffix = Path(upload_file.key).suffix
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
|
||||
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
|
||||
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
|
||||
assert unstructured_api_url is not None, "unstructured_api_url is required"
|
||||
assert unstructured_api_key is not None, "unstructured_api_key is required"
|
||||
extractor: Optional[BaseExtractor] = None
|
||||
if etl_type == "Unstructured":
|
||||
if file_extension in {".xlsx", ".xls"}:
|
||||
extractor = ExcelExtractor(file_path)
|
||||
@@ -156,6 +166,7 @@ class ExtractProcessor:
|
||||
extractor = TextExtractor(file_path, autodetect_encoding=True)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
|
||||
assert extract_setting.notion_info is not None, "notion_info is required"
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
@@ -165,6 +176,7 @@ class ExtractProcessor:
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
assert extract_setting.website_info is not None, "website_info is required"
|
||||
if extract_setting.website_info.provider == "firecrawl":
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -20,9 +21,9 @@ class FirecrawlApp:
|
||||
json_data.update(params)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response["success"] == True:
|
||||
data = response["data"]
|
||||
response_data = response.json()
|
||||
if response_data["success"] == True:
|
||||
data = response_data["data"]
|
||||
return {
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
@@ -30,7 +31,7 @@ class FirecrawlApp:
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
@@ -46,9 +47,11 @@ class FirecrawlApp:
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get("jobId")
|
||||
return job_id
|
||||
return cast(str, job_id)
|
||||
else:
|
||||
self._handle_error(response, "start crawl job")
|
||||
# FIXME: unreachable code for mypy
|
||||
return "" # unreachable
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
@@ -64,9 +67,9 @@ class FirecrawlApp:
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
"title": item.get("metadata").get("title"),
|
||||
"description": item.get("metadata").get("description"),
|
||||
"source_url": item.get("metadata").get("sourceURL"),
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
@@ -92,6 +95,8 @@ class FirecrawlApp:
|
||||
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
# FIXME: unreachable code for mypy
|
||||
return {} # unreachable
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
@@ -23,6 +23,7 @@ class HtmlExtractor(BaseExtractor):
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
def _load_as_text(self) -> str:
|
||||
text: str = ""
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, "html.parser")
|
||||
text = soup.get_text()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -78,6 +78,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
|
||||
"""Get all the pages from a Notion database."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
res = requests.post(
|
||||
DATABASE_URL_TMPL.format(database_id=database_id),
|
||||
headers={
|
||||
@@ -96,6 +97,7 @@ class NotionExtractor(BaseExtractor):
|
||||
for result in data["results"]:
|
||||
properties = result["properties"]
|
||||
data = {}
|
||||
value: Any
|
||||
for property_name, property_value in properties.items():
|
||||
type = property_value["type"]
|
||||
if type == "multi_select":
|
||||
@@ -130,6 +132,7 @@ class NotionExtractor(BaseExtractor):
|
||||
return [Document(page_content="\n".join(database_content))]
|
||||
|
||||
def _get_notion_block_data(self, page_id: str) -> list[str]:
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
|
||||
@@ -184,6 +187,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
|
||||
@@ -242,6 +246,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
def _read_table_rows(self, block_id: str) -> str:
|
||||
"""Read table rows."""
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
start_cursor = None
|
||||
@@ -296,7 +301,7 @@ class NotionExtractor(BaseExtractor):
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def update_last_edited_time(self, document_model: DocumentModel):
|
||||
def update_last_edited_time(self, document_model: Optional[DocumentModel]):
|
||||
if not document_model:
|
||||
return
|
||||
|
||||
@@ -309,6 +314,7 @@ class NotionExtractor(BaseExtractor):
|
||||
db.session.commit()
|
||||
|
||||
def get_notion_last_edited_time(self) -> str:
|
||||
assert self._notion_access_token is not None, "Notion access token is required"
|
||||
obj_id = self._notion_obj_id
|
||||
page_type = self._notion_page_type
|
||||
if page_type == "database":
|
||||
@@ -330,7 +336,7 @@ class NotionExtractor(BaseExtractor):
|
||||
)
|
||||
|
||||
data = res.json()
|
||||
return data["last_edited_time"]
|
||||
return cast(str, data["last_edited_time"])
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
@@ -349,4 +355,4 @@ class NotionExtractor(BaseExtractor):
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return data_source_binding.access_token
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Abstract interface for document loader implementations."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
||||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
try:
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
except FileNotFoundError:
|
||||
@@ -53,7 +53,7 @@ class PdfExtractor(BaseExtractor):
|
||||
|
||||
def parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
import pypdfium2
|
||||
import pypdfium2 # type: ignore
|
||||
|
||||
with blob.as_bytes_io() as file_path:
|
||||
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
||||
@@ -30,6 +30,9 @@ class UnstructuredEpubExtractor(BaseExtractor):
|
||||
if self._api_url:
|
||||
from unstructured.partition.api import partition_via_api
|
||||
|
||||
if self._api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
from unstructured.partition.epub import partition_epub
|
||||
|
||||
@@ -27,9 +27,11 @@ class UnstructuredPPTExtractor(BaseExtractor):
|
||||
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
|
||||
else:
|
||||
raise NotImplementedError("Unstructured API Url is not configured")
|
||||
text_by_page = {}
|
||||
text_by_page: dict[int, str] = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
if page is None:
|
||||
continue
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
|
||||
@@ -29,14 +29,15 @@ class UnstructuredPPTXExtractor(BaseExtractor):
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
elements = partition_pptx(filename=self._file_path)
|
||||
text_by_page = {}
|
||||
text_by_page: dict[int, str] = {}
|
||||
for element in elements:
|
||||
page = element.metadata.page_number
|
||||
text = element.text
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
if page is not None:
|
||||
if page in text_by_page:
|
||||
text_by_page[page] += "\n" + text
|
||||
else:
|
||||
text_by_page[page] = text
|
||||
|
||||
combined_texts = list(text_by_page.values())
|
||||
documents = []
|
||||
|
||||
@@ -89,6 +89,8 @@ class WordExtractor(BaseExtractor):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
if image_ext is None:
|
||||
continue
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
mime_type, _ = mimetypes.guess_type(file_key)
|
||||
@@ -97,6 +99,8 @@ class WordExtractor(BaseExtractor):
|
||||
continue
|
||||
else:
|
||||
image_ext = rel.target_ref.split(".")[-1]
|
||||
if image_ext is None:
|
||||
continue
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
|
||||
@@ -226,6 +230,8 @@ class WordExtractor(BaseExtractor):
|
||||
if x_child is None:
|
||||
continue
|
||||
if x.tag.endswith("instrText"):
|
||||
if x.text is None:
|
||||
continue
|
||||
for i in url_pattern.findall(x.text):
|
||||
hyperlinks_url = str(i)
|
||||
except Exception as e:
|
||||
|
||||
@@ -49,6 +49,7 @@ class BaseIndexProcessor(ABC):
|
||||
"""
|
||||
Get the NodeParser object according to the processing rule.
|
||||
"""
|
||||
character_splitter: TextSplitter
|
||||
if processing_rule["mode"] == "custom":
|
||||
# The user-defined segmentation rule
|
||||
rules = processing_rule["rules"]
|
||||
|
||||
@@ -9,7 +9,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
|
||||
class IndexProcessorFactory:
|
||||
"""IndexProcessorInit."""
|
||||
|
||||
def __init__(self, index_type: str):
|
||||
def __init__(self, index_type: str | None):
|
||||
self._index_type = index_type
|
||||
|
||||
def init_index_processor(self) -> BaseIndexProcessor:
|
||||
|
||||
@@ -27,12 +27,13 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
# Split the text documents into nodes.
|
||||
splitter = self._get_splitter(
|
||||
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
|
||||
processing_rule=kwargs.get("process_rule", {}),
|
||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||
)
|
||||
all_documents = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {}))
|
||||
document.page_content = document_text
|
||||
# parse document to nodes
|
||||
document_nodes = splitter.split_documents([document])
|
||||
@@ -41,8 +42,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
if document_node.metadata is not None:
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = remove_leading_symbols(document_node.page_content).strip()
|
||||
if len(page_content) > 0:
|
||||
|
||||
@@ -32,15 +32,16 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
|
||||
splitter = self._get_splitter(
|
||||
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
|
||||
processing_rule=kwargs.get("process_rule") or {},
|
||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||
)
|
||||
|
||||
# Split the text documents into nodes.
|
||||
all_documents = []
|
||||
all_qa_documents = []
|
||||
all_documents: list[Document] = []
|
||||
all_qa_documents: list[Document] = []
|
||||
for document in documents:
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
|
||||
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {})
|
||||
document.page_content = document_text
|
||||
|
||||
# parse document to nodes
|
||||
@@ -50,8 +51,9 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
if document_node.page_content.strip():
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
if document_node.metadata is not None:
|
||||
document_node.metadata["doc_id"] = doc_id
|
||||
document_node.metadata["doc_hash"] = hash
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
document_node.page_content = remove_leading_symbols(page_content)
|
||||
@@ -64,7 +66,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
document_format_thread = threading.Thread(
|
||||
target=self._format_qa_document,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"document_node": doc,
|
||||
"all_qa_documents": all_qa_documents,
|
||||
@@ -148,11 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
qa_documents = []
|
||||
for result in document_qa_list:
|
||||
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result["question"])
|
||||
qa_document.metadata["answer"] = result["answer"]
|
||||
qa_document.metadata["doc_id"] = doc_id
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
if qa_document.metadata is not None:
|
||||
doc_id = str(uuid.uuid4())
|
||||
hash = helper.generate_text_hash(result["question"])
|
||||
qa_document.metadata["answer"] = result["answer"]
|
||||
qa_document.metadata["doc_id"] = doc_id
|
||||
qa_document.metadata["doc_hash"] = hash
|
||||
qa_documents.append(qa_document)
|
||||
format_documents.extend(qa_documents)
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,7 +30,11 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
doc_ids = set()
|
||||
unique_documents = []
|
||||
for document in documents:
|
||||
if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
|
||||
if (
|
||||
document.provider == "dify"
|
||||
and document.metadata is not None
|
||||
and document.metadata["doc_id"] not in doc_ids
|
||||
):
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
docs.append(document.page_content)
|
||||
unique_documents.append(document)
|
||||
@@ -54,7 +58,8 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
metadata=documents[result.index].metadata,
|
||||
provider=documents[result.index].provider,
|
||||
)
|
||||
rerank_document.metadata["score"] = result.score
|
||||
rerank_documents.append(rerank_document)
|
||||
if rerank_document.metadata is not None:
|
||||
rerank_document.metadata["score"] = result.score
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return rerank_documents
|
||||
|
||||
@@ -39,7 +39,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
unique_documents = []
|
||||
doc_ids = set()
|
||||
for document in documents:
|
||||
if document.metadata["doc_id"] not in doc_ids:
|
||||
if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
|
||||
doc_ids.add(document.metadata["doc_id"])
|
||||
unique_documents.append(document)
|
||||
|
||||
@@ -56,10 +56,11 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
)
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = score
|
||||
rerank_documents.append(document)
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
|
||||
rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
|
||||
return rerank_documents[:top_n] if top_n else rerank_documents
|
||||
|
||||
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
|
||||
@@ -76,8 +77,9 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
if document.metadata is not None:
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
query_keyword_counts = Counter(query_keywords)
|
||||
@@ -162,7 +164,7 @@ class WeightRerankRunner(BaseRerankRunner):
|
||||
query_vector = cache_embedding.embed_query(query)
|
||||
for document in documents:
|
||||
# calculate cosine similarity
|
||||
if "score" in document.metadata:
|
||||
if document.metadata and "score" in document.metadata:
|
||||
query_vector_scores.append(document.metadata["score"])
|
||||
else:
|
||||
# transform to NumPy
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
import threading
|
||||
from collections import Counter
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -34,7 +34,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@@ -140,12 +140,12 @@ class DatasetRetrieval:
|
||||
user_from,
|
||||
available_datasets,
|
||||
query,
|
||||
retrieve_config.top_k,
|
||||
retrieve_config.score_threshold,
|
||||
retrieve_config.rerank_mode,
|
||||
retrieve_config.top_k or 0,
|
||||
retrieve_config.score_threshold or 0,
|
||||
retrieve_config.rerank_mode or "reranking_model",
|
||||
retrieve_config.reranking_model,
|
||||
retrieve_config.weights,
|
||||
retrieve_config.reranking_enabled,
|
||||
retrieve_config.reranking_enabled or True,
|
||||
message_id,
|
||||
)
|
||||
|
||||
@@ -300,10 +300,11 @@ class DatasetRetrieval:
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
else:
|
||||
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
|
||||
@@ -325,7 +326,7 @@ class DatasetRetrieval:
|
||||
score_threshold = 0.0
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
score_threshold = retrieval_model_config.get("score_threshold", 0.0)
|
||||
|
||||
with measure_time() as timer:
|
||||
results = RetrievalService.retrieve(
|
||||
@@ -358,14 +359,14 @@ class DatasetRetrieval:
|
||||
score_threshold: float,
|
||||
reranking_mode: str,
|
||||
reranking_model: Optional[dict] = None,
|
||||
weights: Optional[dict] = None,
|
||||
weights: Optional[dict[str, Any]] = None,
|
||||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_documents = []
|
||||
all_documents: list[Document] = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type_check = all(
|
||||
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||
@@ -392,15 +393,18 @@ class DatasetRetrieval:
|
||||
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||
)
|
||||
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
if weights is not None:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[
|
||||
0
|
||||
].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset.id,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
@@ -439,21 +443,22 @@ class DatasetRetrieval:
|
||||
"""Handle retrieval end."""
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
if document.metadata is not None:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# get tracing instance
|
||||
trace_manager: TraceQueueManager = (
|
||||
trace_manager: Optional[TraceQueueManager] = (
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
@@ -504,10 +509,11 @@ class DatasetRetrieval:
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset_id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
all_documents.append(document)
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
@@ -607,19 +613,20 @@ class DatasetRetrieval:
|
||||
|
||||
tools.append(tool)
|
||||
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
if retrieve_config.reranking_model is not None:
|
||||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
|
||||
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
@@ -635,10 +642,11 @@ class DatasetRetrieval:
|
||||
query_keywords = keyword_table_handler.extract_keywords(query, None)
|
||||
documents_keywords = []
|
||||
for document in documents:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
if document.metadata is not None:
|
||||
# get the document keywords
|
||||
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
||||
document.metadata["keywords"] = document_keywords
|
||||
documents_keywords.append(document_keywords)
|
||||
|
||||
# Counter query keywords(TF)
|
||||
query_keyword_counts = Counter(query_keywords)
|
||||
@@ -696,8 +704,9 @@ class DatasetRetrieval:
|
||||
|
||||
for document, score in zip(documents, similarities):
|
||||
# format document
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = score
|
||||
documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
|
||||
return documents[:top_k] if top_k else documents
|
||||
|
||||
def calculate_vector_score(
|
||||
@@ -705,10 +714,12 @@ class DatasetRetrieval:
|
||||
) -> list[Document]:
|
||||
filter_documents = []
|
||||
for document in all_documents:
|
||||
if score_threshold is None or document.metadata["score"] >= score_threshold:
|
||||
if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
|
||||
filter_documents.append(document)
|
||||
|
||||
if not filter_documents:
|
||||
return []
|
||||
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
filter_documents = sorted(
|
||||
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
|
||||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Union
|
||||
from typing import Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
@@ -27,11 +28,14 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
result = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
),
|
||||
)
|
||||
if result.message.tool_calls:
|
||||
# get retrieval model config
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Union
|
||||
from typing import Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -92,6 +92,7 @@ class ReactMultiDatasetRouter:
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
) -> Union[str, None]:
|
||||
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
|
||||
if model_config.mode == "chat":
|
||||
prompt = self.create_chat_prompt(
|
||||
query=query,
|
||||
@@ -149,12 +150,15 @@ class ReactMultiDatasetRouter:
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
invoke_result = cast(
|
||||
Generator[LLMResult, None, None],
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_param,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
user=user_id,
|
||||
),
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -172,7 +176,7 @@ class ReactMultiDatasetRouter:
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
|
||||
@@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def from_encoder(
|
||||
cls: type[TS],
|
||||
embedding_model_instance: Optional[ModelInstance],
|
||||
allowed_special: Union[Literal[all], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal[all], Collection[str]] = "all",
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
|
||||
@@ -92,7 +92,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
texts, metadatas = [], []
|
||||
for doc in documents:
|
||||
texts.append(doc.page_content)
|
||||
metadatas.append(doc.metadata)
|
||||
metadatas.append(doc.metadata or {})
|
||||
return self.create_documents(texts, metadatas=metadatas)
|
||||
|
||||
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
|
||||
@@ -143,7 +143,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
||||
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase # type: ignore
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")
|
||||
|
||||
Reference in New Issue
Block a user