feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -32,8 +32,11 @@ class Jieba(BaseKeyword):
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, text.metadata["doc_id"], list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
@@ -58,20 +61,26 @@ class Jieba(BaseKeyword):
keywords = keyword_table_handler.extract_keywords(
text.page_content, self._config.max_keywords_per_chunk
)
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata["doc_id"], list(keywords))
if text.metadata is not None:
self._update_segment_keywords(self.dataset.id, text.metadata["doc_id"], list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, text.metadata["doc_id"], list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
return False
return id in set.union(*keyword_table.values())
def delete_by_ids(self, ids: list[str]) -> None:
lock_name = "keyword_indexing_lock_{}".format(self.dataset.id)
with redis_client.lock(lock_name, timeout=600):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
if keyword_table is not None:
keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids)
self._save_dataset_keyword_table(keyword_table)
@@ -80,7 +89,7 @@ class Jieba(BaseKeyword):
k = kwargs.get("top_k", 4)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table, query, k)
sorted_chunk_indices = self._retrieve_ids_by_query(keyword_table or {}, query, k)
documents = []
for chunk_index in sorted_chunk_indices:
@@ -137,7 +146,7 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict["__data__"]["table"]
return dict(keyword_table_dict["__data__"]["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(
@@ -188,8 +197,8 @@ class Jieba(BaseKeyword):
# go through text chunks in order of most matching keywords
chunk_indices_count: dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
keywords_list = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords_list:
for node_id in keyword_table[keyword]:
chunk_indices_count[node_id] += 1
@@ -215,7 +224,7 @@ class Jieba(BaseKeyword):
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
@@ -226,17 +235,19 @@ class Jieba(BaseKeyword):
if pre_segment_data["keywords"]:
segment.keywords = pre_segment_data["keywords"]
keyword_table = self._add_text_to_keyword_table(
keyword_table, segment.index_node_id, pre_segment_data["keywords"]
keyword_table or {}, segment.index_node_id, pre_segment_data["keywords"]
)
else:
keywords = keyword_table_handler.extract_keywords(segment.content, self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
keyword_table = self._add_text_to_keyword_table(
keyword_table or {}, segment.index_node_id, list(keywords)
)
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)

View File

@@ -4,7 +4,7 @@ from typing import Optional
class JiebaKeywordTableHandler:
def __init__(self):
import jieba.analyse
import jieba.analyse # type: ignore
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
@@ -12,7 +12,7 @@ class JiebaKeywordTableHandler:
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba
import jieba # type: ignore
keywords = jieba.analyse.extract_tags(
sentence=text,

View File

@@ -37,6 +37,8 @@ class BaseKeyword(ABC):
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata is None:
continue
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
@@ -45,4 +47,4 @@ class BaseKeyword(ABC):
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata["doc_id"] for text in texts]
return [text.metadata["doc_id"] for text in texts if text.metadata]

View File

@@ -6,6 +6,7 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@@ -31,7 +32,7 @@ class RetrievalService:
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = "reranking_model",
reranking_mode: str = "reranking_model",
weights: Optional[dict] = None,
):
if not query:
@@ -42,15 +43,15 @@ class RetrievalService:
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = []
threads = []
exceptions = []
all_documents: list[Document] = []
threads: list[threading.Thread] = []
exceptions: list[str] = []
# retrieval_model source with keyword
if retrieval_method == "keyword_search":
keyword_thread = threading.Thread(
target=RetrievalService.keyword_search,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
@@ -65,7 +66,7 @@ class RetrievalService:
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
@@ -84,7 +85,7 @@ class RetrievalService:
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
@@ -124,7 +125,7 @@ class RetrievalService:
if not dataset:
return []
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
)
return all_documents
@@ -135,6 +136,8 @@ class RetrievalService:
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("dataset not found")
keyword = Keyword(dataset=dataset)
@@ -159,6 +162,8 @@ class RetrievalService:
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("dataset not found")
vector = Vector(dataset=dataset)
@@ -209,6 +214,8 @@ class RetrievalService:
with flask_app.app_context():
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("dataset not found")
vector_processor = Vector(
dataset=dataset,

View File

@@ -17,12 +17,19 @@ from models.dataset import Dataset
class AnalyticdbVector(BaseVector):
def __init__(
self, collection_name: str, api_config: AnalyticdbVectorOpenAPIConfig, sql_config: AnalyticdbVectorBySqlConfig
self,
collection_name: str,
api_config: AnalyticdbVectorOpenAPIConfig | None,
sql_config: AnalyticdbVectorBySqlConfig | None,
):
super().__init__(collection_name)
if api_config is not None:
self.analyticdb_vector = AnalyticdbVectorOpenAPI(collection_name, api_config)
self.analyticdb_vector: AnalyticdbVectorOpenAPI | AnalyticdbVectorBySql = AnalyticdbVectorOpenAPI(
collection_name, api_config
)
else:
if sql_config is None:
raise ValueError("Either api_config or sql_config must be provided")
self.analyticdb_vector = AnalyticdbVectorBySql(collection_name, sql_config)
def get_type(self) -> str:
@@ -33,8 +40,8 @@ class AnalyticdbVector(BaseVector):
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(documents, embeddings)
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)
@@ -68,13 +75,13 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
if dify_config.ANALYTICDB_HOST is None:
# implemented through OpenAPI
apiConfig = AnalyticdbVectorOpenAPIConfig(
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
access_key_id=dify_config.ANALYTICDB_KEY_ID or "",
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET or "",
region_id=dify_config.ANALYTICDB_REGION_ID or "",
instance_id=dify_config.ANALYTICDB_INSTANCE_ID or "",
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
)
sqlConfig = None
@@ -83,11 +90,11 @@ class AnalyticdbVectorFactory(AbstractVectorFactory):
sqlConfig = AnalyticdbVectorBySqlConfig(
host=dify_config.ANALYTICDB_HOST,
port=dify_config.ANALYTICDB_PORT,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
account=dify_config.ANALYTICDB_ACCOUNT or "",
account_password=dify_config.ANALYTICDB_PASSWORD or "",
min_connection=dify_config.ANALYTICDB_MIN_CONNECTION,
max_connection=dify_config.ANALYTICDB_MAX_CONNECTION,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace=dify_config.ANALYTICDB_NAMESPACE or "",
)
apiConfig = None
return AnalyticdbVector(

View File

@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, model_validator
@@ -20,7 +20,7 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
account: str
account_password: str
namespace: str = "dify"
namespace_password: str = (None,)
namespace_password: Optional[str] = None
metrics: str = "cosine"
read_timeout: int = 60000
@@ -55,8 +55,8 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
class AnalyticdbVectorOpenAPI:
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
try:
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_gpdb20160503.client import Client # type: ignore
from alibabacloud_tea_openapi import models as open_api_models # type: ignore
except:
raise ImportError(_import_err_msg)
self._collection_name = collection_name.lower()
@@ -77,7 +77,7 @@ class AnalyticdbVectorOpenAPI:
redis_client.set(database_exist_cache_key, 1, ex=3600)
def _initialize_vector_database(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models # type: ignore
request = gpdb_20160503_models.InitVectorDatabaseRequest(
dbinstance_id=self.config.instance_id,
@@ -89,7 +89,7 @@ class AnalyticdbVectorOpenAPI:
def _create_namespace_if_not_exists(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException
from Tea.exceptions import TeaException # type: ignore
try:
request = gpdb_20160503_models.DescribeNamespaceRequest(
@@ -159,17 +159,18 @@ class AnalyticdbVectorOpenAPI:
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
for doc, embedding in zip(documents, embeddings, strict=True):
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
if doc.metadata is not None:
metadata = {
"ref_doc_id": doc.metadata["doc_id"],
"page_content": doc.page_content,
"metadata_": json.dumps(doc.metadata),
}
rows.append(
gpdb_20160503_models.UpsertCollectionDataRequestRows(
vector=embedding,
metadata=metadata,
)
)
)
request = gpdb_20160503_models.UpsertCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -258,7 +259,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -290,7 +291,7 @@ class AnalyticdbVectorOpenAPI:
metadata=metadata,
)
documents.append(doc)
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
def delete(self) -> None:

View File

@@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from core.rag.models.document import Document
@@ -75,6 +75,7 @@ class AnalyticdbVectorBySql:
@contextmanager
def _get_cursor(self):
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
try:
@@ -156,16 +157,17 @@ class AnalyticdbVectorBySql:
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
if doc.metadata is not None:
values.append(
(
id_prefix + str(i),
doc.metadata.get("doc_id", str(uuid.uuid4())),
embeddings[i],
doc.page_content,
json.dumps(doc.metadata),
doc.page_content,
)
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_batch(cur, sql, values)

View File

@@ -5,13 +5,13 @@ from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.exception import ServerError
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
from pymochow import MochowClient # type: ignore
from pymochow.auth.bce_credentials import BceCredentials # type: ignore
from pymochow.configuration import Configuration # type: ignore
from pymochow.exception import ServerError # type: ignore
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState # type: ignore
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex # type: ignore
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -75,7 +75,7 @@ class BaiduVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
metadatas = [doc.metadata for doc in documents if doc.metadata is not None]
total_count = len(documents)
batch_size = 1000
@@ -84,6 +84,8 @@ class BaiduVector(BaseVector):
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
# FIXME do you need this assert?
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
@@ -136,7 +138,7 @@ class BaiduVector(BaseVector):
# baidu vector database doesn't support bm25 search on current version
return []
def _get_search_res(self, res, score_threshold):
def _get_search_res(self, res, score_threshold) -> list[Document]:
docs = []
for row in res.rows:
row_data = row.get("row", {})
@@ -276,11 +278,11 @@ class BaiduVectorFactory(AbstractVectorFactory):
return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT or "",
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT or "",
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY or "",
database=dify_config.BAIDU_VECTOR_DB_DATABASE or "",
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),

View File

@@ -71,11 +71,13 @@ class ChromaVector(BaseVector):
metadatas = [d.metadata for d in documents]
collection = self._client.get_or_create_collection(self._collection_name)
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas)
# FIXME: chromadb using numpy array, fix the type error later
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
def delete_by_metadata_field(self, key: str, value: str):
collection = self._client.get_or_create_collection(self._collection_name)
collection.delete(where={key: {"$eq": value}})
# FIXME: fix the type error later
collection.delete(where={key: {"$eq": value}}) # type: ignore
def delete(self):
self._client.delete_collection(self._collection_name)
@@ -94,15 +96,19 @@ class ChromaVector(BaseVector):
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]
metadatas: dict[str, Any] = results["metadatas"][0]
distances: list[float] = results["distances"][0]
# Check if results contain data
if not results["ids"] or not results["documents"] or not results["metadatas"] or not results["distances"]:
return []
ids = results["ids"][0]
documents = results["documents"][0]
metadatas = results["metadatas"][0]
distances = results["distances"][0]
docs = []
for index in range(len(ids)):
distance = distances[index]
metadata = metadatas[index]
metadata = dict(metadatas[index])
if distance >= score_threshold:
metadata["score"] = distance
doc = Document(
@@ -111,7 +117,7 @@ class ChromaVector(BaseVector):
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -133,7 +139,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
return ChromaVector(
collection_name=collection_name,
config=ChromaConfig(
host=dify_config.CHROMA_HOST,
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,

View File

@@ -5,14 +5,14 @@ import uuid
from datetime import timedelta
from typing import Any
from couchbase import search
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.management.search import SearchIndex
from couchbase import search # type: ignore
from couchbase.auth import PasswordAuthenticator # type: ignore
from couchbase.cluster import Cluster # type: ignore
from couchbase.management.search import SearchIndex # type: ignore
# needed for options -- cluster, timeout, SQL++ (N1QL) query, etc.
from couchbase.options import ClusterOptions, SearchOptions
from couchbase.vector_search import VectorQuery, VectorSearch
from couchbase.options import ClusterOptions, SearchOptions # type: ignore
from couchbase.vector_search import VectorQuery, VectorSearch # type: ignore
from flask import current_app
from pydantic import BaseModel, model_validator
@@ -231,7 +231,7 @@ class CouchbaseVector(BaseVector):
# Pass the id as a parameter to the query
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
for row in result:
return row["count"] > 0
return bool(row["count"] > 0)
return False # Return False if no rows are returned
def delete_by_ids(self, ids: list[str]) -> None:
@@ -369,10 +369,10 @@ class CouchbaseVectorFactory(AbstractVectorFactory):
return CouchbaseVector(
collection_name=collection_name,
config=CouchbaseConfig(
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
user=config.get("COUCHBASE_USER"),
password=config.get("COUCHBASE_PASSWORD"),
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
connection_string=config.get("COUCHBASE_CONNECTION_STRING", ""),
user=config.get("COUCHBASE_USER", ""),
password=config.get("COUCHBASE_PASSWORD", ""),
bucket_name=config.get("COUCHBASE_BUCKET_NAME", ""),
scope_name=config.get("COUCHBASE_SCOPE_NAME", ""),
),
)

View File

@@ -1,7 +1,7 @@
import json
import logging
import math
from typing import Any, Optional
from typing import Any, Optional, cast
from urllib.parse import urlparse
import requests
@@ -70,7 +70,7 @@ class ElasticSearchVector(BaseVector):
def _get_version(self) -> str:
info = self._client.info()
return info["version"]["number"]
return cast(str, info["version"]["number"])
def _check_version(self):
if self._version < "8.0.0":
@@ -135,7 +135,8 @@ class ElasticSearchVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
doc.metadata["score"] = score
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -156,12 +157,15 @@ class ElasticSearchVector(BaseVector):
return docs
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
self,
embeddings: list[list[float]],
metadatas: Optional[list[dict[Any, Any]]] = None,
index_params: Optional[dict] = None,
):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
@@ -208,10 +212,10 @@ class ElasticSearchVectorFactory(AbstractVectorFactory):
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(
host=config.get("ELASTICSEARCH_HOST"),
port=config.get("ELASTICSEARCH_PORT"),
username=config.get("ELASTICSEARCH_USERNAME"),
password=config.get("ELASTICSEARCH_PASSWORD"),
host=config.get("ELASTICSEARCH_HOST", "localhost"),
port=config.get("ELASTICSEARCH_PORT", 9200),
username=config.get("ELASTICSEARCH_USERNAME", ""),
password=config.get("ELASTICSEARCH_PASSWORD", ""),
),
attributes=[],
)

View File

@@ -42,7 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {"hosts": self.hosts}
params: dict[str, Any] = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@@ -53,7 +53,7 @@ class LindormVectorStore(BaseVector):
self._routing = None
self._routing_field = None
if using_ugc:
routing_value: str = kwargs.get("routing_value")
routing_value: str | None = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
@@ -87,14 +87,15 @@ class LindormVectorStore(BaseVector):
"_id": uuids[i],
}
}
action_values = {
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
@@ -105,7 +106,9 @@ class LindormVectorStore(BaseVector):
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}
}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
@@ -191,7 +194,8 @@ class LindormVectorStore(BaseVector):
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
doc.metadata["score"] = score
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
return docs
@@ -366,6 +370,7 @@ def default_text_search_query(
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
query_clause: dict[str, Any] = {}
if routing is not None:
query_clause = {
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
@@ -386,7 +391,7 @@ def default_text_search_query(
else:
must = [query_clause]
boolean_query = {"must": must}
boolean_query: dict[str, Any] = {"must": must}
if must_not:
if not isinstance(must_not, list):
@@ -426,7 +431,7 @@ def default_vector_search_query(
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
final_ext: dict[str, Any] = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
@@ -438,7 +443,7 @@ def default_vector_search_query(
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
search_query = {
search_query: dict[str, Any] = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
@@ -446,8 +451,8 @@ def default_vector_search_query(
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
filter_dict = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filter_dict # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
@@ -459,17 +464,19 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
hosts=dify_config.LINDORM_URL or "",
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
using_ugc = dify_config.USING_UGC_INDEX
if using_ugc is None:
raise ValueError("USING_UGC_INDEX is not set")
routing_value = None
if dataset.index_struct:
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc = dataset.index_struct_dict.get("using_ugc", False)
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = dataset.index_struct_dict["dimension"]

View File

@@ -3,8 +3,8 @@ import logging
from typing import Any, Optional
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException
from pymilvus.milvus_client import IndexParams
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -54,14 +54,14 @@ class MilvusVector(BaseVector):
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields = []
self._fields: list[str] = []
def get_type(self) -> str:
return VectorType.MILVUS
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
@@ -161,8 +161,8 @@ class MilvusVector(BaseVector):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore
# Determine embedding dim
dim = len(embeddings[0])
@@ -217,10 +217,10 @@ class MilvusVectorFactory(AbstractVectorFactory):
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
uri=dify_config.MILVUS_URI,
token=dify_config.MILVUS_TOKEN,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
database=dify_config.MILVUS_DATABASE,
uri=dify_config.MILVUS_URI or "",
token=dify_config.MILVUS_TOKEN or "",
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
),
)

View File

@@ -74,15 +74,16 @@ class MyScaleVector(BaseVector):
columns = ["id", "text", "vector", "metadata"]
values = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
row = (
doc_id,
self.escape_str(doc.page_content),
embeddings[i],
json.dumps(doc.metadata) if doc.metadata else {},
)
values.append(str(row))
ids.append(doc_id)
sql = f"""
INSERT INTO {self._config.database}.{self._collection_name}
({",".join(columns)}) VALUES {",".join(values)}

View File

@@ -4,7 +4,7 @@ import math
from typing import Any
from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from pyobvector import VECTOR, ObVecClient # type: ignore
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT
@@ -131,7 +131,7 @@ class OceanBaseVector(BaseVector):
def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0
return bool(cur.rowcount != 0)
def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)

View File

@@ -66,7 +66,7 @@ class OpenSearchVector(BaseVector):
return VectorType.OPENSEARCH
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
metadatas = [d.metadata for d in texts]
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings)
@@ -244,7 +244,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST,
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,

View File

@@ -5,7 +5,7 @@ import uuid
from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
from pydantic import BaseModel, model_validator
@@ -88,12 +88,11 @@ class OracleVector(BaseVector):
def numpy_converter_out(self, value):
if value.typecode == "b":
dtype = numpy.int8
return numpy.array(value, copy=False, dtype=numpy.int8)
elif value.typecode == "f":
dtype = numpy.float32
return numpy.array(value, copy=False, dtype=numpy.float32)
else:
dtype = numpy.float64
return numpy.array(value, copy=False, dtype=dtype)
return numpy.array(value, copy=False, dtype=numpy.float64)
def output_type_handler(self, cursor, metadata):
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
@@ -135,17 +134,18 @@ class OracleVector(BaseVector):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
# array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
)
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(
@@ -201,8 +201,8 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords
import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
@@ -285,10 +285,10 @@ class OracleVectorFactory(AbstractVectorFactory):
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
host=dify_config.ORACLE_HOST,
host=dify_config.ORACLE_HOST or "localhost",
port=dify_config.ORACLE_PORT,
user=dify_config.ORACLE_USER,
password=dify_config.ORACLE_PASSWORD,
database=dify_config.ORACLE_DATABASE,
user=dify_config.ORACLE_USER or "system",
password=dify_config.ORACLE_PASSWORD or "oracle",
database=dify_config.ORACLE_DATABASE or "orcl",
),
)

View File

@@ -4,7 +4,7 @@ from typing import Any
from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import VECTOR
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
from sqlalchemy import Float, String, create_engine, insert, select, text
from sqlalchemy import text as sql_text
@@ -58,7 +58,7 @@ class PGVectoRS(BaseVector):
with Session(self._client) as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vectors"))
session.commit()
self._fields = []
self._fields: list[str] = []
class _Table(CollectionORM):
__tablename__ = collection_name
@@ -222,11 +222,11 @@ class PGVectoRSFactory(AbstractVectorFactory):
return PGVectoRS(
collection_name=collection_name,
config=PgvectoRSConfig(
host=dify_config.PGVECTO_RS_HOST,
port=dify_config.PGVECTO_RS_PORT,
user=dify_config.PGVECTO_RS_USER,
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
host=dify_config.PGVECTO_RS_HOST or "localhost",
port=dify_config.PGVECTO_RS_PORT or 5432,
user=dify_config.PGVECTO_RS_USER or "postgres",
password=dify_config.PGVECTO_RS_PASSWORD or "",
database=dify_config.PGVECTO_RS_DATABASE or "postgres",
),
dim=dim,
)

View File

@@ -3,8 +3,8 @@ import uuid
from contextlib import contextmanager
from typing import Any
import psycopg2.extras
import psycopg2.pool
import psycopg2.extras # type: ignore
import psycopg2.pool # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -98,16 +98,17 @@ class PGVector(BaseVector):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
embeddings[i],
)
)
)
with self._get_cursor() as cur:
psycopg2.extras.execute_values(
cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
@@ -216,11 +217,11 @@ class PGVectorFactory(AbstractVectorFactory):
return PGVector(
collection_name=collection_name,
config=PGVectorConfig(
host=dify_config.PGVECTOR_HOST,
host=dify_config.PGVECTOR_HOST or "localhost",
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER,
password=dify_config.PGVECTOR_PASSWORD,
database=dify_config.PGVECTOR_DATABASE,
user=dify_config.PGVECTOR_USER or "postgres",
password=dify_config.PGVECTOR_PASSWORD or "",
database=dify_config.PGVECTOR_DATABASE or "postgres",
min_connection=dify_config.PGVECTOR_MIN_CONNECTION,
max_connection=dify_config.PGVECTOR_MAX_CONNECTION,
),

View File

@@ -51,6 +51,8 @@ class QdrantConfig(BaseModel):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
if not self.root_path:
raise ValueError("Root path is not set")
path = os.path.join(self.root_path, path)
return {"path": path}
@@ -149,9 +151,12 @@ class QdrantVector(BaseVector):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
# Filter out None values from metadatas list to match expected type
filtered_metadatas = [m for m in metadatas if m is not None]
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
@@ -194,7 +199,7 @@ class QdrantVector(BaseVector):
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
group_id or "", # Ensure group_id is never None
Field.GROUP_KEY.value,
),
)
@@ -337,18 +342,20 @@ class QdrantVector(BaseVector):
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -432,9 +439,9 @@ class QdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=QdrantConfig(
endpoint=dify_config.QDRANT_URL,
endpoint=dify_config.QDRANT_URL or "",
api_key=dify_config.QDRANT_API_KEY,
root_path=current_app.config.root_path,
root_path=str(current_app.config.root_path),
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED,

View File

@@ -3,7 +3,7 @@ import uuid
from typing import Any, Optional
from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import Column, String, Table, create_engine, insert
from sqlalchemy import text as sql_text
from sqlalchemy.dialects.postgresql import JSON, TEXT
from sqlalchemy.orm import Session
@@ -58,14 +58,14 @@ class RelytVector(BaseVector):
f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}"
)
self.client = create_engine(self._url)
self._fields = []
self._fields: list[str] = []
self._group_id = group_id
def get_type(self) -> str:
return VectorType.RELYT
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs) -> None:
index_params: dict[str, Any] = {}
metadatas = [d.metadata for d in texts]
self.create_collection(len(embeddings[0]))
self.embedding_dimension = len(embeddings[0])
@@ -107,10 +107,10 @@ class RelytVector(BaseVector):
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
from pgvecto_rs.sqlalchemy import VECTOR
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
ids = [str(uuid.uuid1()) for _ in documents]
metadatas = [d.metadata for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
for metadata in metadatas:
metadata["group_id"] = self._group_id
texts = [d.page_content for d in documents]
@@ -242,10 +242,6 @@ class RelytVector(BaseVector):
filter: Optional[dict] = None,
) -> list[tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError("Could not import Row from sqlalchemy.engine. Please 'pip install sqlalchemy>=1.4'.")
filter_condition = ""
if filter is not None:
@@ -275,7 +271,7 @@ class RelytVector(BaseVector):
# Execute the query and fetch the results
with self.client.connect() as conn:
results: Sequence[Row] = conn.execute(sql_text(sql_query), params).fetchall()
results = conn.execute(sql_text(sql_query), params).fetchall()
documents_with_scores = [
(
@@ -307,11 +303,11 @@ class RelytVectorFactory(AbstractVectorFactory):
return RelytVector(
collection_name=collection_name,
config=RelytConfig(
host=dify_config.RELYT_HOST,
host=dify_config.RELYT_HOST or "localhost",
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
user=dify_config.RELYT_USER or "",
password=dify_config.RELYT_PASSWORD or "",
database=dify_config.RELYT_DATABASE or "default",
),
group_id=dataset.id,
)

View File

@@ -2,10 +2,10 @@ import json
from typing import Any, Optional
from pydantic import BaseModel
from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter
from tcvectordb import VectorDBClient # type: ignore
from tcvectordb.model import document, enum # type: ignore
from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import Filter # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
@@ -25,8 +25,8 @@ class TencentConfig(BaseModel):
database: Optional[str]
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = (1,)
replicas: int = (2,)
shard: int = 1
replicas: int = 2
def to_tencent_params(self):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
@@ -120,15 +120,15 @@ class TencentVector(BaseVector):
metadatas = [doc.metadata for doc in documents]
total_count = len(embeddings)
docs = []
for id in range(0, total_count):
for i in range(0, total_count):
if metadatas is None:
continue
metadata = json.dumps(metadatas[id])
metadata = metadatas[i] or {}
doc = document.Document(
id=metadatas[id]["doc_id"],
vector=embeddings[id],
text=texts[id],
metadata=metadata,
id=metadata.get("doc_id"),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadata),
)
docs.append(doc)
self._db.collection(self._collection_name).upsert(docs, self._client_config.timeout)
@@ -159,8 +159,8 @@ class TencentVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []
def _get_search_res(self, res, score_threshold):
docs = []
def _get_search_res(self, res: list | None, score_threshold: float) -> list[Document]:
docs: list[Document] = []
if res is None or len(res) == 0:
return docs
@@ -193,7 +193,7 @@ class TencentVectorFactory(AbstractVectorFactory):
return TencentVector(
collection_name=collection_name,
config=TencentConfig(
url=dify_config.TENCENT_VECTOR_DB_URL,
url=dify_config.TENCENT_VECTOR_DB_URL or "",
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,

View File

@@ -54,7 +54,10 @@ class TidbOnQdrantConfig(BaseModel):
if self.endpoint and self.endpoint.startswith("path:"):
path = self.endpoint.replace("path:", "")
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
if self.root_path:
path = os.path.join(self.root_path, path)
else:
raise ValueError("root_path is required")
return {"path": path}
else:
@@ -157,7 +160,7 @@ class TidbOnQdrantVector(BaseVector):
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
metadatas = [d.metadata for d in documents if d.metadata is not None]
added_ids = []
for batch_ids, points in self._generate_rest_batches(texts, embeddings, metadatas, uuids, 64, self._group_id):
@@ -203,7 +206,7 @@ class TidbOnQdrantVector(BaseVector):
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
group_id or "",
Field.GROUP_KEY.value,
),
)
@@ -334,18 +337,20 @@ class TidbOnQdrantVector(BaseVector):
)
docs = []
for result in results:
if result.payload is None:
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
metadata=metadata,
)
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata["score"] if x.metadata is not None else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -427,12 +432,12 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
new_cluster = TidbService.create_tidb_serverless_cluster(
dify_config.TIDB_PROJECT_ID,
dify_config.TIDB_API_URL,
dify_config.TIDB_IAM_API_URL,
dify_config.TIDB_PUBLIC_KEY,
dify_config.TIDB_PRIVATE_KEY,
dify_config.TIDB_REGION,
dify_config.TIDB_PROJECT_ID or "",
dify_config.TIDB_API_URL or "",
dify_config.TIDB_IAM_API_URL or "",
dify_config.TIDB_PUBLIC_KEY or "",
dify_config.TIDB_PRIVATE_KEY or "",
dify_config.TIDB_REGION or "",
)
new_tidb_auth_binding = TidbAuthBinding(
cluster_id=new_cluster["cluster_id"],
@@ -464,9 +469,9 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
collection_name=collection_name,
group_id=dataset.id,
config=TidbOnQdrantConfig(
endpoint=dify_config.TIDB_ON_QDRANT_URL,
endpoint=dify_config.TIDB_ON_QDRANT_URL or "",
api_key=TIDB_ON_QDRANT_API_KEY,
root_path=config.root_path,
root_path=str(config.root_path),
timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT,
prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED,

View File

@@ -146,7 +146,7 @@ class TidbService:
iam_url: str,
public_key: str,
private_key: str,
) -> list[dict]:
):
"""
Update the status of a new TiDB Serverless cluster.
:param project_id: The project ID of the TiDB Cloud project (required).
@@ -159,7 +159,6 @@ class TidbService:
:return: The response from the API.
"""
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
@@ -169,7 +168,6 @@ class TidbService:
if response.status_code == 200:
response_data = response.json()
cluster_infos = []
for item in response_data["clusters"]:
state = item["state"]
userPrefix = item["userPrefix"]
@@ -236,16 +234,17 @@ class TidbService:
cluster_infos = []
for item in response_data["clusters"]:
cache_key = f"tidb_serverless_cluster_password:{item['displayName']}"
password = redis_client.get(cache_key)
if not password:
cached_password = redis_client.get(cache_key)
if not cached_password:
continue
cluster_info = {
"cluster_id": item["clusterId"],
"cluster_name": item["displayName"],
"account": "root",
"password": password.decode("utf-8"),
"password": cached_password.decode("utf-8"),
}
cluster_infos.append(cluster_info)
return cluster_infos
else:
response.raise_for_status()
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception

View File

@@ -49,7 +49,7 @@ class TiDBVector(BaseVector):
return VectorType.TIDB_VECTOR
def _table(self, dim: int) -> Table:
from tidb_vector.sqlalchemy import VectorType
from tidb_vector.sqlalchemy import VectorType # type: ignore
return Table(
self._collection_name,
@@ -241,11 +241,11 @@ class TiDBVectorFactory(AbstractVectorFactory):
return TiDBVector(
collection_name=collection_name,
config=TiDBVectorConfig(
host=dify_config.TIDB_VECTOR_HOST,
port=dify_config.TIDB_VECTOR_PORT,
user=dify_config.TIDB_VECTOR_USER,
password=dify_config.TIDB_VECTOR_PASSWORD,
database=dify_config.TIDB_VECTOR_DATABASE,
host=dify_config.TIDB_VECTOR_HOST or "",
port=dify_config.TIDB_VECTOR_PORT or 0,
user=dify_config.TIDB_VECTOR_USER or "",
password=dify_config.TIDB_VECTOR_PASSWORD or "",
database=dify_config.TIDB_VECTOR_DATABASE or "",
program_name=dify_config.APPLICATION_NAME,
),
)

View File

@@ -51,15 +51,16 @@ class BaseVector(ABC):
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
if text.metadata and "doc_id" in text.metadata:
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata["doc_id"] for text in texts]
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self):

View File

@@ -193,10 +193,13 @@ class Vector:
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts.copy():
if text.metadata is None:
continue
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
if doc_id:
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts

View File

@@ -2,7 +2,7 @@ import json
from typing import Any
from pydantic import BaseModel
from volcengine.viking_db import (
from volcengine.viking_db import ( # type: ignore
Data,
DistanceType,
Field,
@@ -121,11 +121,12 @@ class VikingDBVector(BaseVector):
for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
for key, val in metadatas[i].items():
for key, val in (metadatas[i] or {}).items():
metadata[key] = val
# FIXME: fix the type of metadata later
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], # type: ignore
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
@@ -178,7 +179,7 @@ class VikingDBVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
def _get_search_res(self, results, score_threshold):
def _get_search_res(self, results, score_threshold) -> list[Document]:
if len(results) == 0:
return []
@@ -191,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

View File

@@ -3,7 +3,7 @@ import json
from typing import Any, Optional
import requests
import weaviate
import weaviate # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -107,7 +107,8 @@ class WeaviateVector(BaseVector):
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
# metadata maybe None
for key, val in (metadatas[i] or {}).items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
@@ -208,10 +209,11 @@ class WeaviateVector(BaseVector):
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
docs = sorted(docs, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -275,7 +277,7 @@ class WeaviateVectorFactory(AbstractVectorFactory):
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=dify_config.WEAVIATE_ENDPOINT,
endpoint=dify_config.WEAVIATE_ENDPOINT or "",
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE,
),

View File

@@ -83,6 +83,9 @@ class DatasetDocumentStore:
if not isinstance(doc, Document):
raise ValueError("doc must be a Document")
if doc.metadata is None:
raise ValueError("doc.metadata must be a dict")
segment_document = self.get_document_segment(doc_id=doc.metadata["doc_id"])
# NOTE: doc could already exist in the store, but we overwrite it
@@ -179,10 +182,10 @@ class DatasetDocumentStore:
if document_segment is None:
return None
data: Optional[str] = document_segment.index_node_hash
return data
return document_segment.index_node_hash
def get_document_segment(self, doc_id: str) -> DocumentSegment:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)

View File

@@ -1,6 +1,6 @@
import base64
import logging
from typing import Optional, cast
from typing import Any, Optional, cast
import numpy as np
from sqlalchemy.exc import IntegrityError
@@ -27,7 +27,7 @@ class CacheEmbedding(Embeddings):
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs in batches of 10."""
# use doc embedding cache or store if not exists
text_embeddings = [None for _ in range(len(texts))]
text_embeddings: list[Any] = [None for _ in range(len(texts))]
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
@@ -64,7 +64,8 @@ class CacheEmbedding(Embeddings):
for vector in embedding_result.embeddings:
try:
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
# FIXME: type ignore for numpy here
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
@@ -77,8 +78,8 @@ class CacheEmbedding(Embeddings):
logging.exception("Failed transform embedding")
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = n_embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
@@ -86,7 +87,7 @@ class CacheEmbedding(Embeddings):
hash=hash,
provider_name=self._model_instance.provider,
)
embedding_cache.set_embedding(embedding)
embedding_cache.set_embedding(n_embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
db.session.commit()
@@ -115,7 +116,8 @@ class CacheEmbedding(Embeddings):
)
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
# FIXME: type ignore for numpy here
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:

View File

@@ -14,7 +14,7 @@ class NotionInfo(BaseModel):
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str
document: Document = None
document: Optional[Document] = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
import os
from typing import Optional
from typing import Optional, cast
import pandas as pd
from openpyxl import load_workbook
@@ -47,7 +47,7 @@ class ExcelExtractor(BaseExtractor):
for col_index, (k, v) in enumerate(row.items()):
if pd.notna(v):
cell = sheet.cell(
row=index + 2, column=col_index + 1
row=cast(int, index) + 2, column=col_index + 1
) # +2 to account for header and 1-based index
if cell.hyperlink:
value = f"[{v}]({cell.hyperlink.target})"
@@ -60,8 +60,8 @@ class ExcelExtractor(BaseExtractor):
elif file_extension == ".xls":
excel_file = pd.ExcelFile(self._file_path, engine="xlrd")
for sheet_name in excel_file.sheet_names:
df = excel_file.parse(sheet_name=sheet_name)
for excel_sheet_name in excel_file.sheet_names:
df = excel_file.parse(sheet_name=excel_sheet_name)
df.dropna(how="all", inplace=True)
for _, row in df.iterrows():

View File

@@ -10,6 +10,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.excel_extractor import ExcelExtractor
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
from core.rag.extractor.html_extractor import HtmlExtractor
from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor
@@ -66,9 +67,13 @@ class ExtractProcessor:
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
@@ -89,15 +94,20 @@ class ExtractProcessor:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
assert unstructured_api_url is not None, "unstructured_api_url is required"
assert unstructured_api_key is not None, "unstructured_api_key is required"
extractor: Optional[BaseExtractor] = None
if etl_type == "Unstructured":
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
@@ -156,6 +166,7 @@ class ExtractProcessor:
extractor = TextExtractor(file_path, autodetect_encoding=True)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.NOTION.value:
assert extract_setting.notion_info is not None, "notion_info is required"
extractor = NotionExtractor(
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
notion_obj_id=extract_setting.notion_info.notion_obj_id,
@@ -165,6 +176,7 @@ class ExtractProcessor:
)
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,

View File

@@ -1,5 +1,6 @@
import json
import time
from typing import cast
import requests
@@ -20,9 +21,9 @@ class FirecrawlApp:
json_data.update(params)
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
if response.status_code == 200:
response = response.json()
if response["success"] == True:
data = response["data"]
response_data = response.json()
if response_data["success"] == True:
data = response_data["data"]
return {
"title": data.get("metadata").get("title"),
"description": data.get("metadata").get("description"),
@@ -30,7 +31,7 @@ class FirecrawlApp:
"markdown": data.get("markdown"),
}
else:
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred")
@@ -46,9 +47,11 @@ class FirecrawlApp:
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
if response.status_code == 200:
job_id = response.json().get("jobId")
return job_id
return cast(str, job_id)
else:
self._handle_error(response, "start crawl job")
# FIXME: unreachable code for mypy
return "" # unreachable
def check_crawl_status(self, job_id) -> dict:
headers = self._prepare_headers()
@@ -64,9 +67,9 @@ class FirecrawlApp:
for item in data:
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
url_data = {
"title": item.get("metadata").get("title"),
"description": item.get("metadata").get("description"),
"source_url": item.get("metadata").get("sourceURL"),
"title": item.get("metadata", {}).get("title"),
"description": item.get("metadata", {}).get("description"),
"source_url": item.get("metadata", {}).get("sourceURL"),
"markdown": item.get("markdown"),
}
url_data_list.append(url_data)
@@ -92,6 +95,8 @@ class FirecrawlApp:
else:
self._handle_error(response, "check crawl status")
# FIXME: unreachable code for mypy
return {} # unreachable
def _prepare_headers(self):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}

View File

@@ -1,6 +1,6 @@
"""Abstract interface for document loader implementations."""
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
@@ -23,6 +23,7 @@ class HtmlExtractor(BaseExtractor):
return [Document(page_content=self._load_as_text())]
def _load_as_text(self) -> str:
text: str = ""
with open(self._file_path, "rb") as fp:
soup = BeautifulSoup(fp, "html.parser")
text = soup.get_text()

View File

@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional
from typing import Any, Optional, cast
import requests
@@ -78,6 +78,7 @@ class NotionExtractor(BaseExtractor):
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database."""
assert self._notion_access_token is not None, "Notion access token is required"
res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id),
headers={
@@ -96,6 +97,7 @@ class NotionExtractor(BaseExtractor):
for result in data["results"]:
properties = result["properties"]
data = {}
value: Any
for property_name, property_value in properties.items():
type = property_value["type"]
if type == "multi_select":
@@ -130,6 +132,7 @@ class NotionExtractor(BaseExtractor):
return [Document(page_content="\n".join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]:
assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=page_id)
@@ -184,6 +187,7 @@ class NotionExtractor(BaseExtractor):
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block."""
assert self._notion_access_token is not None, "Notion access token is required"
result_lines_arr = []
start_cursor = None
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=block_id)
@@ -242,6 +246,7 @@ class NotionExtractor(BaseExtractor):
def _read_table_rows(self, block_id: str) -> str:
"""Read table rows."""
assert self._notion_access_token is not None, "Notion access token is required"
done = False
result_lines_arr = []
start_cursor = None
@@ -296,7 +301,7 @@ class NotionExtractor(BaseExtractor):
result_lines = "\n".join(result_lines_arr)
return result_lines
def update_last_edited_time(self, document_model: DocumentModel):
def update_last_edited_time(self, document_model: Optional[DocumentModel]):
if not document_model:
return
@@ -309,6 +314,7 @@ class NotionExtractor(BaseExtractor):
db.session.commit()
def get_notion_last_edited_time(self) -> str:
assert self._notion_access_token is not None, "Notion access token is required"
obj_id = self._notion_obj_id
page_type = self._notion_page_type
if page_type == "database":
@@ -330,7 +336,7 @@ class NotionExtractor(BaseExtractor):
)
data = res.json()
return data["last_edited_time"]
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
@@ -349,4 +355,4 @@ class NotionExtractor(BaseExtractor):
f"and notion workspace {notion_workspace_id}"
)
return data_source_binding.access_token
return cast(str, data_source_binding.access_token)

View File

@@ -1,7 +1,7 @@
"""Abstract interface for document loader implementations."""
from collections.abc import Iterator
from typing import Optional
from typing import Optional, cast
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
try:
text = storage.load(self._file_cache_key).decode("utf-8")
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
@@ -53,7 +53,7 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)

View File

@@ -1,7 +1,7 @@
import base64
import logging
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document

View File

@@ -30,6 +30,9 @@ class UnstructuredEpubExtractor(BaseExtractor):
if self._api_url:
from unstructured.partition.api import partition_via_api
if self._api_key is None:
raise ValueError("api_key is required")
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
from unstructured.partition.epub import partition_epub

View File

@@ -27,9 +27,11 @@ class UnstructuredPPTExtractor(BaseExtractor):
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else:
raise NotImplementedError("Unstructured API Url is not configured")
text_by_page = {}
text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
if page is None:
continue
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text

View File

@@ -29,14 +29,15 @@ class UnstructuredPPTXExtractor(BaseExtractor):
from unstructured.partition.pptx import partition_pptx
elements = partition_pptx(filename=self._file_path)
text_by_page = {}
text_by_page: dict[int, str] = {}
for element in elements:
page = element.metadata.page_number
text = element.text
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
if page is not None:
if page in text_by_page:
text_by_page[page] += "\n" + text
else:
text_by_page[page] = text
combined_texts = list(text_by_page.values())
documents = []

View File

@@ -89,6 +89,8 @@ class WordExtractor(BaseExtractor):
response = ssrf_proxy.get(url)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
if image_ext is None:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
mime_type, _ = mimetypes.guess_type(file_key)
@@ -97,6 +99,8 @@ class WordExtractor(BaseExtractor):
continue
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
continue
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self.tenant_id + "/" + file_uuid + "." + image_ext
@@ -226,6 +230,8 @@ class WordExtractor(BaseExtractor):
if x_child is None:
continue
if x.tag.endswith("instrText"):
if x.text is None:
continue
for i in url_pattern.findall(x.text):
hyperlinks_url = str(i)
except Exception as e:

View File

@@ -49,6 +49,7 @@ class BaseIndexProcessor(ABC):
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule
rules = processing_rule["rules"]

View File

@@ -9,7 +9,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit."""
def __init__(self, index_type: str):
def __init__(self, index_type: str | None):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:

View File

@@ -27,12 +27,13 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
processing_rule=kwargs.get("process_rule", {}),
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule", {}))
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
@@ -41,8 +42,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = remove_leading_symbols(document_node.page_content).strip()
if len(page_content) > 0:

View File

@@ -32,15 +32,16 @@ class QAIndexProcessor(BaseIndexProcessor):
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule"), embedding_model_instance=kwargs.get("embedding_model_instance")
processing_rule=kwargs.get("process_rule") or {},
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
# Split the text documents into nodes.
all_documents = []
all_qa_documents = []
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule"))
document_text = CleanProcessor.clean(document.page_content, kwargs.get("process_rule") or {})
document.page_content = document_text
# parse document to nodes
@@ -50,8 +51,9 @@ class QAIndexProcessor(BaseIndexProcessor):
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
if document_node.metadata is not None:
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
@@ -64,7 +66,7 @@ class QAIndexProcessor(BaseIndexProcessor):
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
@@ -148,11 +150,12 @@ class QAIndexProcessor(BaseIndexProcessor):
qa_documents = []
for result in document_qa_list:
qa_document = Document(page_content=result["question"], metadata=document_node.metadata.copy())
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:

View File

@@ -30,7 +30,11 @@ class RerankModelRunner(BaseRerankRunner):
doc_ids = set()
unique_documents = []
for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
if (
document.provider == "dify"
and document.metadata is not None
and document.metadata["doc_id"] not in doc_ids
):
doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
@@ -54,7 +58,8 @@ class RerankModelRunner(BaseRerankRunner):
metadata=documents[result.index].metadata,
provider=documents[result.index].provider,
)
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
if rerank_document.metadata is not None:
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_documents

View File

@@ -39,7 +39,7 @@ class WeightRerankRunner(BaseRerankRunner):
unique_documents = []
doc_ids = set()
for document in documents:
if document.metadata["doc_id"] not in doc_ids:
if document.metadata is not None and document.metadata["doc_id"] not in doc_ids:
doc_ids.add(document.metadata["doc_id"])
unique_documents.append(document)
@@ -56,10 +56,11 @@ class WeightRerankRunner(BaseRerankRunner):
)
if score_threshold and score < score_threshold:
continue
document.metadata["score"] = score
rerank_documents.append(document)
if document.metadata is not None:
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents.sort(key=lambda x: x.metadata["score"], reverse=True)
rerank_documents.sort(key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
@@ -76,8 +77,9 @@ class WeightRerankRunner(BaseRerankRunner):
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
if document.metadata is not None:
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
@@ -162,7 +164,7 @@ class WeightRerankRunner(BaseRerankRunner):
query_vector = cache_embedding.embed_query(query)
for document in documents:
# calculate cosine similarity
if "score" in document.metadata:
if document.metadata and "score" in document.metadata:
query_vector_scores.append(document.metadata["score"])
else:
# transform to NumPy

View File

@@ -1,7 +1,7 @@
import math
import threading
from collections import Counter
from typing import Optional, cast
from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -34,7 +34,7 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -140,12 +140,12 @@ class DatasetRetrieval:
user_from,
available_datasets,
query,
retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.rerank_mode,
retrieve_config.top_k or 0,
retrieve_config.score_threshold or 0,
retrieve_config.rerank_mode or "reranking_model",
retrieve_config.reranking_model,
retrieve_config.weights,
retrieve_config.reranking_enabled,
retrieve_config.reranking_enabled or True,
message_id,
)
@@ -300,10 +300,11 @@ class DatasetRetrieval:
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
results.append(document)
else:
retrieval_model_config = dataset.retrieval_model or default_retrieval_model
@@ -325,7 +326,7 @@ class DatasetRetrieval:
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
score_threshold = retrieval_model_config.get("score_threshold", 0.0)
with measure_time() as timer:
results = RetrievalService.retrieve(
@@ -358,14 +359,14 @@ class DatasetRetrieval:
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
weights: Optional[dict[str, Any]] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
if not available_datasets:
return []
threads = []
all_documents = []
all_documents: list[Document] = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
@@ -392,15 +393,18 @@ class DatasetRetrieval:
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
if weights is not None:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
@@ -439,21 +443,22 @@ class DatasetRetrieval:
"""Handle retrieval end."""
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if document.metadata is not None:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager = (
trace_manager: Optional[TraceQueueManager] = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
@@ -504,10 +509,11 @@ class DatasetRetrieval:
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset_id
document.metadata["dataset_name"] = dataset.name
all_documents.append(document)
else:
# get retrieval model , if the model is not setting , using default
@@ -607,19 +613,20 @@ class DatasetRetrieval:
tools.append(tool)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
if retrieve_config.reranking_model is not None:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
tools.append(tool)
tools.append(tool)
return tools
@@ -635,10 +642,11 @@ class DatasetRetrieval:
query_keywords = keyword_table_handler.extract_keywords(query, None)
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
if document.metadata is not None:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
@@ -696,8 +704,9 @@ class DatasetRetrieval:
for document, score in zip(documents, similarities):
# format document
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
if document.metadata is not None:
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(
@@ -705,10 +714,12 @@ class DatasetRetrieval:
) -> list[Document]:
filter_documents = []
for document in all_documents:
if score_threshold is None or document.metadata["score"] >= score_threshold:
if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
filter_documents = sorted(
filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
)
return filter_documents[:top_k] if top_k else filter_documents

View File

@@ -1,7 +1,8 @@
from typing import Union
from typing import Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage
@@ -27,11 +28,14 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
result = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
),
)
if result.message.tool_calls:
# get retrieval model config

View File

@@ -1,9 +1,9 @@
from collections.abc import Generator, Sequence
from typing import Union
from typing import Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -92,6 +92,7 @@ class ReactMultiDatasetRouter:
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
) -> Union[str, None]:
prompt: Union[list[ChatModelMessage], CompletionModelPromptTemplate]
if model_config.mode == "chat":
prompt = self.create_chat_prompt(
query=query,
@@ -149,12 +150,15 @@ class ReactMultiDatasetRouter:
:param stop: stop
:return:
"""
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
invoke_result = cast(
Generator[LLMResult, None, None],
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=completion_param,
stop=stop,
stream=True,
user=user_id,
),
)
# handle invoke result
@@ -172,7 +176,7 @@ class ReactMultiDatasetRouter:
:return:
"""
model = None
prompt_messages = []
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
for result in invoke_result:

View File

@@ -26,8 +26,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder(
cls: type[TS],
embedding_model_instance: Optional[ModelInstance],
allowed_special: Union[Literal[all], Set[str]] = set(),
disallowed_special: Union[Literal[all], Collection[str]] = "all",
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
def _token_encoder(text: str) -> int:

View File

@@ -92,7 +92,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
texts, metadatas = [], []
for doc in documents:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
metadatas.append(doc.metadata or {})
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
@@ -143,7 +143,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase # type: ignore
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError("Tokenizer received was not an instance of PreTrainedTokenizerBase")