mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-15 22:06:52 +08:00
feat: upgrade langchain (#430)
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, List
|
||||
|
||||
from langchain.callbacks import CallbackManager
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from core.constant import llm_constant
|
||||
from core.llm.error import ProviderTokenNotInitError
|
||||
@@ -32,12 +31,11 @@ class LLMBuilder:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]:
|
||||
if model_name == 'fake':
|
||||
return FakeListLLM(responses=[])
|
||||
|
||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
provider = cls.get_default_provider(tenant_id)
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
mode = cls.get_mode_by_model(model_name)
|
||||
if mode == 'chat':
|
||||
if provider == 'openai':
|
||||
@@ -52,16 +50,21 @@ class LLMBuilder:
|
||||
else:
|
||||
raise ValueError(f"model name {model_name} is not supported.")
|
||||
|
||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||
|
||||
model_kwargs = {
|
||||
'top_p': kwargs.get('top_p', 1),
|
||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||
}
|
||||
|
||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
||||
|
||||
return llm_cls(
|
||||
model_name=model_name,
|
||||
temperature=kwargs.get('temperature', 0),
|
||||
max_tokens=kwargs.get('max_tokens', 256),
|
||||
top_p=kwargs.get('top_p', 1),
|
||||
frequency_penalty=kwargs.get('frequency_penalty', 0),
|
||||
presence_penalty=kwargs.get('presence_penalty', 0),
|
||||
callback_manager=kwargs.get('callback_manager', None),
|
||||
**model_extras_kwargs,
|
||||
callbacks=kwargs.get('callbacks', None),
|
||||
streaming=kwargs.get('streaming', False),
|
||||
# request_timeout=None
|
||||
**model_credentials
|
||||
@@ -69,7 +72,7 @@ class LLMBuilder:
|
||||
|
||||
@classmethod
|
||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||
callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||
model_name = model.get("name")
|
||||
completion_params = model.get("completion_params", {})
|
||||
|
||||
@@ -82,7 +85,7 @@ class LLMBuilder:
|
||||
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
|
||||
presence_penalty=completion_params.get('presence_penalty', 0.1),
|
||||
streaming=streaming,
|
||||
callback_manager=callback_manager
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -42,7 +42,10 @@ class AzureProvider(BaseProvider):
|
||||
"""
|
||||
config = self.get_provider_api_key(model_id=model_id)
|
||||
config['openai_api_type'] = 'azure'
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
if model_id == 'text-embedding-ada-002':
|
||||
config['deployment'] = model_id.replace('.', '') if model_id else None
|
||||
else:
|
||||
config['deployment_name'] = model_id.replace('.', '') if model_id else None
|
||||
return config
|
||||
|
||||
def get_provider_name(self):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
@@ -68,60 +69,22 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages],
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Mapping, Any
|
||||
@@ -53,12 +52,20 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import BaseMessage, LLMResult
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
@@ -70,57 +71,22 @@ class StreamableChatOpenAI(ChatOpenAI):
|
||||
|
||||
return message_tokens
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
|
||||
)
|
||||
|
||||
chat_result = super()._generate(messages, stop)
|
||||
|
||||
result = LLMResult(
|
||||
generations=[chat_result.generations],
|
||||
llm_output=chat_result.llm_output
|
||||
)
|
||||
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
else:
|
||||
self.callback_manager.on_llm_end(result, verbose=self.verbose)
|
||||
|
||||
return chat_result
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(messages, stop)
|
||||
return super().generate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||
self,
|
||||
messages: List[List[BaseMessage]],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(messages, stop)
|
||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Any, Mapping
|
||||
from langchain import OpenAI
|
||||
@@ -48,15 +49,22 @@ class StreamableOpenAI(OpenAI):
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}}
|
||||
|
||||
|
||||
@handle_llm_exceptions
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return super().generate(prompts, stop)
|
||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
@handle_llm_exceptions_async
|
||||
async def agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
return await super().agenerate(prompts, stop)
|
||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user