mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
feat: [backend] vision support (#1510)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any, Optional
|
||||
from typing import Generator, Union, Any, Optional, List
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
@@ -12,9 +12,11 @@ from sqlalchemy import and_
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.models.entity.message import PromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
@@ -35,6 +37,9 @@ class CompletionService:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
auto_generate_name = args['auto_generate_name'] \
|
||||
if 'auto_generate_name' in args else True
|
||||
|
||||
if app_model.mode != 'completion' and not query:
|
||||
raise ValueError('query is required')
|
||||
@@ -132,6 +137,14 @@ class CompletionService:
|
||||
# clean input by app_model_config form rules
|
||||
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
app_model_config,
|
||||
user
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
@@ -146,17 +159,20 @@ class CompletionService:
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@@ -172,10 +188,12 @@ class CompletionService:
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, detached_user: Union[Account, EndUser],
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev'):
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
@@ -195,10 +213,12 @@ class CompletionService:
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
files=files,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
@@ -215,7 +235,8 @@ class CompletionService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
|
||||
generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
@@ -274,6 +295,12 @@ class CompletionService:
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
@@ -288,11 +315,13 @@ class CompletionService:
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from
|
||||
'retriever_from': retriever_from,
|
||||
'auto_generate_name': False
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
@@ -388,7 +417,8 @@ class CompletionService:
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
|
||||
Reference in New Issue
Block a user