mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: server multi models support (#799)
This commit is contained in:
@@ -2,16 +2,17 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import flask
|
||||
from flask_login import login_required, current_user
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Unauthorized, Forbidden
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants.model_template import model_templates, demo_model_templates
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
@@ -126,9 +127,9 @@ class AppListApi(Resource):
|
||||
if args['model_config'] is not None:
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=args['model_config'],
|
||||
mode=args['mode']
|
||||
config=args['model_config']
|
||||
)
|
||||
|
||||
app = App(
|
||||
@@ -164,6 +165,21 @@ class AppListApi(Resource):
|
||||
app = App(**model_config_template['app'])
|
||||
app_model_config = AppModelConfig(**model_config_template['model_config'])
|
||||
|
||||
default_model = ModelFactory.get_default_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.TEXT_GENERATION
|
||||
)
|
||||
|
||||
if default_model:
|
||||
model_dict = app_model_config.model_dict
|
||||
model_dict['provider'] = default_model.provider_name
|
||||
model_dict['name'] = default_model.model_name
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
else:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Text Generation Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
app.name = args['name']
|
||||
app.mode = args['mode']
|
||||
app.icon = args['icon']
|
||||
|
||||
@@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \
|
||||
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from flask_restful import Resource
|
||||
from services.audio_service import AudioService
|
||||
|
||||
@@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value
|
||||
from flask_restful import Resource, reqparse
|
||||
@@ -41,8 +41,11 @@ class CompletionMessageApi(Resource):
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
@@ -51,7 +54,7 @@ class CompletionMessageApi(Resource):
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
||||
@@ -111,8 +114,11 @@ class ChatMessageApi(Resource):
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
try:
|
||||
@@ -121,7 +127,7 @@ class ChatMessageApi(Resource):
|
||||
user=account,
|
||||
args=args,
|
||||
from_source='console',
|
||||
streaming=True,
|
||||
streaming=streaming,
|
||||
is_model_config_override=True
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni
|
||||
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
||||
@@ -28,9 +28,9 @@ class ModelConfigResource(Resource):
|
||||
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
account=current_user,
|
||||
config=request.json,
|
||||
mode=app_model.mode
|
||||
config=request.json
|
||||
)
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
|
||||
Reference in New Issue
Block a user