chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -1,4 +1,3 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
@@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
@@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@@ -44,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@@ -53,14 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -71,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@@ -130,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -140,16 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
stream=stream,
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -161,16 +152,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
raise ValueError("node_id is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
@@ -178,18 +166,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
conversation_id=None,
inputs={},
query='',
query="",
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={
"auto_generate_conversation_name": False
},
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -199,17 +184,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
stream=stream
stream=stream,
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
def _generate(
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -225,10 +212,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
is_first_conversation = True
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@@ -243,18 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread.start()
@@ -269,17 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@@ -302,7 +289,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
message=message,
)
runner.run()
@@ -310,14 +297,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
@@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
@@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '')
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
self.msg_text = ""
except Exception as e:
self.logger.warning(e)

View File

@@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
:param application_generate_entity: application generate entity
@@ -66,11 +66,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
@@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
@@ -98,26 +98,27 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
@@ -190,12 +191,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> bool:
"""
Handle input moderation
@@ -217,18 +218,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
message_id=message_id,
)
except ModerationException as e:
self._complete_with_stream_output(
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True
return False
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
@@ -246,32 +243,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
if annotation_reply:
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
self._complete_with_stream_output(
text=annotation_reply.content,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)
self._publish_event(QueueStopEvent(stopped_by=stopped_by))

View File

@@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@@ -65,6 +65,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
@@ -72,14 +73,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -123,13 +124,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
@@ -147,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras['metadata'] = stream_response.metadata
extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
@@ -158,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
@@ -176,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@@ -187,17 +187,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
@@ -228,12 +231,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -267,22 +270,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@@ -293,7 +292,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
@@ -304,62 +303,52 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
@@ -372,20 +361,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
@@ -399,11 +384,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
@@ -420,8 +404,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
# Save message
@@ -434,8 +417,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
@@ -445,8 +429,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
@@ -472,7 +457,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
@@ -502,8 +487,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
@@ -523,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
extras=self._application_generate_entity.extras,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -533,15 +519,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata.copy()
extras["metadata"] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
@@ -555,14 +539,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else: