mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,74 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
HOSTED_INFERENCE_API = 'hosted_inference_api'
|
||||
INFERENCE_ENDPOINTS = 'inference_endpoints'
|
||||
|
||||
|
||||
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
|
||||
client: Any
|
||||
model: str
|
||||
|
||||
huggingface_namespace: Optional[str] = None
|
||||
task_type: Optional[str] = None
|
||||
huggingfacehub_api_type: Optional[str] = None
|
||||
huggingfacehub_api_token: Optional[str] = None
|
||||
huggingfacehub_endpoint_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values['huggingfacehub_api_token'] = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values['client'] = InferenceClient(token=values['huggingfacehub_api_token'])
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
model = ''
|
||||
|
||||
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
|
||||
model = self.model
|
||||
else:
|
||||
model = self.huggingfacehub_endpoint_url
|
||||
|
||||
output = self.client.post(
|
||||
json={
|
||||
"inputs": texts,
|
||||
"options": {
|
||||
"wait_for_model": False,
|
||||
"use_cache": False
|
||||
}
|
||||
}, model=model)
|
||||
|
||||
embeddings = json.loads(output.decode())
|
||||
return self.mean_pooling(embeddings)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
||||
|
||||
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
|
||||
# Returned values are a list of floats, or a list of list of floats
|
||||
# (depending on if you sent a string or a list of string,
|
||||
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
|
||||
# This should be explained on the model's README.)
|
||||
def mean_pooling(self, embeddings: List) -> List[float]:
|
||||
# If automatic reduction by giving model, no need to mean_pooling.
|
||||
# For example one: List[List[float]]
|
||||
if not isinstance(embeddings[0][0], list):
|
||||
return embeddings
|
||||
|
||||
# For example two: List[List[List[float]]], need to mean_pooling.
|
||||
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
|
||||
return sentence_embeddings
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Wrapper around Jina embedding models."""
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around Jina embedding models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.invoke_embedding(text=text)
|
||||
embeddings.append(result)
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def invoke_embedding(self, text):
|
||||
params = {
|
||||
"model": self.model,
|
||||
"input": [
|
||||
text
|
||||
]
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
'https://api.jina.ai/v1/embeddings',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return json_response["data"][0]["embedding"]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Jina's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Wrapper around OpenLLM embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class OpenLLMEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around OpenLLM embedding models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to OpenLLM's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.invoke_embedding(text=text)
|
||||
embeddings.append(result)
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def invoke_embedding(self, text):
|
||||
params = [
|
||||
text
|
||||
]
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/embeddings',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return json_response[0]["embeddings"][0]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenLLM's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -1,99 +0,0 @@
|
||||
"""Wrapper around Replicate embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class ReplicateEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around Replicate embedding models.
|
||||
|
||||
To use, you should have the ``replicate`` python package installed.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
replicate_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
replicate_api_token = get_from_dict_or_env(
|
||||
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||
)
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
|
||||
values["client"] = replicate_python.Client(api_token=replicate_api_token)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import replicate python package. "
|
||||
"Please install it with `pip install replicate`."
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to Replicate's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = self.client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.client.run(self.model, input={first_input_name: text})
|
||||
embeddings.append(result[0].get('embedding'))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to Replicate's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = self.client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
result = self.client.run(self.model, input={first_input_name: text})
|
||||
embedding = result[0].get('embedding')
|
||||
|
||||
return list(map(float, embedding))
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from xinference_client.client.restful.restful_client import Client
|
||||
|
||||
|
||||
class XinferenceEmbeddings(Embeddings):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.server_url = server_url
|
||||
|
||||
self.model_uid = model_uid
|
||||
|
||||
self.client = Client(server_url)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embeddings = [
|
||||
model.create_embedding(text)["data"][0]["embedding"] for text in texts
|
||||
]
|
||||
vectors = [list(map(float, e)) for e in embeddings]
|
||||
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]
|
||||
|
||||
return normalized_vectors
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
embedding_res = model.create_embedding(text)
|
||||
|
||||
embedding = embedding_res["data"][0]["embedding"]
|
||||
|
||||
vector = list(map(float, embedding))
|
||||
normalized_vector = (vector / np.linalg.norm(vector)).tolist()
|
||||
|
||||
return normalized_vector
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Wrapper around ZhipuAI embedding models."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
|
||||
|
||||
|
||||
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around ZhipuAI embedding models.
|
||||
1024 dimensions.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model: str
|
||||
"""Model name to use."""
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
response = self.client.invoke(model=self.model, prompt=text)
|
||||
data = response["data"]
|
||||
embeddings.append(data.get('embedding'))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.schema import ChatMessage, BaseMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.utils import get_from_dict_or_env, check_package_version
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class AnthropicLLM(ChatAnthropic):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anthropic_api_key"] = get_from_dict_or_env(
|
||||
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anthropic_api_url",
|
||||
"ANTHROPIC_API_URL",
|
||||
default="https://api.anthropic.com",
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
check_package_version("anthropic", gte_version="0.3")
|
||||
values["client"] = anthropic.Anthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"],
|
||||
timeout=values["default_request_timeout"],
|
||||
max_retries=0
|
||||
)
|
||||
values["async_client"] = anthropic.AsyncAnthropic(
|
||||
base_url=values["anthropic_api_url"],
|
||||
api_key=values["anthropic_api_key"],
|
||||
timeout=values["default_request_timeout"],
|
||||
)
|
||||
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||
values["count_tokens"] = values["client"].count_tokens
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import anthropic python package. "
|
||||
"Please it install it with `pip install anthropic`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_text = f"{self.HUMAN_PROMPT} {message.content}"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_text = f"{self.AI_PROMPT} {message.content}"
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_text = f"{message.content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_text
|
||||
@@ -1,141 +0,0 @@
|
||||
from typing import Dict, Any, Optional, List, Tuple, Union, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_dict_to_message
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.schema import ChatResult, BaseMessage, ChatGeneration, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceAzureChatOpenAI(AzureChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
params = {**params, **kwargs}
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
role = "assistant"
|
||||
params["stream"] = True
|
||||
function_call: Optional[dict] = None
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=message_dicts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||
inner_completion += token
|
||||
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
|
||||
if _function_call:
|
||||
if function_call is None:
|
||||
function_call = _function_call
|
||||
else:
|
||||
function_call["arguments"] += _function_call["arguments"]
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token)
|
||||
message = _convert_dict_to_message(
|
||||
{
|
||||
"content": inner_completion,
|
||||
"role": role,
|
||||
"function_call": function_call,
|
||||
}
|
||||
)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
110
api/core/third_party/langchain/llms/azure_open_ai.py
vendored
110
api/core/third_party/langchain/llms/azure_open_ai.py
vendored
@@ -1,110 +0,0 @@
|
||||
from typing import Dict, Any, Mapping, Optional, List, Union, Tuple
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
||||
update_token_usage
|
||||
from langchain.schema import LLMResult
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceAzureOpenAI(AzureOpenAI):
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_version: str = ""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": self.openai_api_type,
|
||||
"api_base": self.openai_api_base,
|
||||
"api_version": self.openai_api_version,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = openai.generate(["Tell me a joke."])
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
||||
315
api/core/third_party/langchain/llms/baichuan_llm.py
vendored
315
api/core/third_party/langchain/llms/baichuan_llm.py
vendored
@@ -1,315 +0,0 @@
|
||||
"""Wrapper around Baichuan APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import Extra, root_validator, BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaichuanModelAPI(BaseModel):
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
base_url: str = "https://api.baichuan-ai.com/v1"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any):
|
||||
stream = 'stream' in kwargs and kwargs['stream']
|
||||
|
||||
url = self.base_url + ("/stream/chat" if stream else "/chat")
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60))
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
json_response = response.json()
|
||||
if json_response['code'] != 0:
|
||||
raise ValueError(
|
||||
f"API {json_response['code']}"
|
||||
f" error: {json_response['msg']}"
|
||||
)
|
||||
return json_response
|
||||
else:
|
||||
return response
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
class BaichuanChatLLM(BaseChatModel):
|
||||
"""Wrapper around Baichuan large language models.
|
||||
To use, you should pass the api_key as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "Baichuan2-53B"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.3
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.85
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "BAICHUAN_API_KEY"
|
||||
)
|
||||
|
||||
values["secret_key"] = get_from_dict_or_env(
|
||||
values, "secret_key", "BAICHUAN_SECRET_KEY"
|
||||
)
|
||||
|
||||
values['client'] = BaichuanModelAPI(
|
||||
api_key=values['api_key'],
|
||||
secret_key=values['secret_key']
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"parameters": {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "baichuan"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params.update(kwargs)
|
||||
response = self.client.do_request(**params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params.update(kwargs)
|
||||
|
||||
for event in self.client.do_request(stream=True, **params).iter_lines():
|
||||
if event:
|
||||
event = event.decode("utf-8")
|
||||
|
||||
meta = json.loads(event)
|
||||
|
||||
if meta['code'] != 0:
|
||||
raise ValueError(
|
||||
f"API {meta['code']}"
|
||||
f" error: {meta['msg']}"
|
||||
)
|
||||
|
||||
content = meta['data']['messages'][0]['content']
|
||||
|
||||
chunk_kwargs = {
|
||||
'message': AIMessageChunk(content=content),
|
||||
}
|
||||
|
||||
if 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
overall_token_usage = {
|
||||
'prompt_tokens': token_usage.get('prompt_tokens', 0),
|
||||
'completion_tokens': token_usage.get('answer_tokens', 0),
|
||||
'total_tokens': token_usage.get('total_tokens', 0)
|
||||
}
|
||||
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
|
||||
|
||||
yield ChatGenerationChunk(**chunk_kwargs)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
data = response["data"]
|
||||
generations = []
|
||||
for res in data["messages"]:
|
||||
message = self._convert_dict_to_message(res)
|
||||
gen = ChatGeneration(
|
||||
message=message
|
||||
)
|
||||
generations.append(gen)
|
||||
usage = response.get("usage")
|
||||
token_usage = {
|
||||
'prompt_tokens': usage.get('prompt_tokens', 0),
|
||||
'completion_tokens': usage.get('answer_tokens', 0),
|
||||
'total_tokens': usage.get('total_tokens', 0)
|
||||
}
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
|
||||
return {"token_usage": token_usage, "model_name": self.model}
|
||||
152
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
152
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
@@ -1,152 +0,0 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Optional, Union, Tuple, List, cast
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"api_type": 'openai',
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model, encoding = self._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
7
api/core/third_party/langchain/llms/fake.py
vendored
7
api/core/third_party/langchain/llms/fake.py
vendored
@@ -1,12 +1,10 @@
|
||||
import time
|
||||
from typing import List, Optional, Any, Mapping, Callable
|
||||
from typing import List, Optional, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
|
||||
|
||||
from core.model_providers.models.entity.message import str_to_prompt_messages
|
||||
|
||||
|
||||
class FakeLLM(SimpleChatModel):
|
||||
"""Fake ChatModel for testing purposes."""
|
||||
@@ -14,7 +12,6 @@ class FakeLLM(SimpleChatModel):
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
response: str
|
||||
num_token_func: Optional[Callable] = None
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@@ -35,7 +32,7 @@ class FakeLLM(SimpleChatModel):
|
||||
return {"response": self.response}
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return self.num_token_func(str_to_prompt_messages([text])) if self.num_token_func else 0
|
||||
return 0
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
from typing import Dict, Any, Optional, List, Iterable, Iterator
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.embeddings.huggingface_hub import VALID_TASKS
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HuggingFaceEndpointLLM(HuggingFaceEndpoint):
|
||||
"""HuggingFace Endpoint models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation` and `text2text-generation` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFaceEndpoint
|
||||
endpoint_url = (
|
||||
"https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud"
|
||||
)
|
||||
hf = HuggingFaceEndpoint(
|
||||
endpoint_url=endpoint_url,
|
||||
huggingfacehub_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
client: Any
|
||||
streaming: bool = False
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
|
||||
values['client'] = InferenceClient(values['endpoint_url'], token=huggingfacehub_api_token)
|
||||
|
||||
values["huggingfacehub_api_token"] = huggingfacehub_api_token
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to HuggingFace Hub's inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = hf("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
# payload samples
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
|
||||
# generation parameter
|
||||
gen_kwargs = {
|
||||
**params,
|
||||
'stop_sequences': stop
|
||||
}
|
||||
|
||||
response = self.client.text_generation(prompt, stream=self.streaming, details=True, **gen_kwargs)
|
||||
|
||||
if self.streaming and isinstance(response, Iterable):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_response(response, run_manager):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = response.generated_text
|
||||
|
||||
if self.task == "text-generation":
|
||||
text = completion
|
||||
# Remove prompt if included in generated text.
|
||||
if text.startswith(prompt):
|
||||
text = text[len(prompt) :]
|
||||
elif self.task == "text2text-generation":
|
||||
text = completion
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got invalid task {self.task}, "
|
||||
f"currently only {VALID_TASKS} are supported"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
response: Iterable,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Iterator[str]:
|
||||
for r in response:
|
||||
# skip special tokens
|
||||
if r.token.special:
|
||||
continue
|
||||
|
||||
token = r.token.text
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=None
|
||||
)
|
||||
|
||||
# yield the generated token
|
||||
yield token
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
|
||||
from huggingface_hub import HfApi, InferenceApi
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.huggingface_hub import VALID_TASKS
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class HuggingFaceHubLLM(HuggingFaceHub):
|
||||
"""HuggingFaceHub models.
|
||||
|
||||
To use, you should have the ``huggingface_hub`` python package installed, and the
|
||||
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Only supports `text-generation`, `text2text-generation` for now.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import HuggingFaceHub
|
||||
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
|
||||
"""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
huggingfacehub_api_token = get_from_dict_or_env(
|
||||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
|
||||
)
|
||||
client = InferenceApi(
|
||||
repo_id=values["repo_id"],
|
||||
token=huggingfacehub_api_token,
|
||||
task=values.get("task"),
|
||||
)
|
||||
client.options = {"wait_for_model": False, "use_gpu": False}
|
||||
values["client"] = client
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
hfapi = HfApi(token=self.huggingfacehub_api_token)
|
||||
model_info = hfapi.model_info(repo_id=self.repo_id)
|
||||
if not model_info:
|
||||
raise ValueError(f"Model {self.repo_id} not found.")
|
||||
|
||||
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
|
||||
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")
|
||||
|
||||
if model_info.pipeline_tag not in VALID_TASKS:
|
||||
raise ValueError(f"Model {self.repo_id} is not a valid task, "
|
||||
f"must be one of {VALID_TASKS}.")
|
||||
|
||||
return super()._call(prompt, stop, run_manager, **kwargs)
|
||||
283
api/core/third_party/langchain/llms/minimax_llm.py
vendored
283
api/core/third_party/langchain/llms/minimax_llm.py
vendored
@@ -1,283 +0,0 @@
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List, Tuple, Iterator
|
||||
|
||||
import requests
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatGenerationChunk, ChatGeneration
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator, Field, BaseModel
|
||||
|
||||
|
||||
class _MinimaxEndpointClient(BaseModel):
|
||||
"""An API client that talks to a Minimax llm endpoint."""
|
||||
|
||||
host: str
|
||||
group_id: str
|
||||
api_key: str
|
||||
api_url: str
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "api_url" not in values:
|
||||
host = values["host"]
|
||||
group_id = values["group_id"]
|
||||
api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}"
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, **request: Any) -> Any:
|
||||
stream = 'stream' in request and request['stream']
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(self.api_url, headers=headers, json=request, stream=stream, timeout=(5, 60))
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
if response.json()["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {response.json()['base_resp']['status_code']}"
|
||||
f" error: {response.json()['base_resp']['status_msg']}"
|
||||
)
|
||||
return response.json()
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
class MinimaxChatLLM(BaseChatModel):
|
||||
|
||||
_client: _MinimaxEndpointClient
|
||||
model: str = "abab5.5-chat"
|
||||
"""Model name to use."""
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
temperature: float = 0.7
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.95
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
minimax_api_host: Optional[str] = None
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_key: Optional[str] = None
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"minimax_api_key": "MINIMAX_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["minimax_api_key"] = get_from_dict_or_env(
|
||||
values, "minimax_api_key", "MINIMAX_API_KEY"
|
||||
)
|
||||
values["minimax_group_id"] = get_from_dict_or_env(
|
||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["minimax_api_host"] = get_from_dict_or_env(
|
||||
values,
|
||||
"minimax_api_host",
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient(
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
group_id=values["minimax_group_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"tokens_to_generate": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"role_meta": {"user_name": "我", "bot_name": "专家"},
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "minimax"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"sender_type": "USER", "text": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"sender_type": "BOT", "text": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _create_messages_and_prompt(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
prompt = ""
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, SystemMessage):
|
||||
if prompt:
|
||||
prompt += "\n"
|
||||
prompt += f"{m.content}"
|
||||
continue
|
||||
|
||||
message = self._convert_message_to_dict(m)
|
||||
dict_messages.append(message)
|
||||
|
||||
prompt = prompt if prompt else ' '
|
||||
|
||||
return dict_messages, prompt
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts, prompt = self._create_messages_and_prompt(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params["prompt"] = prompt
|
||||
params.update(kwargs)
|
||||
response = self._client.post(**params)
|
||||
return self._create_chat_result(response, stop)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, prompt = self._create_messages_and_prompt(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params["prompt"] = prompt
|
||||
params["stream"] = True
|
||||
params.update(kwargs)
|
||||
|
||||
for token in self._client.post(**params).iter_lines():
|
||||
if token:
|
||||
token = token.decode("utf-8")
|
||||
|
||||
if not token.startswith("data:"):
|
||||
data = json.loads(token)
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {data['base_resp']['status_code']}"
|
||||
f" error: {data['base_resp']['status_msg']}"
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
token = token.lstrip("data:").strip()
|
||||
data = json.loads(token)
|
||||
|
||||
if "base_resp" in data and data["base_resp"]["status_code"] > 0:
|
||||
raise ValueError(
|
||||
f"API {data['base_resp']['status_code']}"
|
||||
f" error: {data['base_resp']['status_msg']}"
|
||||
)
|
||||
|
||||
if not data['choices']:
|
||||
continue
|
||||
|
||||
content = data['choices'][0]['delta']
|
||||
|
||||
chunk_kwargs = {
|
||||
'message': AIMessageChunk(content=content),
|
||||
}
|
||||
|
||||
if 'usage' in data:
|
||||
token_usage = data['usage']
|
||||
overall_token_usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': token_usage.get('total_tokens', 0),
|
||||
'total_tokens': token_usage.get('total_tokens', 0)
|
||||
}
|
||||
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
|
||||
|
||||
yield ChatGenerationChunk(**chunk_kwargs)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any], stop: Optional[List[str]] = None) -> ChatResult:
|
||||
text = response['reply']
|
||||
if stop is not None:
|
||||
# This is required since the stop tokens
|
||||
# are not enforced by the model parameters
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
generations = [ChatGeneration(message=AIMessage(content=text))]
|
||||
usage = response.get("usage")
|
||||
|
||||
# only return total_tokens in minimax response
|
||||
token_usage = {
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': usage.get('total_tokens', 0),
|
||||
'total_tokens': usage.get('total_tokens', 0)
|
||||
}
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
|
||||
return {"token_usage": token_usage, "model_name": self.model}
|
||||
82
api/core/third_party/langchain/llms/open_ai.py
vendored
82
api/core/third_party/langchain/llms/open_ai.py
vendored
@@ -1,82 +0,0 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
|
||||
from langchain import OpenAI
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceOpenAI(OpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 1
|
||||
"""Maximum number of retries to make when generating."""
|
||||
|
||||
def __new__(cls, **data: Any): # type: ignore
|
||||
return super(EnhanceOpenAI, cls).__new__(cls)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
try:
|
||||
import openai
|
||||
|
||||
values["client"] = openai.Completion
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
)
|
||||
if values["streaming"] and values["n"] > 1:
|
||||
raise ValueError("Cannot stream results when n > 1.")
|
||||
if values["streaming"] and values["best_of"] > 1:
|
||||
raise ValueError("Cannot stream results when best_of > 1.")
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
return {**super()._invocation_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {**super()._identifying_params, **{
|
||||
"api_type": 'openai',
|
||||
"api_base": self.openai_api_base if self.openai_api_base
|
||||
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
||||
"api_version": None,
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = {**self._invocation_params, **kwargs, "stream": True}
|
||||
self.get_sub_prompts(params, [prompt], stop) # this mutates params
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=prompt, run_manager=run_manager, **params
|
||||
):
|
||||
if 'text' in stream_resp["choices"][0]:
|
||||
chunk = _stream_response_to_generation_chunk(stream_resp)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text,
|
||||
verbose=self.verbose,
|
||||
logprobs=chunk.generation_info["logprobs"]
|
||||
if chunk.generation_info
|
||||
else None,
|
||||
)
|
||||
86
api/core/third_party/langchain/llms/openllm.py
vendored
86
api/core/third_party/langchain/llms/openllm.py
vendored
@@ -1,86 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenLLM(LLM):
|
||||
"""OpenLLM, supporting both in-process model
|
||||
instance and remote OpenLLM servers.
|
||||
|
||||
If you have a OpenLLM server running, you can also use it remotely:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
llm = OpenLLM(server_url='http://localhost:3000')
|
||||
llm("What is the difference between a duck and a goose?")
|
||||
"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to be passed to openllm.LLM"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"llm_config": self.llm_kwargs,
|
||||
"stop": stop,
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/generate',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
completion = json_response["outputs"][0]['text']
|
||||
completion = completion.lstrip(prompt)
|
||||
|
||||
# if stop is not None:
|
||||
# completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError(
|
||||
"Async call is not supported for OpenLLM at the moment."
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
from typing import Dict, Optional, List, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Replicate
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class EnhanceReplicate(Replicate):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
replicate_api_token = get_from_dict_or_env(
|
||||
values, "replicate_api_token", "REPLICATE_API_TOKEN"
|
||||
)
|
||||
values["replicate_api_token"] = replicate_api_token
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to replicate endpoint."""
|
||||
try:
|
||||
import replicate as replicate_python
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import replicate python package. "
|
||||
"Please install it with `pip install replicate`."
|
||||
)
|
||||
|
||||
client = replicate_python.Client(api_token=self.replicate_api_token)
|
||||
|
||||
# get the model and version
|
||||
model_str, version_str = self.model.split(":")
|
||||
model = client.models.get(model_str)
|
||||
version = model.versions.get(version_str)
|
||||
|
||||
# sort through the openapi schema to get the name of the first input
|
||||
input_properties = sorted(
|
||||
version.openapi_schema["components"]["schemas"]["Input"][
|
||||
"properties"
|
||||
].items(),
|
||||
key=lambda item: item[1].get("x-order", 0),
|
||||
)
|
||||
first_input_name = input_properties[0][0]
|
||||
inputs = {first_input_name: prompt, **self.input}
|
||||
|
||||
prediction = client.predictions.create(
|
||||
version=version, input={**inputs, **kwargs}
|
||||
)
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
for output in prediction.output_iterator():
|
||||
current_completion += output
|
||||
|
||||
# test for stop conditions, if specified
|
||||
if stop:
|
||||
for s in stop:
|
||||
if s in current_completion:
|
||||
prediction.cancel()
|
||||
stop_index = current_completion.find(s)
|
||||
current_completion = current_completion[:stop_index]
|
||||
stop_condition_reached = True
|
||||
break
|
||||
|
||||
if stop_condition_reached:
|
||||
break
|
||||
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(output)
|
||||
return current_completion
|
||||
192
api/core/third_party/langchain/llms/spark.py
vendored
192
api/core/third_party/langchain/llms/spark.py
vendored
@@ -1,192 +0,0 @@
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
from _decimal import Decimal, ROUND_HALF_UP
|
||||
from typing import Dict, List, Optional, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
|
||||
ChatGeneration
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.third_party.spark.spark_llm import SparkLLMClient
|
||||
|
||||
|
||||
class ChatSpark(BaseChatModel):
|
||||
r"""Wrapper around Spark's large language model.
|
||||
|
||||
To use, you should pass `app_id`, `api_key`, `api_secret`
|
||||
as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
client = SparkLLMClient(
|
||||
model_name="<model_name>",
|
||||
app_id="<app_id>",
|
||||
api_key="<api_key>",
|
||||
api_secret="<api_secret>"
|
||||
)
|
||||
"""
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
model_name: str = "spark"
|
||||
"""The Spark model name."""
|
||||
|
||||
max_tokens: int = 256
|
||||
"""Denotes the number of tokens to predict per generation."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
|
||||
top_k: Optional[int] = None
|
||||
"""Number of most likely tokens to consider at each step."""
|
||||
|
||||
user_id: Optional[str] = None
|
||||
"""User ID to use for the model."""
|
||||
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results."""
|
||||
|
||||
app_id: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_secret: Optional[str] = None
|
||||
api_domain: Optional[str] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["app_id"] = get_from_dict_or_env(
|
||||
values, "app_id", "SPARK_APP_ID"
|
||||
)
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "SPARK_API_KEY"
|
||||
)
|
||||
values["api_secret"] = get_from_dict_or_env(
|
||||
values, "api_secret", "SPARK_API_SECRET"
|
||||
)
|
||||
|
||||
values["client"] = SparkLLMClient(
|
||||
model_name=values["model_name"],
|
||||
app_id=values["app_id"],
|
||||
api_key=values["api_key"],
|
||||
api_secret=values["api_secret"],
|
||||
api_domain=values.get('api_domain')
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Mapping[str, Any]:
|
||||
"""Get the default parameters for calling Anthropic API."""
|
||||
d = {
|
||||
"max_tokens": self.max_tokens
|
||||
}
|
||||
if self.temperature is not None:
|
||||
d["temperature"] = self.temperature
|
||||
if self.top_k is not None:
|
||||
d["top_k"] = self.top_k
|
||||
return d
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{}, **self._default_params}
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "spark-chat"
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
|
||||
"""Format a list of messages into a full dict list.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): List of BaseMessage to combine.
|
||||
|
||||
Returns:
|
||||
list[dict]
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, ChatMessage):
|
||||
new_messages.append({'role': 'user', 'content': message.content})
|
||||
elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
|
||||
new_messages.append({'role': 'user', 'content': message.content})
|
||||
elif isinstance(message, AIMessage):
|
||||
new_messages.append({'role': 'assistant', 'content': message.content})
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return new_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
messages = self._convert_messages_to_dicts(messages)
|
||||
|
||||
thread = threading.Thread(target=self.client.run, args=(
|
||||
messages,
|
||||
self.user_id,
|
||||
self._default_params,
|
||||
self.streaming
|
||||
))
|
||||
thread.start()
|
||||
|
||||
completion = ""
|
||||
for content in self.client.subscribe():
|
||||
if isinstance(content, dict):
|
||||
delta = content['data']
|
||||
else:
|
||||
delta = content
|
||||
|
||||
completion += delta
|
||||
if self.streaming and run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
delta,
|
||||
)
|
||||
|
||||
thread.join()
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
message = AIMessage(content=completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message = AIMessage(content='')
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
def get_num_tokens(self, text: str) -> float:
|
||||
"""Calculate number of tokens."""
|
||||
total = Decimal(0)
|
||||
words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
|
||||
for word in words:
|
||||
if word:
|
||||
if '\u4e00' <= word <= '\u9fff': # if chinese
|
||||
total += Decimal('1.5')
|
||||
else:
|
||||
total += Decimal('0.8')
|
||||
return int(total)
|
||||
@@ -1,82 +0,0 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Tongyi
|
||||
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
class EnhanceTongyi(Tongyi):
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"top_p": self.top_p,
|
||||
"api_key": self.dashscope_api_key
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = []
|
||||
params: Dict[str, Any] = {
|
||||
**{"model": self.model_name},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
if self.streaming:
|
||||
if len(prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
text = ''
|
||||
for stream_resp in stream_generate_with_retry(
|
||||
self, prompt=prompts[0], **params
|
||||
):
|
||||
if not generations:
|
||||
current_text = stream_resp["output"]["text"]
|
||||
else:
|
||||
current_text = stream_resp["output"]["text"][len(text):]
|
||||
|
||||
text = stream_resp["output"]["text"]
|
||||
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=current_text,
|
||||
generation_info=dict(
|
||||
finish_reason=stream_resp["output"]["finish_reason"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
current_text,
|
||||
verbose=self.verbose,
|
||||
logprobs=None,
|
||||
)
|
||||
else:
|
||||
for prompt in prompts:
|
||||
completion = generate_with_retry(
|
||||
self,
|
||||
prompt=prompt,
|
||||
**params,
|
||||
)
|
||||
generations.append(
|
||||
[
|
||||
Generation(
|
||||
text=completion["output"]["text"],
|
||||
generation_info=dict(
|
||||
finish_reason=completion["output"]["finish_reason"],
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
return LLMResult(generations=generations)
|
||||
319
api/core/third_party/langchain/llms/wenxin.py
vendored
319
api/core/third_party/langchain/llms/wenxin.py
vendored
@@ -1,319 +0,0 @@
|
||||
"""Wrapper around Wenxin APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator, Tuple,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _WenxinEndpointClient(BaseModel):
|
||||
"""An API client that talks to a Wenxin llm endpoint."""
|
||||
|
||||
base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
|
||||
secret_key: str
|
||||
api_key: str
|
||||
|
||||
def get_access_token(self) -> str:
|
||||
url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
|
||||
f"&client_secret={self.secret_key}&grant_type=client_credentials"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
|
||||
if 'error' in response.json():
|
||||
raise ValueError(
|
||||
f"Wenxin API {response.json()['error']}"
|
||||
f" error: {response.json()['error_description']}"
|
||||
)
|
||||
|
||||
access_token = response.json()['access_token']
|
||||
|
||||
# todo add cache
|
||||
|
||||
return access_token
|
||||
|
||||
def post(self, request: dict) -> Any:
|
||||
if 'model' not in request:
|
||||
raise ValueError(f"Wenxin Model name is required")
|
||||
|
||||
model_url_map = {
|
||||
'ernie-bot-4': 'completions_pro',
|
||||
'ernie-bot': 'completions',
|
||||
'ernie-bot-turbo': 'eb-instant',
|
||||
'bloomz-7b': 'bloomz_7b1',
|
||||
}
|
||||
|
||||
stream = 'stream' in request and request['stream']
|
||||
|
||||
access_token = self.get_access_token()
|
||||
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
|
||||
del request['model']
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(api_url,
|
||||
headers=headers,
|
||||
json=request,
|
||||
stream=stream)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
json_response = response.json()
|
||||
if 'error_code' in json_response:
|
||||
raise ValueError(
|
||||
f"Wenxin API {json_response['error_code']}"
|
||||
f" error: {json_response['error_msg']}"
|
||||
)
|
||||
return json_response
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
class Wenxin(BaseChatModel):
|
||||
"""Wrapper around Wenxin large language models."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
_client: _WenxinEndpointClient = PrivateAttr()
|
||||
model: str = "ernie-bot"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.7
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.95
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "WENXIN_API_KEY"
|
||||
)
|
||||
values["secret_key"] = get_from_dict_or_env(
|
||||
values, "secret_key", "WENXIN_SECRET_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
**self.model_kwargs,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model": self.model}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "wenxin"
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self._client = _WenxinEndpointClient(
|
||||
api_key=self.api_key,
|
||||
secret_key=self.secret_key,
|
||||
)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
dict_messages = []
|
||||
system = None
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if message['role'] == 'system':
|
||||
if not system:
|
||||
system = message['content']
|
||||
else:
|
||||
system += f"\n{message['content']}"
|
||||
continue
|
||||
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages, system
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts, system = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["messages"] = message_dicts
|
||||
if system:
|
||||
request["system"] = system
|
||||
request.update(kwargs)
|
||||
response = self._client.post(request)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, system = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["messages"] = message_dicts
|
||||
if system:
|
||||
request["system"] = system
|
||||
request.update(kwargs)
|
||||
|
||||
for token in self._client.post(request).iter_lines():
|
||||
if token:
|
||||
token = token.decode("utf-8")
|
||||
|
||||
if token.startswith('data:'):
|
||||
completion = json.loads(token[5:])
|
||||
|
||||
chunk_dict = {
|
||||
'message': AIMessageChunk(content=completion['result']),
|
||||
}
|
||||
|
||||
if completion['is_end']:
|
||||
token_usage = completion['usage']
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
||||
chunk_dict['generation_info'] = dict({'token_usage': token_usage})
|
||||
|
||||
yield ChatGenerationChunk(**chunk_dict)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(completion['result'])
|
||||
else:
|
||||
try:
|
||||
json_response = json.loads(token)
|
||||
except JSONDecodeError:
|
||||
raise ValueError(f"Wenxin Response Error {token}")
|
||||
|
||||
raise ValueError(
|
||||
f"Wenxin API {json_response['error_code']}"
|
||||
f" error: {json_response['error_msg']}, "
|
||||
f"please confirm if the model you have chosen is already paid for."
|
||||
)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
generations = [ChatGeneration(
|
||||
message=AIMessage(content=response['result']),
|
||||
)]
|
||||
token_usage = response.get("usage")
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
@@ -1,196 +0,0 @@
|
||||
from typing import Optional, List, Any, Union, Generator, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
RESTfulChatglmCppChatModelHandle,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulGenerateModelHandle, Client,
|
||||
)
|
||||
|
||||
|
||||
class XinferenceLLM(LLM):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
}
|
||||
)
|
||||
|
||||
if self.server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = Client(server_url)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "xinference"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Returns:
|
||||
The generated string by the model.
|
||||
"""
|
||||
model = self.client.get_model(self.model_uid)
|
||||
|
||||
if isinstance(model, RESTfulChatModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
return completion["choices"][0]["message"]["content"]
|
||||
elif isinstance(model, RESTfulGenerateModelHandle):
|
||||
generate_config: "LlamaCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if stop:
|
||||
generate_config["stop"] = stop
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
|
||||
else:
|
||||
completion = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
return completion["choices"][0]["text"]
|
||||
elif isinstance(model, RESTfulChatglmCppChatModelHandle):
|
||||
generate_config: "ChatglmCppGenerateConfig" = kwargs.get(
|
||||
"generate_config", {}
|
||||
)
|
||||
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
else:
|
||||
completion = model.chat(prompt=prompt, generate_config=generate_config)
|
||||
completion = completion["choices"][0]["message"]["content"]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union[
|
||||
"RESTfulGenerateModelHandle",
|
||||
"RESTfulChatModelHandle",
|
||||
"RESTfulChatglmCppChatModelHandle",
|
||||
],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[
|
||||
Union[
|
||||
"LlamaCppGenerateConfig",
|
||||
"PytorchGenerateConfig",
|
||||
"ChatglmCppGenerateConfig",
|
||||
]
|
||||
] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
model: The model used for generation.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
generate_config: Optional dictionary for the configuration used for
|
||||
generation.
|
||||
|
||||
Yields:
|
||||
A string token.
|
||||
"""
|
||||
if isinstance(
|
||||
model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
|
||||
):
|
||||
streaming_response = model.chat(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
else:
|
||||
streaming_response = model.generate(
|
||||
prompt=prompt, generate_config=generate_config
|
||||
)
|
||||
|
||||
for chunk in streaming_response:
|
||||
if isinstance(chunk, dict):
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
if "text" in choice:
|
||||
token = choice.get("text", "")
|
||||
elif "delta" in choice and "content" in choice["delta"]:
|
||||
token = choice.get("delta").get("content")
|
||||
else:
|
||||
continue
|
||||
log_probs = choice.get("logprobs")
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield token
|
||||
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
315
api/core/third_party/langchain/llms/zhipuai_llm.py
vendored
@@ -1,315 +0,0 @@
|
||||
"""Wrapper around ZhipuAI APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import posixpath
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator, Sequence,
|
||||
)
|
||||
|
||||
import zhipuai
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import Extra, root_validator, BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils import jwt_token
|
||||
from zhipuai.utils.http_client import post, stream
|
||||
from zhipuai.utils.sse_client import SSEClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZhipuModelAPI(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
api_timeout_seconds = 60
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
if not response['success']:
|
||||
raise ValueError(
|
||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||
)
|
||||
return response
|
||||
|
||||
def sse_invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
return SSEClient(data)
|
||||
|
||||
def _build_api_url(self, kwargs, *path):
|
||||
if kwargs:
|
||||
if "model" not in kwargs:
|
||||
raise Exception("model param missed")
|
||||
model = kwargs.pop("model")
|
||||
else:
|
||||
model = "-"
|
||||
|
||||
return posixpath.join(self.base_url, model, *path)
|
||||
|
||||
def _generate_token(self):
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"api_key not provided, you could provide it."
|
||||
)
|
||||
|
||||
try:
|
||||
return jwt_token.generate_token(self.api_key)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Your api_key is invalid, please check it."
|
||||
)
|
||||
|
||||
|
||||
class ZhipuAIChatLLM(BaseChatModel):
|
||||
"""Wrapper around ZhipuAI large language models.
|
||||
To use, you should pass the api_key as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from core.third_party.langchain.llms.zhipuai import ZhipuAI
|
||||
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "chatglm_turbo"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.95
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.7
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "ZHIPUAI_API_KEY"
|
||||
)
|
||||
|
||||
if 'test' in values['base_url']:
|
||||
values['model'] = 'chatglm_130b_test'
|
||||
|
||||
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "zhipuai"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
continue
|
||||
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
response = self.client.invoke(**request)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
request = self._default_params
|
||||
request["prompt"] = message_dicts
|
||||
request.update(kwargs)
|
||||
|
||||
for event in self.client.sse_invoke(incremental=True, **request).events():
|
||||
if event.event == "add":
|
||||
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(event.data)
|
||||
elif event.event == "error" or event.event == "interrupted":
|
||||
raise ValueError(
|
||||
f"{event.data}"
|
||||
)
|
||||
elif event.event == "finish":
|
||||
meta = json.loads(event.meta)
|
||||
token_usage = meta['usage']
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
yield ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=event.data),
|
||||
generation_info=dict({'token_usage': token_usage})
|
||||
)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
data = response["data"]
|
||||
generations = []
|
||||
for res in data["choices"]:
|
||||
message = self._convert_dict_to_message(res)
|
||||
gen = ChatGeneration(
|
||||
message=message
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = data.get("usage")
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
# def get_token_ids(self, text: str) -> List[int]:
|
||||
# """Return the ordered ids of the tokens in a text.
|
||||
#
|
||||
# Args:
|
||||
# text: The string input to tokenize.
|
||||
#
|
||||
# Returns:
|
||||
# A list of ids corresponding to the tokens in the text, in order they occur
|
||||
# in the text.
|
||||
# """
|
||||
# from core.third_party.transformers.Token import ChatGLMTokenizer
|
||||
#
|
||||
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
|
||||
# return tokenizer.encode(text)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
for k, v in token_usage.items():
|
||||
if k in overall_token_usage:
|
||||
overall_token_usage[k] += v
|
||||
else:
|
||||
overall_token_usage[k] = v
|
||||
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||
Reference in New Issue
Block a user