Add Azure AI Studio as provider (#7549)

Co-authored-by: Hélio Lúcio <canais.hlucio@voegol.com.br>
This commit is contained in:
Hélio Lúcio
2024-08-26 22:52:59 -03:00
committed by GitHub
parent 162faee4f2
commit 7b7576ad55
15 changed files with 1157 additions and 1 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -0,0 +1,17 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class AzureAIStudioProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

View File

@@ -0,0 +1,65 @@
provider: azure_ai_studio
label:
zh_Hans: Azure AI Studio
en_US: Azure AI Studio
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
description:
en_US: Azure AI Studio
zh_Hans: Azure AI Studio
background: "#93c5fd"
help:
title:
en_US: How to deploy customized model on Azure AI Studio
zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
url:
en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
supported_model_types:
- llm
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint
label:
en_US: Azure AI Studio Endpoint
type: text-input
required: true
placeholder:
zh_Hans: 请输入你的Azure AI Studio推理端点
en_US: 'Enter your API Endpoint, eg: https://example.com'
- variable: api_key
required: true
label:
en_US: API Key
zh_Hans: API Key
type: secret-input
placeholder:
en_US: Enter your Azure AI Studio API Key
zh_Hans: 在此输入您的 Azure AI Studio API Key
show_on:
- variable: __model_type
value: llm
- variable: jwt_token
required: true
label:
en_US: JWT Token
zh_Hans: JWT令牌
type: secret-input
placeholder:
en_US: Enter your Azure AI Studio JWT Token
zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
show_on:
- variable: __model_type
value: rerank

View File

@@ -0,0 +1,334 @@
import logging
from collections.abc import Generator
from typing import Any, Optional, Union
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import StreamingChatCompletionsUpdate
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import (
ClientAuthenticationError,
DecodeError,
DeserializationError,
HttpResponseError,
ResourceExistsError,
ResourceModifiedError,
ResourceNotFoundError,
ResourceNotModifiedError,
SerializationError,
ServiceRequestError,
ServiceResponseError,
)
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
I18nObject,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
"""
Model class for Azure AI Studio large language model.
"""
client: Any = None
from azure.ai.inference.models import StreamingChatCompletionsUpdate
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
if not self.client:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
payload = {
"messages": messages,
"max_tokens": model_parameters.get("max_tokens", 4096),
"temperature": model_parameters.get("temperature", 0),
"top_p": model_parameters.get("top_p", 1),
"stream": stream,
}
if stop:
payload["stop"] = stop
if tools:
payload["tools"] = [tool.model_dump() for tool in tools]
try:
response = self.client.complete(**payload)
if stream:
return self._handle_stream_response(response, model, prompt_messages)
else:
return self._handle_non_stream_response(response, model, prompt_messages, credentials)
except Exception as e:
raise self._transform_invoke_error(e)
def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
for chunk in response:
if isinstance(chunk, StreamingChatCompletionsUpdate):
if chunk.choices:
delta = chunk.choices[0].delta
if delta.content:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
),
)
def _handle_non_stream_response(
self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
) -> LLMResult:
assistant_text = response.choices[0].message.content
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
usage = self._calc_response_usage(
model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
)
result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
if hasattr(response, "system_fingerprint"):
result.system_fingerprint = response.system_fingerprint
return result
def _invoke_result_generator(
self,
model: str,
result: Generator,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator:
"""
Invoke result generator
:param result: result generator
:return: result generator
"""
callbacks = callbacks or []
prompt_message = AssistantPromptMessage(content="")
usage = None
system_fingerprint = None
real_model = model
try:
for chunk in result:
if isinstance(chunk, dict):
content = chunk["choices"][0]["message"]["content"]
usage = chunk["usage"]
chunk = LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=content, tool_calls=[]),
),
system_fingerprint=chunk.get("system_fingerprint"),
)
yield chunk
self._trigger_new_chunk_callbacks(
chunk=chunk,
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
prompt_message.content += chunk.delta.message.content
real_model = chunk.model
if hasattr(chunk.delta, "usage"):
usage = chunk.delta.usage
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
# Implement token counting logic here
# Might need to use a tokenizer specific to the Azure AI Studio model
return 0
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
endpoint = credentials.get("endpoint")
api_key = credentials.get("api_key")
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
client.get_model_info()
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
ServiceRequestError,
],
InvokeServerUnavailableError: [
ServiceResponseError,
],
InvokeAuthorizationError: [
ClientAuthenticationError,
],
InvokeBadRequestError: [
HttpResponseError,
DecodeError,
ResourceExistsError,
ResourceNotFoundError,
ResourceModifiedError,
ResourceNotModifiedError,
SerializationError,
DeserializationError,
],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Used to define customizable model schema
"""
rules = [
ParameterRule(
name="temperature",
type=ParameterType.FLOAT,
use_template="temperature",
label=I18nObject(zh_Hans="温度", en_US="Temperature"),
),
ParameterRule(
name="top_p",
type=ParameterType.FLOAT,
use_template="top_p",
label=I18nObject(zh_Hans="Top P", en_US="Top P"),
),
ParameterRule(
name="max_tokens",
type=ParameterType.INT,
use_template="max_tokens",
min=1,
default=512,
label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
),
]
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
features=[],
model_properties={},
parameter_rules=rules,
)
return entity

View File

@@ -0,0 +1,164 @@
import json
import logging
import os
import ssl
import urllib.request
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
logger = logging.getLogger(__name__)
class AzureRerankModel(RerankModel):
"""
Model class for Azure AI Studio rerank model.
"""
def _allow_self_signed_https(self, allowed):
# bypass the server certificate verification on client side
if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
ssl._create_default_https_context = ssl._create_unverified_context
def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
# self._allow_self_signed_https(True) # Enable if using self-signed certificate
data = {"inputs": query_input, "docs": docs}
body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
req = urllib.request.Request(endpoint, body, headers)
try:
with urllib.request.urlopen(req) as response:
result = response.read()
return json.loads(result)
except urllib.error.HTTPError as error:
logger.error(f"The request failed with status code: {error.code}")
logger.error(error.info())
logger.error(error.read().decode("utf8", "ignore"))
raise
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
try:
if len(docs) == 0:
return RerankResult(model=model, docs=[])
endpoint = credentials.get("endpoint")
api_key = credentials.get("jwt_token")
if not endpoint or not api_key:
raise ValueError("Azure endpoint and API key must be provided in credentials")
result = self._azure_rerank(query, docs, endpoint, api_key)
logger.info(f"Azure rerank result: {result}")
rerank_documents = []
for idx, (doc, score_dict) in enumerate(zip(docs, result)):
score = score_dict["score"]
rerank_document = RerankDocument(index=idx, text=doc, score=score)
if score_threshold is None or score >= score_threshold:
rerank_documents.append(rerank_document)
rerank_documents.sort(key=lambda x: x.score, reverse=True)
if top_n:
rerank_documents = rerank_documents[:top_n]
return RerankResult(model=model, docs=rerank_documents)
except Exception as e:
logger.exception(f"Exception in Azure rerank: {e}")
raise
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [urllib.error.URLError],
InvokeServerUnavailableError: [urllib.error.HTTPError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={},
parameter_rules=[],
)
return entity