Feat/huggingface embedding support (#1211)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Garfield Dai
2023-09-22 13:59:02 +08:00
committed by GitHub
parent 32d9b6181c
commit e409895c02
10 changed files with 416 additions and 28 deletions

View File

@@ -48,6 +48,15 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
if (v.model_type === 'embeddings') {
return [
'huggingfacehub_api_token',
'huggingface_namespace',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
]
}
return [
'huggingfacehub_api_token',
'model_name',
@@ -68,14 +77,27 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
if (v.model_type === 'embeddings') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'huggingface_namespace',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
else {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
@@ -83,6 +105,31 @@ const config: ProviderConfig = {
}, {})
},
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'radio',
key: 'huggingfacehub_api_type',
@@ -121,6 +168,20 @@ const config: ProviderConfig = {
'zh-Hans': '在此输入您的 Hugging Face Hub API Token',
},
},
{
hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'),
type: 'text',
key: 'huggingface_namespace',
required: true,
label: {
'en': 'User Name / Organization Name',
'zh-Hans': '用户名 / 组织名称',
},
placeholder: {
'en': 'Enter your User Name / Organization Name here',
'zh-Hans': '在此输入您的用户名 / 组织名称',
},
},
{
type: 'text',
key: 'model_name',
@@ -148,7 +209,7 @@ const config: ProviderConfig = {
},
},
{
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings',
type: 'radio',
key: 'task_type',
required: true,
@@ -173,6 +234,25 @@ const config: ProviderConfig = {
},
],
},
{
hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'),
type: 'radio',
key: 'task_type',
required: true,
label: {
'en': 'Task',
'zh-Hans': 'Task',
},
options: [
{
key: 'feature-extraction',
label: {
'en': 'Feature Extraction',
'zh-Hans': 'Feature Extraction',
},
},
],
},
],
},
}

View File

@@ -1,7 +1,7 @@
import { useEffect, useState } from 'react'
import type { Dispatch, FC, SetStateAction } from 'react'
import { useContext } from 'use-context-selector'
import type { Field, FormValue, ProviderConfigModal } from '../declarations'
import { type Field, type FormValue, type ProviderConfigModal, ProviderEnum } from '../declarations'
import { useValidate } from '../../key-validator/hooks'
import { ValidatingTip } from '../../key-validator/ValidateStatus'
import { validateModelProviderFn } from '../utils'
@@ -85,10 +85,31 @@ const Form: FC<FormProps> = ({
}
const handleFormChange = (k: string, v: string) => {
if (mode === 'edit' && !cleared)
if (mode === 'edit' && !cleared) {
handleClear({ [k]: v })
else
handleMultiFormChange({ ...value, [k]: v }, k)
}
else {
const extraValue: Record<string, string> = {}
if (
(
(k === 'model_type' && v === 'embeddings' && value.huggingfacehub_api_type === 'inference_endpoints')
|| (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'embeddings')
)
&& modelModal?.key === ProviderEnum.huggingface_hub
)
extraValue.task_type = 'feature-extraction'
if (
(
(k === 'model_type' && v === 'text-generation' && value.huggingfacehub_api_type === 'inference_endpoints')
|| (k === 'huggingfacehub_api_type' && v === 'inference_endpoints' && value.model_type === 'text-generation')
)
&& modelModal?.key === ProviderEnum.huggingface_hub
)
extraValue.task_type = 'text-generation'
handleMultiFormChange({ ...value, [k]: v, ...extraValue }, k)
}
}
const handleFocus = () => {

View File

@@ -92,7 +92,7 @@ const ModelModal: FC<ModelModalProps> = ({
return (
<Portal>
<div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'>
<div className='w-[640px] max-h-screen bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='w-[640px] max-h-[calc(100vh-120px)] bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='px-8 pt-8'>
<div className='flex justify-between items-center mb-2'>
<div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div>