mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
Add Azure AI Studio as provider (#7549)
Co-authored-by: Hélio Lúcio <canais.hlucio@voegol.com.br>
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
@@ -0,0 +1,17 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAIStudioProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,65 @@
|
||||
provider: azure_ai_studio
|
||||
label:
|
||||
zh_Hans: Azure AI Studio
|
||||
en_US: Azure AI Studio
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
description:
|
||||
en_US: Azure AI Studio
|
||||
zh_Hans: Azure AI Studio
|
||||
background: "#93c5fd"
|
||||
help:
|
||||
title:
|
||||
en_US: How to deploy customized model on Azure AI Studio
|
||||
zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
|
||||
url:
|
||||
en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
|
||||
zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
|
||||
supported_model_types:
|
||||
- llm
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: endpoint
|
||||
label:
|
||||
en_US: Azure AI Studio Endpoint
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请输入你的Azure AI Studio推理端点
|
||||
en_US: 'Enter your API Endpoint, eg: https://example.com'
|
||||
- variable: api_key
|
||||
required: true
|
||||
label:
|
||||
en_US: API Key
|
||||
zh_Hans: API Key
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Azure AI Studio API Key
|
||||
zh_Hans: 在此输入您的 Azure AI Studio API Key
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: jwt_token
|
||||
required: true
|
||||
label:
|
||||
en_US: JWT Token
|
||||
zh_Hans: JWT令牌
|
||||
type: secret-input
|
||||
placeholder:
|
||||
en_US: Enter your Azure AI Studio JWT Token
|
||||
zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: rerank
|
||||
@@ -0,0 +1,334 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from azure.ai.inference import ChatCompletionsClient
|
||||
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import (
|
||||
ClientAuthenticationError,
|
||||
DecodeError,
|
||||
DeserializationError,
|
||||
HttpResponseError,
|
||||
ResourceExistsError,
|
||||
ResourceModifiedError,
|
||||
ResourceNotFoundError,
|
||||
ResourceNotModifiedError,
|
||||
SerializationError,
|
||||
ServiceRequestError,
|
||||
ServiceResponseError,
|
||||
)
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Azure AI Studio large language model.
|
||||
"""
|
||||
|
||||
client: Any = None
|
||||
|
||||
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
if not self.client:
|
||||
endpoint = credentials.get("endpoint")
|
||||
api_key = credentials.get("api_key")
|
||||
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||
|
||||
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"max_tokens": model_parameters.get("max_tokens", 4096),
|
||||
"temperature": model_parameters.get("temperature", 0),
|
||||
"top_p": model_parameters.get("top_p", 1),
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if stop:
|
||||
payload["stop"] = stop
|
||||
|
||||
if tools:
|
||||
payload["tools"] = [tool.model_dump() for tool in tools]
|
||||
|
||||
try:
|
||||
response = self.client.complete(**payload)
|
||||
|
||||
if stream:
|
||||
return self._handle_stream_response(response, model, prompt_messages)
|
||||
else:
|
||||
return self._handle_non_stream_response(response, model, prompt_messages, credentials)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamingChatCompletionsUpdate):
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_non_stream_response(
|
||||
self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
|
||||
) -> LLMResult:
|
||||
assistant_text = response.choices[0].message.content
|
||||
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
|
||||
)
|
||||
result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
|
||||
|
||||
if hasattr(response, "system_fingerprint"):
|
||||
result.system_fingerprint = response.system_fingerprint
|
||||
|
||||
return result
|
||||
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Generator:
|
||||
"""
|
||||
Invoke result generator
|
||||
|
||||
:param result: result generator
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
prompt_message = AssistantPromptMessage(content="")
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
if isinstance(chunk, dict):
|
||||
content = chunk["choices"][0]["message"]["content"]
|
||||
usage = chunk["usage"]
|
||||
chunk = LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=content, tool_calls=[]),
|
||||
),
|
||||
system_fingerprint=chunk.get("system_fingerprint"),
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
self._trigger_new_chunk_callbacks(
|
||||
chunk=chunk,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
prompt_message.content += chunk.delta.message.content
|
||||
real_model = chunk.model
|
||||
if hasattr(chunk.delta, "usage"):
|
||||
usage = chunk.delta.usage
|
||||
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
model=real_model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=prompt_message,
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
system_fingerprint=system_fingerprint,
|
||||
),
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# Implement token counting logic here
|
||||
# Might need to use a tokenizer specific to the Azure AI Studio model
|
||||
return 0
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
endpoint = credentials.get("endpoint")
|
||||
api_key = credentials.get("api_key")
|
||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||
client.get_model_info()
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
ServiceRequestError,
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
ServiceResponseError,
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
ClientAuthenticationError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
HttpResponseError,
|
||||
DecodeError,
|
||||
ResourceExistsError,
|
||||
ResourceNotFoundError,
|
||||
ResourceModifiedError,
|
||||
ResourceNotModifiedError,
|
||||
SerializationError,
|
||||
DeserializationError,
|
||||
],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Used to define customizable model schema
|
||||
"""
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template="temperature",
|
||||
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
type=ParameterType.FLOAT,
|
||||
use_template="top_p",
|
||||
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
type=ParameterType.INT,
|
||||
use_template="max_tokens",
|
||||
min=1,
|
||||
default=512,
|
||||
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
|
||||
),
|
||||
]
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
model_properties={},
|
||||
parameter_rules=rules,
|
||||
)
|
||||
|
||||
return entity
|
||||
@@ -0,0 +1,164 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for Azure AI Studio rerank model.
|
||||
"""
|
||||
|
||||
def _allow_self_signed_https(self, allowed):
|
||||
# bypass the server certificate verification on client side
|
||||
if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
|
||||
# self._allow_self_signed_https(True) # Enable if using self-signed certificate
|
||||
|
||||
data = {"inputs": query_input, "docs": docs}
|
||||
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
|
||||
req = urllib.request.Request(endpoint, body, headers)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
result = response.read()
|
||||
return json.loads(result)
|
||||
except urllib.error.HTTPError as error:
|
||||
logger.error(f"The request failed with status code: {error.code}")
|
||||
logger.error(error.info())
|
||||
logger.error(error.read().decode("utf8", "ignore"))
|
||||
raise
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
try:
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
endpoint = credentials.get("endpoint")
|
||||
api_key = credentials.get("jwt_token")
|
||||
|
||||
if not endpoint or not api_key:
|
||||
raise ValueError("Azure endpoint and API key must be provided in credentials")
|
||||
|
||||
result = self._azure_rerank(query, docs, endpoint, api_key)
|
||||
logger.info(f"Azure rerank result: {result}")
|
||||
|
||||
rerank_documents = []
|
||||
for idx, (doc, score_dict) in enumerate(zip(docs, result)):
|
||||
score = score_dict["score"]
|
||||
rerank_document = RerankDocument(index=idx, text=doc, score=score)
|
||||
|
||||
if score_threshold is None or score >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
rerank_documents.sort(key=lambda x: x.score, reverse=True)
|
||||
|
||||
if top_n:
|
||||
rerank_documents = rerank_documents[:top_n]
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception in Azure rerank: {e}")
|
||||
raise
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [urllib.error.URLError],
|
||||
InvokeServerUnavailableError: [urllib.error.HTTPError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.RERANK,
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
)
|
||||
|
||||
return entity
|
||||
Reference in New Issue
Block a user