feat: add zhipuai (#1188)

This commit is contained in:
takatost
2023-09-18 17:32:31 +08:00
committed by GitHub
parent c8bd76cd66
commit 827c97f0d3
36 changed files with 1087 additions and 111 deletions

View File

@@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
class ZhipuAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ZhipuAIEmbeddings(
model=name,
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

View File

@@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
precision: Optional[int] = None
class ModelKwargsRules(BaseModel):

View File

@@ -0,0 +1,61 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
class ZhipuAIModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -23,14 +23,18 @@ class OpenAIModeration(BaseModeration):
# 2000 text per chunk
length = 2000
chunks = [text[i:i + length] for i in range(0, len(text), length)]
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
moderation_result = self._client.create(input=chunks,
api_key=credentials['openai_api_key'])
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for result in moderation_result.results:
if result['flagged'] is True:
return False
for text_chunk in chunks:
moderation_result = self._client.create(input=text_chunk,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True