feat: hf inference endpoint stream support (#1028)

This commit is contained in:
takatost
2023-08-26 19:48:34 +08:00
committed by GitHub
parent 6c148b223d
commit 0796791de5
14 changed files with 128 additions and 40 deletions

View File

@@ -75,7 +75,7 @@ class AnthropicModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -141,6 +141,6 @@ class AzureOpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -138,7 +138,7 @@ class BaseLLM(BaseProviderModel):
result = self._run(
messages=messages,
stop=stop,
callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
**kwargs
)
except Exception as ex:
@@ -149,7 +149,7 @@ class BaseLLM(BaseProviderModel):
else:
completion_content = result.generations[0][0].text
if self.streaming and not self.support_streaming():
if self.streaming and not self.support_streaming:
# use FakeLLM to simulate streaming when current model not support streaming but streaming is True
prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
fake_llm = FakeLLM(
@@ -298,8 +298,8 @@ class BaseLLM(BaseProviderModel):
else:
self.client.callbacks.extend(callbacks)
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return False
def get_prompt(self, mode: str,

View File

@@ -61,7 +61,3 @@ class ChatGLMModel(BaseLLM):
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
else:
return ex
@classmethod
def support_streaming(cls):
return False

View File

@@ -17,12 +17,18 @@ class HuggingfaceHubModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
streaming = self.streaming
if 'baichuan' in self.name.lower():
streaming = False
client = HuggingFaceEndpointLLM(
endpoint_url=self.credentials['huggingfacehub_endpoint_url'],
task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs,
huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'],
callbacks=self.callbacks
callbacks=self.callbacks,
streaming=streaming
)
else:
client = HuggingFaceHub(
@@ -76,7 +82,10 @@ class HuggingfaceHubModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface Hub: {str(ex)}")
@classmethod
def support_streaming(cls):
return False
@property
def support_streaming(self):
if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints':
if 'baichuan' in self.name.lower():
return False
return True

View File

@@ -154,8 +154,8 @@ class OpenAIModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True
# def is_model_valid_or_raise(self):

View File

@@ -63,7 +63,3 @@ class OpenLLMModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -91,6 +91,6 @@ class ReplicateModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -65,6 +65,6 @@ class SparkModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
return True
@property
def support_streaming(self):
return True

View File

@@ -69,6 +69,6 @@ class TongyiModel(BaseLLM):
else:
return ex
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True

View File

@@ -57,7 +57,3 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@@ -74,6 +74,6 @@ class XinferenceModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Xinference: {str(ex)}")
@classmethod
def support_streaming(cls):
@property
def support_streaming(self):
return True