Feat/add free provider apply (#829)

This commit is contained in:
takatost
2023-08-14 12:44:35 +08:00
committed by GitHub
parent 42a417167f
commit cc52cdc2a9
5 changed files with 54 additions and 3 deletions

View File

@@ -3,7 +3,6 @@ import logging
from json import JSONDecodeError
from typing import Type
from flask import current_app
from langchain.schema import HumanMessage
from core.helper import encrypter

View File

@@ -50,6 +50,7 @@ class ChatSpark(BaseChatModel):
app_id: Optional[str] = None
api_key: Optional[str] = None
api_secret: Optional[str] = None
api_domain: Optional[str] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@@ -68,6 +69,7 @@ class ChatSpark(BaseChatModel):
app_id=values["app_id"],
api_key=values["api_key"],
api_secret=values["api_secret"],
api_domain=values.get('api_domain')
)
return values

View File

@@ -16,9 +16,9 @@ import websocket
class SparkLLMClient:
def __init__(self, app_id: str, api_key: str, api_secret: str):
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat"
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,