feat: remove llm client use (#1316)

This commit is contained in:
takatost
2023-10-12 03:02:53 +08:00
committed by GitHub
parent c007dbdc13
commit cbf095465c
14 changed files with 434 additions and 353 deletions

View File

@@ -1,6 +1,6 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -9,6 +9,7 @@ class LLMRunResult(BaseModel):
prompt_tokens: int
completion_tokens: int
source: list = None
function_call: dict = None
class MessageType(enum.Enum):
@@ -20,6 +21,7 @@ class MessageType(enum.Enum):
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]):
@@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]):
if message.type == MessageType.HUMAN:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
@@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]):
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
return prompt_messages

View File

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
@property
def base_model_name(self) -> str:

View File

@@ -13,7 +13,8 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
@@ -157,8 +158,11 @@ class BaseLLM(BaseProviderModel):
except Exception as ex:
raise self.handle_exceptions(ex)
function_call = None
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else:
completion_content = result.generations[0][0].text
@@ -191,7 +195,8 @@ class BaseLLM(BaseProviderModel):
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
completion_tokens=completion_tokens,
function_call=function_call
)
@abstractmethod
@@ -442,16 +447,7 @@ class BaseLLM(BaseProviderModel):
if len(messages) == 0:
return []
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
return to_lc_messages(messages)
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""

View File

@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""