feat: upgrade langchain (#430)

Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
John Wang
2023-06-25 16:49:14 +08:00
committed by GitHub
parent 1dee5de9b4
commit 3241e4015b
91 changed files with 2703 additions and 3153 deletions

View File

@@ -1,7 +1,6 @@
from typing import Union, Optional
from typing import Union, Optional, List
from langchain.callbacks import CallbackManager
from langchain.llms.fake import FakeListLLM
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
@@ -32,12 +31,11 @@ class LLMBuilder:
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
if model_name == 'fake':
return FakeListLLM(responses=[])
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == 'openai':
@@ -52,16 +50,21 @@ class LLMBuilder:
else:
raise ValueError(f"model name {model_name} is not supported.")
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
model_kwargs = {
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
}
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
return llm_cls(
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
top_p=kwargs.get('top_p', 1),
frequency_penalty=kwargs.get('frequency_penalty', 0),
presence_penalty=kwargs.get('presence_penalty', 0),
callback_manager=kwargs.get('callback_manager', None),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
@@ -69,7 +72,7 @@ class LLMBuilder:
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
@@ -82,7 +85,7 @@ class LLMBuilder:
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callback_manager=callback_manager
callbacks=callbacks
)
@classmethod

View File

@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
"""
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id.replace('.', '') if model_id else None
if model_id == 'text-embedding-ada-002':
config['deployment'] = model_id.replace('.', '') if model_id else None
else:
config['deployment_name'] = model_id.replace('.', '') if model_id else None
return config
def get_provider_name(self):

View File

@@ -1,3 +1,4 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop)
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop)
return await super().agenerate(messages, stop, callbacks, **kwargs)

View File

@@ -1,5 +1,4 @@
import os
from langchain.callbacks.manager import Callbacks
from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop)
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop)
return await super().agenerate(prompts, stop, callbacks, **kwargs)

View File

@@ -1,6 +1,7 @@
import os
from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
return message_tokens
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
)
chat_result = super()._generate(messages, stop)
result = LLMResult(
generations=[chat_result.generations],
llm_output=chat_result.llm_output
)
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(result, verbose=self.verbose)
return chat_result
@handle_llm_exceptions
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop)
return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(messages, stop)
return await super().agenerate(messages, stop, callbacks, **kwargs)

View File

@@ -1,5 +1,6 @@
import os
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None,
}}
@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(prompts, stop)
return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop)
return await super().agenerate(prompts, stop, callbacks, **kwargs)