mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-24 10:13:01 +08:00
Feat/Agent-Image-Processing (#3293)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -547,6 +547,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
if user:
|
||||
extra_model_kwargs['user'] = user
|
||||
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# chat model
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
@@ -757,6 +760,31 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
|
||||
return tool_call
|
||||
|
||||
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Clear illegal prompt messages for OpenAI API
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:return: cleaned prompt messages
|
||||
"""
|
||||
checklist = ['gpt-4-turbo', 'gpt-4-turbo-2024-04-09']
|
||||
|
||||
if model in checklist:
|
||||
# count how many user messages are there
|
||||
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
||||
if user_message_count > 1:
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
item.data if item.type == PromptMessageContentType.TEXT else
|
||||
'[IMAGE]' if item.type == PromptMessageContentType.IMAGE else ''
|
||||
for item in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for OpenAI API
|
||||
|
||||
Reference in New Issue
Block a user