feat: backend model load balancing support (#4927)

This commit is contained in:
takatost
2024-06-05 00:13:04 +08:00
committed by GitHub
parent 52ec152dd3
commit d1dbbc1e33
47 changed files with 2191 additions and 256 deletions

View File

@@ -1,11 +1,10 @@
from collections.abc import Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import func
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
@@ -95,11 +94,7 @@ class DatasetDocumentStore:
# calc embedding use tokens
if embedding_model:
model_type_instance = embedding_model.model_type_instance
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
tokens = model_type_instance.get_num_tokens(
model=embedding_model.model,
credentials=embedding_model.credentials,
tokens = embedding_model.get_text_embedding_num_tokens(
texts=[doc.page_content]
)
else:

View File

@@ -1,10 +1,9 @@
"""Functionality for splitting text."""
from __future__ import annotations
from typing import Any, Optional, cast
from typing import Any, Optional
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.rag.splitter.text_splitter import (
TS,
@@ -35,11 +34,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return 0
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
return embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
return embedding_model_instance.get_text_embedding_num_tokens(
texts=[text]
)
else: