chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -21,8 +21,9 @@ class TokenBufferMemory:
self.conversation = conversation
self.model_instance = model_instance
def get_history_prompt_messages(self, max_token_limit: int = 2000,
message_limit: Optional[int] = None) -> list[PromptMessage]:
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
@@ -31,16 +32,11 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
query = db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id
).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc())
query = (
db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id)
.filter(Message.conversation_id == self.conversation.id, Message.answer != "")
.order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
message_limit = message_limit if message_limit <= 500 else 500
@@ -50,10 +46,7 @@ class TokenBufferMemory:
messages = query.limit(message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(
tenant_id=app_record.tenant_id,
app_id=app_record.id
)
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = []
for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
@@ -63,20 +56,17 @@ class TokenBufferMemory:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
else:
if message.workflow_run_id:
workflow_run = (db.session.query(WorkflowRun)
.filter(WorkflowRun.id == message.workflow_run_id).first())
workflow_run = (
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
)
if workflow_run:
file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict,
is_vision=False
workflow_run.workflow.features_dict, is_vision=False
)
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
else:
file_objs = []
@@ -97,24 +87,23 @@ class TokenBufferMemory:
return []
# prune the chat message if it exceeds the max token limit
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
if curr_message_tokens > max_token_limit:
pruned_memory = []
while curr_message_tokens > max_token_limit and len(prompt_messages)>1:
while curr_message_tokens > max_token_limit and len(prompt_messages) > 1:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(self, human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None) -> str:
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
@@ -123,10 +112,7 @@ class TokenBufferMemory:
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit
)
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages: