feat: [backend] vision support (#1510)

Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
takatost
2023-11-13 22:05:46 +08:00
committed by GitHub
parent d0e1ea8f06
commit 41d0a8b295
61 changed files with 1563 additions and 300 deletions

View File

@@ -3,7 +3,7 @@ import logging
import threading
import time
import uuid
from typing import Generator, Union, Any, Optional
from typing import Generator, Union, Any, Optional, List
from flask import current_app, Flask
from redis.client import PubSub
@@ -12,9 +12,11 @@ from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.file.message_file_parser import MessageFileParser
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.models.entity.message import PromptMessageFile
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
@@ -35,6 +37,9 @@ class CompletionService:
# is streaming mode
inputs = args['inputs']
query = args['query']
files = args['files'] if 'files' in args and args['files'] else []
auto_generate_name = args['auto_generate_name'] \
if 'auto_generate_name' in args else True
if app_model.mode != 'completion' and not query:
raise ValueError('query is required')
@@ -132,6 +137,14 @@ class CompletionService:
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
app_model_config,
user
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -146,17 +159,20 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': query,
'inputs': inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -172,10 +188,12 @@ class CompletionService:
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, detached_user: Union[Account, EndUser],
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
retriever_from: str = 'dev', auto_generate_name: bool = True):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
@@ -195,10 +213,12 @@ class CompletionService:
query=query,
inputs=inputs,
user=user,
files=files,
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
@@ -215,7 +235,8 @@ class CompletionService:
db.session.commit()
@classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
@@ -274,6 +295,12 @@ class CompletionService:
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -288,11 +315,13 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': message.query,
'inputs': message.inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
'retriever_from': retriever_from,
'auto_generate_name': False
})
generate_worker_thread.start()
@@ -388,7 +417,8 @@ class CompletionService:
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
yield "data: " + json.dumps(
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':