FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost
2024-04-08 18:51:46 +08:00
committed by GitHub
parent 2fb9850af5
commit 7753ba2d37
1161 changed files with 103836 additions and 10327 deletions

View File

@@ -0,0 +1,41 @@
'use client'
import { useBoolean } from 'ahooks'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import AddButton from '@/app/components/base/button/add-button'
import SelectDataset from '@/app/components/app/configuration/dataset-config/select-dataset'
import type { DataSet } from '@/models/datasets'
type Props = {
selectedIds: string[]
onChange: (dataSets: DataSet[]) => void
}
const AddDataset: FC<Props> = ({
selectedIds,
onChange,
}) => {
const [isShowModal, {
setTrue: showModal,
setFalse: hideModal,
}] = useBoolean(false)
const handleSelect = useCallback((datasets: DataSet[]) => {
onChange(datasets)
hideModal()
}, [onChange, hideModal])
return (
<div>
<AddButton onClick={showModal} />
{isShowModal && (
<SelectDataset
isShow={isShowModal}
onClose={hideModal}
selectedIds={selectedIds}
onSelect={handleSelect}
/>
)}
</div>
)
}
export default React.memo(AddDataset)

View File

@@ -0,0 +1,85 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import { useBoolean } from 'ahooks'
import type { DataSet } from '@/models/datasets'
import { DataSourceType } from '@/models/datasets'
import { Settings01, Trash03 } from '@/app/components/base/icons/src/vender/line/general'
import FileIcon from '@/app/components/base/file-icon'
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
import SettingsModal from '@/app/components/app/configuration/dataset-config/settings-modal'
import Drawer from '@/app/components/base/drawer'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
type Props = {
payload: DataSet
onRemove: () => void
onChange: (dataSet: DataSet) => void
readonly?: boolean
}
const DatasetItem: FC<Props> = ({
payload,
onRemove,
onChange,
readonly,
}) => {
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const [isShowSettingsModal, {
setTrue: showSettingsModal,
setFalse: hideSettingsModal,
}] = useBoolean(false)
const handleSave = useCallback((newDataset: DataSet) => {
onChange(newDataset)
hideSettingsModal()
}, [hideSettingsModal, onChange])
return (
<div className='flex items-center h-10 justify-between rounded-xl px-2 bg-white border border-gray-200 cursor-pointer group/dataset-item'>
<div className='w-0 grow flex items-center space-x-1.5'>
{
payload.data_source_type === DataSourceType.NOTION
? (
<div className='shrink-0 flex items-center justify-center w-6 h-6 rounded-md border-[0.5px] border-[#EAECF5]'>
<FileIcon type='notion' className='w-4 h-4' />
</div>
)
: <div className='shrink-0 flex items-center justify-center w-6 h-6 bg-[#F5F8FF] rounded-md border-[0.5px] border-[#E0EAFF]'>
<Folder className='w-4 h-4 text-[#444CE7]' />
</div>
}
<div className='w-0 grow text-[13px] font-normal text-gray-800 truncate'>{payload.name}</div>
</div>
{!readonly && (
<div className='hidden group-hover/dataset-item:flex shrink-0 ml-2 items-center space-x-1'>
<div
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
onClick={showSettingsModal}
>
<Settings01 className='w-4 h-4 text-gray-500' />
</div>
<div
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
onClick={onRemove}
>
<Trash03 className='w-4 h-4 text-gray-500' />
</div>
</div>
)}
{isShowSettingsModal && (
<Drawer isOpen={isShowSettingsModal} onClose={hideSettingsModal} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
<SettingsModal
currentDataset={payload}
onCancel={hideSettingsModal}
onSave={handleSave}
/>
</Drawer>
)}
</div>
)
}
export default React.memo(DatasetItem)

View File

@@ -0,0 +1,62 @@
'use client'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import produce from 'immer'
import { useTranslation } from 'react-i18next'
import Item from './dataset-item'
import type { DataSet } from '@/models/datasets'
type Props = {
list: DataSet[]
onChange: (list: DataSet[]) => void
readonly?: boolean
}
const DatasetList: FC<Props> = ({
list,
onChange,
readonly,
}) => {
const { t } = useTranslation()
const handleRemove = useCallback((index: number) => {
return () => {
const newList = produce(list, (draft) => {
draft.splice(index, 1)
})
onChange(newList)
}
}, [list, onChange])
const handleChange = useCallback((index: number) => {
return (value: DataSet) => {
const newList = produce(list, (draft) => {
draft[index] = value
})
onChange(newList)
}
}, [list, onChange])
return (
<div className='space-y-1'>
{list.length
? list.map((item, index) => {
return (
<Item
key={index}
payload={item}
onRemove={handleRemove(index)}
onChange={handleChange(index)}
readonly={readonly}
/>
)
})
: (
<div className='p-3 text-xs text-center text-gray-500 rounded-lg cursor-default select-none bg-gray-50'>
{t('appDebug.datasetConfig.knowledgeTip')}
</div>
)
}
</div>
)
}
export default React.memo(DatasetList)

View File

@@ -0,0 +1,134 @@
'use client'
import type { FC } from 'react'
import React, { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
import type { ModelConfig } from '../../../types'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type {
DatasetConfigs,
} from '@/models/debug'
import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
type Props = {
payload: {
retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
}
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
singleRetrievalModelConfig?: ModelConfig
onSingleRetrievalModelChange?: (config: ModelConfig) => void
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
readonly?: boolean
}
const RetrievalConfig: FC<Props> = ({
payload,
onRetrievalModeChange,
onMultipleRetrievalConfigChange,
singleRetrievalModelConfig,
onSingleRetrievalModelChange,
onSingleRetrievalModelParamsChange,
readonly,
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const {
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const { multiple_retrieval_config } = payload
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
if (isRetrievalModeChange) {
onRetrievalModeChange(configs.retrieval_model)
return
}
onMultipleRetrievalConfigChange({
top_k: configs.top_k,
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold || DATASET_DEFAULT.score_threshold) : null,
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
? undefined
: (!configs.reranking_model?.reranking_provider_name
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: {
provider: configs.reranking_model?.reranking_provider_name,
model: configs.reranking_model?.reranking_model_name,
}),
})
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='bottom-end'
offset={{
// mainAxis: 12,
crossAxis: -2,
}}
>
<PortalToFollowElemTrigger
onClick={() => {
if (readonly)
return
setOpen(v => !v)
}}
>
<div className={cn(!readonly && 'cursor-pointer', open && 'bg-gray-100', 'flex items-center h-6 px-2 rounded-md hover:bg-gray-100 group select-none')}>
<div className={cn(open ? 'text-gray-700' : 'text-gray-500', 'leading-[18px] text-xs font-medium group-hover:bg-gray-100')}>{payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? t('appDebug.datasetConfig.retrieveOneWay.title') : t('appDebug.datasetConfig.retrieveMultiWay.title')}</div>
{!readonly && <ChevronDown className='w-3 h-3 ml-1' />}
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
<div className='w-[404px] pt-3 pb-4 px-4 shadow-xl rounded-2xl border border-gray-200 bg-white'>
<ConfigRetrievalContent
datasetConfigs={
{
retrieval_model: payload.retrieval_mode,
reranking_model: !multiple_retrieval_config?.reranking_model?.provider
? {
reranking_provider_name: rerankDefaultModel?.provider?.provider || '',
reranking_model_name: rerankDefaultModel?.model || '',
}
: {
reranking_provider_name: multiple_retrieval_config?.reranking_model?.provider || '',
reranking_model_name: multiple_retrieval_config?.reranking_model?.model || '',
},
top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config?.score_threshold === null),
score_threshold: multiple_retrieval_config?.score_threshold,
datasets: {
datasets: [],
},
}
}
onChange={handleChange}
isInWorkflow
singleRetrievalModelConfig={singleRetrievalModelConfig}
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
/>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default React.memo(RetrievalConfig)

View File

@@ -0,0 +1,46 @@
import { BlockEnum } from '../../types'
import type { NodeDefault } from '../../types'
import type { KnowledgeRetrievalNodeType } from './types'
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
import { RETRIEVE_TYPE } from '@/types/app'
const i18nPrefix = 'workflow'
const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
defaultValue: {
query_variable_selector: [],
dataset_ids: [],
retrieval_mode: RETRIEVE_TYPE.oneWay,
},
getAvailablePrevNodes(isChatMode: boolean) {
const nodes = isChatMode
? ALL_CHAT_AVAILABLE_BLOCKS
: ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
return nodes
},
getAvailableNextNodes(isChatMode: boolean) {
const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
return nodes
},
checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
let errorMessages = ''
if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && !payload.multiple_retrieval_config?.reranking_model?.provider)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
return {
isValid: !errorMessages,
errorMessage: errorMessages,
}
},
}
export default nodeDefault

View File

@@ -0,0 +1,46 @@
import { type FC, useEffect, useState } from 'react'
import React from 'react'
import type { KnowledgeRetrievalNodeType } from './types'
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
import type { NodeProps } from '@/app/components/workflow/types'
import { fetchDatasets } from '@/service/datasets'
import type { DataSet } from '@/models/datasets'
const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
data,
}) => {
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
useEffect(() => {
(async () => {
if (data.dataset_ids?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } })
setSelectedDatasets(dataSetsWithDetail)
}
else {
setSelectedDatasets([])
}
})()
}, [data.dataset_ids])
if (!selectedDatasets.length)
return null
return (
<div className='mb-1 px-3 py-1'>
<div className='space-y-0.5'>
{selectedDatasets.map(({ id, name }) => (
<div key={id} className='flex items-center h-[26px] bg-gray-100 rounded-md px-1 text-xs font-normal text-gray-700'>
<div className='mr-1 shrink-0 p-1 bg-[#F5F8FF] rounded-md border-[0.5px] border-[#E0EAFF]'>
<Folder className='w-3 h-3 text-[#444CE7]' />
</div>
<div className='text-xs font-normal text-gray-700'>
{name}
</div>
</div>
))}
</div>
</div>
)
}
export default React.memo(Node)

View File

@@ -0,0 +1,165 @@
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import VarReferencePicker from '../_base/components/variable/var-reference-picker'
import useConfig from './use-config'
import RetrievalConfig from './components/retrieval-config'
import AddKnowledge from './components/add-dataset'
import DatasetList from './components/dataset-list'
import type { KnowledgeRetrievalNodeType } from './types'
import Field from '@/app/components/workflow/nodes/_base/components/field'
import Split from '@/app/components/workflow/nodes/_base/components/split'
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types'
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
import ResultPanel from '@/app/components/workflow/run/result-panel'
const i18nPrefix = 'workflow.nodes.knowledgeRetrieval'
const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
id,
data,
}) => {
const { t } = useTranslation()
const {
readOnly,
inputs,
handleQueryVarChange,
filterVar,
handleModelChanged,
handleCompletionParamsChange,
handleRetrievalModeChange,
handleMultipleRetrievalConfigChange,
selectedDatasets,
handleOnDatasetsChange,
isShowSingleRun,
hideSingleRun,
runningStatus,
handleRun,
handleStop,
query,
setQuery,
runResult,
} = useConfig(id, data)
return (
<div className='mt-2'>
<div className='px-4 pb-4 space-y-4'>
{/* {JSON.stringify(inputs, null, 2)} */}
<Field
title={t(`${i18nPrefix}.queryVariable`)}
>
<VarReferencePicker
nodeId={id}
readonly={readOnly}
isShowNodeName
value={inputs.query_variable_selector}
onChange={handleQueryVarChange}
filterVar={filterVar}
/>
</Field>
<Field
title={t(`${i18nPrefix}.knowledge`)}
operations={
<div className='flex items-center space-x-1'>
<RetrievalConfig
payload={{
retrieval_mode: inputs.retrieval_mode,
multiple_retrieval_config: inputs.multiple_retrieval_config,
single_retrieval_config: inputs.single_retrieval_config,
}}
onRetrievalModeChange={handleRetrievalModeChange}
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
onSingleRetrievalModelChange={handleModelChanged as any}
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
readonly={readOnly}
/>
{!readOnly && (<div className='w-px h-3 bg-gray-200'></div>)}
{!readOnly && (
<AddKnowledge
selectedIds={inputs.dataset_ids}
onChange={handleOnDatasetsChange}
/>
)}
</div>
}
>
<DatasetList
list={selectedDatasets}
onChange={handleOnDatasetsChange}
readonly={readOnly}
/>
</Field>
</div>
<Split />
<div className='px-4 pt-4 pb-2'>
<OutputVars>
<>
<VarItem
name='result'
type='Array[Object]'
description={t(`${i18nPrefix}.outputVars.output`)}
subItems={[
{
name: 'content',
type: 'string',
description: t(`${i18nPrefix}.outputVars.content`),
},
// url, title, link like bing search reference result: link, link page title, link page icon
{
name: 'title',
type: 'string',
description: t(`${i18nPrefix}.outputVars.title`),
},
{
name: 'url',
type: 'string',
description: t(`${i18nPrefix}.outputVars.url`),
},
{
name: 'icon',
type: 'string',
description: t(`${i18nPrefix}.outputVars.icon`),
},
{
name: 'metadata',
type: 'object',
description: t(`${i18nPrefix}.outputVars.metadata`),
},
]}
/>
</>
</OutputVars>
{isShowSingleRun && (
<BeforeRunForm
nodeName={inputs.title}
onHide={hideSingleRun}
forms={[
{
inputs: [{
label: t(`${i18nPrefix}.queryVariable`)!,
variable: 'query',
type: InputVarType.paragraph,
required: true,
}],
values: { query },
onChange: keyValue => setQuery((keyValue as any).query),
},
]}
runningStatus={runningStatus}
onRun={handleRun}
onStop={handleStop}
result={<ResultPanel {...runResult} showSteps={false} />}
/>
)}
</div>
</div>
)
}
export default React.memo(Panel)

View File

@@ -0,0 +1,23 @@
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
import type { RETRIEVE_TYPE } from '@/types/app'
export type MultipleRetrievalConfig = {
top_k: number
score_threshold: number | null | undefined
reranking_model?: {
provider: string
model: string
}
}
export type SingleRetrievalConfig = {
model: ModelConfig
}
export type KnowledgeRetrievalNodeType = CommonNodeType & {
query_variable_selector: ValueSelector
dataset_ids: string[]
retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig
}

View File

@@ -0,0 +1,272 @@
import { useCallback, useEffect, useRef, useState } from 'react'
import produce from 'immer'
import { isEqual } from 'lodash-es'
import type { ValueSelector, Var } from '../../types'
import { BlockEnum, VarType } from '../../types'
import {
useIsChatMode, useNodesReadOnly,
useWorkflow,
} from '../../hooks'
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
import { RETRIEVE_TYPE } from '@/types/app'
import { DATASET_DEFAULT } from '@/config'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
const isChatMode = useIsChatMode()
const { getBeforeNodesInSameBranch } = useWorkflow()
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
const newInputs = produce(s, (draft) => {
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
delete draft.single_retrieval_config
else
delete draft.multiple_retrieval_config
})
// not work in pass to draft...
doSetInputs(newInputs)
}, [doSetInputs])
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector
})
setInputs(newInputs)
}, [inputs, setInputs])
const {
currentProvider,
currentModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
const {
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
const draftModel = draft.single_retrieval_config?.model
draftModel.provider = model.provider
draftModel.name = model.modelId
draftModel.mode = model.mode!
})
setInputs(newInputs)
}, [setInputs])
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
// inputRef.current.single_retrieval_config?.model is old when change the provider...
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
return
const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) {
draft.single_retrieval_config = {
model: {
provider: '',
name: '',
mode: '',
completion_params: {},
},
}
}
draft.single_retrieval_config.model.completion_params = newParams
})
setInputs(newInputs)
}, [setInputs])
// set defaults models
useEffect(() => {
const inputs = inputRef.current
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
return
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
return
const newInput = produce(inputs, (draft) => {
if (currentProvider?.provider && currentModel?.model) {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider,
name: currentModel?.model,
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
score_threshold: multipleRetrievalConfig?.score_threshold,
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
? undefined
: (!multipleRetrievalConfig?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: multipleRetrievalConfig?.reranking_model),
}
})
setInputs(newInput)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
const newInputs = produce(inputs, (draft) => {
draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) {
draft.multiple_retrieval_config = {
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
score_threshold: draft.multiple_retrieval_config?.score_threshold,
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
? {
provider: rerankDefaultModel?.provider?.provider || '',
model: rerankDefaultModel?.model || '',
}
: draft.multiple_retrieval_config?.reranking_model,
}
}
else {
const hasSetModel = draft.single_retrieval_config?.model?.provider
if (!hasSetModel) {
draft.single_retrieval_config = {
model: {
provider: currentProvider?.provider || '',
name: currentModel?.model || '',
mode: currentModel?.model_properties?.mode as string,
completion_params: {},
},
}
}
}
})
setInputs(newInputs)
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = newConfig
})
setInputs(newInputs)
}, [inputs, setInputs])
// datasets
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
useEffect(() => {
(async () => {
const inputs = inputRef.current
const datasetIds = inputs.dataset_ids
if (datasetIds?.length > 0) {
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
setSelectedDatasets(dataSetsWithDetail)
}
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds
})
setInputs(newInputs)
})()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
useEffect(() => {
let query_variable_selector: ValueSelector = inputs.query_variable_selector
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
query_variable_selector = [startNodeId, 'sys.query']
setInputs({
...inputs,
query_variable_selector,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id)
})
setInputs(newInputs)
setSelectedDatasets(newDatasets)
}, [inputs, setInputs])
const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string
}, [])
// single run
const {
isShowSingleRun,
hideSingleRun,
runningStatus,
handleRun,
handleStop,
runInputData,
setRunInputData,
runResult,
} = useOneStepRun<KnowledgeRetrievalNodeType>({
id,
data: inputs,
defaultRunInputData: {
query: '',
},
})
const query = runInputData.query
const setQuery = useCallback((newQuery: string) => {
setRunInputData({
...runInputData,
query: newQuery,
})
}, [runInputData, setRunInputData])
return {
readOnly,
inputs,
handleQueryVarChange,
filterVar,
handleRetrievalModeChange,
handleMultipleRetrievalConfigChange,
handleModelChanged,
handleCompletionParamsChange,
selectedDatasets,
handleOnDatasetsChange,
isShowSingleRun,
hideSingleRun,
runningStatus,
handleRun,
handleStop,
query,
setQuery,
runResult,
}
}
export default useConfig

View File

@@ -0,0 +1,5 @@
import type { KnowledgeRetrievalNodeType } from './types'
export const checkNodeValid = (payload: KnowledgeRetrievalNodeType) => {
return true
}