mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-16 06:16:53 +08:00
feat: hf inference endpoint stream support (#1028)
This commit is contained in:
@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
|
||||
result = self._run(
|
||||
messages=messages,
|
||||
stop=stop,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
|
||||
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as ex:
|
||||
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
if self.streaming and not self.support_streaming():
|
||||
if self.streaming and not self.support_streaming:
|
||||
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
|
||||
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
|
||||
fake_llm = FakeLLM(
|
||||
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
|
||||
else:
|
||||
self.client.callbacks.extend(callbacks)
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
def get_prompt(self, mode: str,
|
||||
|
||||
@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
streaming = self.streaming
|
||||
|
||||
if 'baichuan' in self.name.lower():
|
||||
streaming = False
|
||||
|
||||
client = HuggingFaceEndpointLLM(
|
||||
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
|
||||
callbacks=self.callbacks
|
||||
callbacks=self.callbacks,
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
@property
|
||||
def support_streaming(self):
|
||||
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
|
||||
@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return False
|
||||
|
||||
@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Xinference: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user