diff --git a/api/core/embedding/embedding_constant.py b/api/core/entities/embedding_type.py similarity index 100% rename from api/core/embedding/embedding_constant.py rename to api/core/entities/embedding_type.py diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 74b445236..e394233d2 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Generator, Sequence from typing import IO, Optional, Union, cast -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index a948dca20..2d38fba95 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import ConfigDict -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 8701a3805..c45ce87ea 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np import tiktoken from openai import AzureOpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 56b9be1c3..1ace68d2b 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index d9c572659..2f998d8bd 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -13,7 +13,7 @@ from botocore.exceptions import ( UnknownServiceError, ) -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 4da208069..5fd4d637b 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ import cohere import numpy as np from cohere.core import RequestOptions -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index cdce69ff3..c745a7e97 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional, Union import numpy as np from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index b2e6d1b65..8278d1e64 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import requests from huggingface_hub import HfApi, InferenceClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index b8ff3ca54..6b4393453 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 75701ebc5..b6d857cb3 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index b39712951..49c558f4a 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index ab8ca76c2..b4dfc1a4d 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from yarl import URL -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index d031bfa04..29be5888a 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py index 68b7b448b..ca949cb95 100644 --- a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 857dfb5f4..56a707333 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from nomic import embed from nomic import login as nomic_login -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index 936ceb8dd..04363e11b 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 4de9296cc..50fa63768 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import oci -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5cf3f1c6f..a16c91cd7 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 16f1a0cfa..bec01fe67 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import tiktoken from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 64fa6aaa3..c2b7297aa 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index c5d433091..43a2e948e 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 1e86f351c..d78bdaa75 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 9f724a77a..c4e9d0b9c 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from replicate import Client as ReplicateClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 8f993ce67..ae7d805b4 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Any, Optional import boto3 -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c5dcc1261..5e29a4827 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( OAICompatEmbeddingModel, diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 736cd44df..2ef7f3f57 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import dashscope import numpy as np -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index b6509cd26..7dd495b55 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np from openai import OpenAI from tokenizers import Tokenizer -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index fce9544df..43233e612 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from google.cloud import aiplatform from google.oauth2 import service_account from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 0dd4037c9..4d13e4708 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -2,7 +2,7 @@ import time from decimal import Decimal from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index a8a4d3c15..e69c9fccb 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index c21d0c055..19135deb2 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from typing import Any, Optional import numpy as np from requests import Response, post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ddc21b365..f64b9c50a 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 5a34a3d59..f629b62fd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from zhipuai import ZhipuAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index b1d6f93cf..992415657 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: @@ -47,11 +47,12 @@ class DataPostProcessor: tenant_id: str, reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - ) -> Optional[RerankModelRunner | WeightRerankRunner]: + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( vector_weight=weights["vector_setting"]["vector_weight"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], @@ -62,23 +63,33 @@ class DataPostProcessor: ), ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], - model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None + + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d3fd0c672..3affbd2d0 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 6dcd98dcf..c77cb8737 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ _import_err_msg = ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 543cfa67b..1d4bfef76 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 610aa498a..a9e1486ed 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f420373d5..052a18722 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index bdca59f86..080a1ef56 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b30aa7ca2..1fca926a2 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from clickhouse_connect import get_client from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 8d2e0a86a..0e0f10726 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 84a4381cd..4ced5d61e 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a82a9b96d..9233cd63d 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 6f336d27e..40a9cdd13 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ import psycopg2.pool from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f418e3ca0..69d2aa4f7 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.http.models import ( from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 13a63784b..f373dcfea 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 39e3a7f6c..f971a9c5e 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7837c5a4a..1147e35ce 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 873b28902..fb956a16e 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 5f60f10ac..4f927f289 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -14,11 +14,11 @@ from volcengine.viking_db import ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4009efe7a..649cfbfea 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ import weaviate from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py similarity index 97% rename from api/core/embedding/cached_embedding.py rename to api/core/rag/embedding/cached_embedding.py index 31d2171e7..b3e93ce76 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,11 +6,11 @@ import numpy as np from sqlalchemy.exc import IntegrityError from configs import dify_config -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/embedding/embedding_base.py similarity index 90% rename from api/core/rag/datasource/entity/embedding.py rename to api/core/rag/embedding/embedding_base.py index 126c1a372..9f232ab91 100644 --- a/api/core/rag/datasource/entity/embedding.py +++ b/api/core/rag/embedding/embedding_base.py @@ -7,10 +7,12 @@ class Embeddings(ABC): @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + raise NotImplementedError @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" + raise NotImplementedError async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 000000000..818b04b2f --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 000000000..1a3cf8573 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 27f86aed3..40ebf0bef 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,9 +2,10 @@ from typing import Optional from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/rerank_type.py similarity index 100% rename from api/core/rag/rerank/constants/rerank_mode.py rename to api/core/rag/rerank/rerank_type.py diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 16d6b879a..2e3fbe04e 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights