feat: add api-based extension & external data tool & moderation backend (#1403)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
Garfield Dai
2023-11-06 19:36:16 +08:00
committed by GitHub
parent 7699621983
commit db43ed6f41
50 changed files with 1624 additions and 273 deletions

View File

@@ -1,13 +1,25 @@
import logging
from typing import Any, Dict, List, Union
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
try:
self.conversation_message_task.append_message_text(token)
except ConversationTaskStoppedException as ex:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
self.llm_message.completion += token
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None