mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 06:16:53 +08:00
Feat/blocking function call (#2247)
This commit is contained in:
@@ -36,6 +36,7 @@ LLM_BASE_MODELS = [
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
@@ -80,6 +81,7 @@ LLM_BASE_MODELS = [
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
@@ -124,6 +126,7 @@ LLM_BASE_MODELS = [
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
@@ -198,6 +201,7 @@ LLM_BASE_MODELS = [
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
@@ -272,6 +276,7 @@ LLM_BASE_MODELS = [
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
|
||||
@@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> Generator:
|
||||
index = 0
|
||||
full_assistant_content = ''
|
||||
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
|
||||
real_model = model
|
||||
system_fingerprint = None
|
||||
completion = ''
|
||||
@@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
|
||||
delta.delta.function_call is None:
|
||||
continue
|
||||
|
||||
|
||||
# assistant_message_tool_calls = delta.delta.tool_calls
|
||||
assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if delta_assistant_message_function_call_storage is not None:
|
||||
# handle process of stream function call
|
||||
if assistant_message_function_call:
|
||||
# message has not ended ever
|
||||
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
|
||||
continue
|
||||
else:
|
||||
# message has ended
|
||||
assistant_message_function_call = delta_assistant_message_function_call_storage
|
||||
delta_assistant_message_function_call_storage = None
|
||||
else:
|
||||
if assistant_message_function_call:
|
||||
# start of stream function call
|
||||
delta_assistant_message_function_call_storage = assistant_message_function_call
|
||||
if delta_assistant_message_function_call_storage.arguments is None:
|
||||
delta_assistant_message_function_call_storage.arguments = ''
|
||||
continue
|
||||
|
||||
# extract tool calls from response
|
||||
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
@@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
@@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
||||
num_tokens = 0
|
||||
for tool in tools:
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(tool.get("type")))
|
||||
num_tokens += len(encoding.encode('function'))
|
||||
|
||||
# calculate num tokens for function object
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage)
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# check if last message is user message
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
|
||||
@@ -4,6 +4,8 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
||||
@@ -16,7 +16,7 @@ class MinimaxChatCompletion(object):
|
||||
"""
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
||||
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
|
||||
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
"""
|
||||
generate chat completion
|
||||
@@ -162,7 +162,6 @@ class MinimaxChatCompletion(object):
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
print(choice)
|
||||
message = choice['delta']
|
||||
yield MinimaxMessage(
|
||||
content=message,
|
||||
|
||||
@@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object):
|
||||
"""
|
||||
def generate(self, model: str, api_key: str, group_id: str,
|
||||
prompt_messages: List[MinimaxMessage], model_parameters: dict,
|
||||
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
|
||||
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
|
||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||
"""
|
||||
generate chat completion
|
||||
@@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object):
|
||||
**extra_kwargs
|
||||
}
|
||||
|
||||
if tools:
|
||||
body['functions'] = tools
|
||||
body['function_call'] = { 'type': 'auto' }
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||
@@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object):
|
||||
"""
|
||||
handle stream chat generate response
|
||||
"""
|
||||
function_call_storage = None
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
@@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object):
|
||||
msg = data['base_resp']['status_msg']
|
||||
self._handle_error(code, msg)
|
||||
|
||||
if data['reply']:
|
||||
if data['reply'] or 'usage' in data and data['usage']:
|
||||
total_tokens = data['usage']['total_tokens']
|
||||
message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
@@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object):
|
||||
'total_tokens': total_tokens
|
||||
}
|
||||
message.stop_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
if function_call_storage:
|
||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
function_call_message.function_call = function_call_storage
|
||||
yield function_call_message
|
||||
|
||||
yield message
|
||||
return
|
||||
|
||||
@@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object):
|
||||
continue
|
||||
|
||||
for choice in choices:
|
||||
message = choice['messages'][0]['text']
|
||||
if not message:
|
||||
continue
|
||||
message = choice['messages'][0]
|
||||
|
||||
if 'function_call' in message:
|
||||
if not function_call_storage:
|
||||
function_call_storage = message['function_call']
|
||||
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
|
||||
function_call_storage['arguments'] = ''
|
||||
continue
|
||||
else:
|
||||
function_call_storage['arguments'] += message['function_call']['arguments']
|
||||
continue
|
||||
else:
|
||||
if function_call_storage:
|
||||
message['function_call'] = function_call_storage
|
||||
function_call_storage = None
|
||||
|
||||
yield MinimaxMessage(
|
||||
content=message,
|
||||
role=MinimaxMessage.Role.ASSISTANT.value
|
||||
)
|
||||
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||
|
||||
if 'function_call' in message:
|
||||
minimax_message.function_call = message['function_call']
|
||||
|
||||
if 'text' in message:
|
||||
minimax_message.content = message['text']
|
||||
|
||||
yield minimax_message
|
||||
@@ -2,7 +2,7 @@ from typing import Generator, List
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||
SystemPromptMessage, UserPromptMessage)
|
||||
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
@@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
client: MinimaxChatCompletionPro = self.model_apis[model]()
|
||||
|
||||
if tools:
|
||||
tools = [{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters
|
||||
} for tool in tools]
|
||||
|
||||
response = client.generate(
|
||||
model=model,
|
||||
api_key=credentials['minimax_api_key'],
|
||||
@@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(prompt_message, UserPromptMessage):
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
if prompt_message.tool_calls:
|
||||
message = MinimaxMessage(
|
||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||
content=''
|
||||
)
|
||||
message.function_call={
|
||||
'name': prompt_message.tool_calls[0].function.name,
|
||||
'arguments': prompt_message.tool_calls[0].function.arguments
|
||||
}
|
||||
return message
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
|
||||
elif isinstance(prompt_message, ToolPromptMessage):
|
||||
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
|
||||
else:
|
||||
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')
|
||||
|
||||
@@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
),
|
||||
)
|
||||
elif message.function_call:
|
||||
if 'name' not in message.function_call or 'arguments' not in message.function_call:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[AssistantPromptMessage.ToolCall(
|
||||
id='',
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=message.function_call['name'],
|
||||
arguments=message.function_call['arguments']
|
||||
)
|
||||
)]
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
|
||||
@@ -7,13 +7,23 @@ class MinimaxMessage:
|
||||
USER = 'USER'
|
||||
ASSISTANT = 'BOT'
|
||||
SYSTEM = 'SYSTEM'
|
||||
FUNCTION = 'FUNCTION'
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: Dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
function_call: Dict[str, Any] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
|
||||
return {
|
||||
'sender_type': 'BOT',
|
||||
'sender_name': '专家',
|
||||
'text': '',
|
||||
'function_call': self.function_call
|
||||
}
|
||||
|
||||
return {
|
||||
'sender_type': self.role,
|
||||
'sender_name': '我' if self.role == 'USER' else '专家',
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16385
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16385
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16385
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
||||
@@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
|
||||
@@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
if message.name:
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
@@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
|
||||
SystemPromptMessage, UserPromptMessage)
|
||||
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
|
||||
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
|
||||
ParameterRule, ParameterType)
|
||||
ParameterRule, ParameterType, ModelFeature)
|
||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
|
||||
XinferenceModelExtraParameter)
|
||||
from core.model_runtime.utils import helper
|
||||
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,
|
||||
@@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
|
||||
"""
|
||||
if 'temperature' in model_parameters:
|
||||
if model_parameters['temperature'] < 0.01:
|
||||
model_parameters['temperature'] = 0.01
|
||||
elif model_parameters['temperature'] > 1.0:
|
||||
model_parameters['temperature'] = 0.99
|
||||
|
||||
return self._generate(
|
||||
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
|
||||
tools=tools, stop=stop, stream=stream, user=user,
|
||||
@@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
credentials['completion_type'] = 'completion'
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
|
||||
|
||||
if extra_param.support_function_call:
|
||||
credentials['support_function_call'] = True
|
||||
|
||||
except RuntimeError as e:
|
||||
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
|
||||
@@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
@@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
label=I18nObject(
|
||||
zh_Hans='温度',
|
||||
en_US='Temperature'
|
||||
)
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
@@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
completion_type = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
|
||||
|
||||
support_function_call = credentials.get('support_function_call', False)
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
@@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.TOOL_CALL
|
||||
] if support_function_call else [],
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: completion_type,
|
||||
},
|
||||
@@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
|
||||
"""
|
||||
if 'server_url' not in credentials:
|
||||
raise CredentialsValidateFailedError('server_url is required in credentials')
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
client = OpenAI(
|
||||
base_url=f'{credentials["server_url"]}/v1',
|
||||
api_key='abc',
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
|
||||
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
|
||||
@@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle
|
||||
|
||||
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
|
||||
|
||||
class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
@@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
|
||||
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
@@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
server_url = credentials['server_url']
|
||||
model_uid = credentials['model_uid']
|
||||
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvokeAuthorizationError:
|
||||
except (InvokeAuthorizationError, RuntimeError):
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
|
||||
@property
|
||||
@@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
@@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model_properties={},
|
||||
model_properties={
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
|
||||
},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from threading import Lock
|
||||
from time import time
|
||||
from typing import List
|
||||
from os import path
|
||||
|
||||
from requests import get
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object):
|
||||
model_format: str
|
||||
model_handle_type: str
|
||||
model_ability: List[str]
|
||||
max_tokens: int = 512
|
||||
support_function_call: bool = False
|
||||
|
||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
|
||||
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
|
||||
support_function_call: bool, max_tokens: int) -> None:
|
||||
self.model_format = model_format
|
||||
self.model_handle_type = model_handle_type
|
||||
self.model_ability = model_ability
|
||||
self.support_function_call = support_function_call
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
cache = {}
|
||||
cache_lock = Lock()
|
||||
@@ -49,7 +55,7 @@ class XinferenceHelper:
|
||||
get xinference model extra parameter like model_format and model_handle_type
|
||||
"""
|
||||
|
||||
url = f'{server_url}/v1/models/{model_uid}'
|
||||
url = path.join(server_url, 'v1/models', model_uid)
|
||||
|
||||
# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
|
||||
session = Session()
|
||||
@@ -66,10 +72,12 @@ class XinferenceHelper:
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
model_format = response_json['model_format']
|
||||
model_ability = response_json['model_ability']
|
||||
model_format = response_json.get('model_format', 'ggmlv3')
|
||||
model_ability = response_json.get('model_ability', [])
|
||||
|
||||
if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
||||
if response_json.get('model_type') == 'embedding':
|
||||
model_handle_type = 'embedding'
|
||||
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
|
||||
model_handle_type = 'chatglm'
|
||||
elif 'generate' in model_ability:
|
||||
model_handle_type = 'generate'
|
||||
@@ -78,8 +86,13 @@ class XinferenceHelper:
|
||||
else:
|
||||
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
|
||||
|
||||
support_function_call = 'tools' in model_ability
|
||||
max_tokens = response_json.get('max_tokens', 512)
|
||||
|
||||
return XinferenceModelExtraParameter(
|
||||
model_format=model_format,
|
||||
model_handle_type=model_handle_type,
|
||||
model_ability=model_ability
|
||||
model_ability=model_ability,
|
||||
support_function_call=support_function_call,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
@@ -2,6 +2,10 @@ model: glm-3-turbo
|
||||
label:
|
||||
en_US: glm-3-turbo
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
|
||||
@@ -2,6 +2,10 @@ model: glm-4
|
||||
label:
|
||||
en_US: glm-4
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
|
||||
@@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
'content': prompt_message.content,
|
||||
'tool_call_id': prompt_message.tool_call_id
|
||||
})
|
||||
elif isinstance(prompt_message, AssistantPromptMessage):
|
||||
if prompt_message.tool_calls:
|
||||
params['messages'].append({
|
||||
'role': 'assistant',
|
||||
'content': prompt_message.content,
|
||||
'tool_calls': [
|
||||
{
|
||||
'id': tool_call.id,
|
||||
'type': tool_call.type,
|
||||
'function': {
|
||||
'name': tool_call.function.name,
|
||||
'arguments': tool_call.function.arguments
|
||||
}
|
||||
} for tool_call in prompt_message.tool_calls
|
||||
]
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': 'assistant',
|
||||
'content': prompt_message.content
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': prompt_message.role.value,
|
||||
|
||||
Reference in New Issue
Block a user