mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
'use client'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import AddButton from '@/app/components/base/button/add-button'
|
||||
import SelectDataset from '@/app/components/app/configuration/dataset-config/select-dataset'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
type Props = {
|
||||
selectedIds: string[]
|
||||
onChange: (dataSets: DataSet[]) => void
|
||||
}
|
||||
|
||||
const AddDataset: FC<Props> = ({
|
||||
selectedIds,
|
||||
onChange,
|
||||
}) => {
|
||||
const [isShowModal, {
|
||||
setTrue: showModal,
|
||||
setFalse: hideModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleSelect = useCallback((datasets: DataSet[]) => {
|
||||
onChange(datasets)
|
||||
hideModal()
|
||||
}, [onChange, hideModal])
|
||||
return (
|
||||
<div>
|
||||
<AddButton onClick={showModal} />
|
||||
{isShowModal && (
|
||||
<SelectDataset
|
||||
isShow={isShowModal}
|
||||
onClose={hideModal}
|
||||
selectedIds={selectedIds}
|
||||
onSelect={handleSelect}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(AddDataset)
|
||||
@@ -0,0 +1,85 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { Settings01, Trash03 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import FileIcon from '@/app/components/base/file-icon'
|
||||
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
|
||||
import SettingsModal from '@/app/components/app/configuration/dataset-config/settings-modal'
|
||||
import Drawer from '@/app/components/base/drawer'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
|
||||
type Props = {
|
||||
payload: DataSet
|
||||
onRemove: () => void
|
||||
onChange: (dataSet: DataSet) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
const DatasetItem: FC<Props> = ({
|
||||
payload,
|
||||
onRemove,
|
||||
onChange,
|
||||
readonly,
|
||||
}) => {
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
|
||||
const [isShowSettingsModal, {
|
||||
setTrue: showSettingsModal,
|
||||
setFalse: hideSettingsModal,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleSave = useCallback((newDataset: DataSet) => {
|
||||
onChange(newDataset)
|
||||
hideSettingsModal()
|
||||
}, [hideSettingsModal, onChange])
|
||||
|
||||
return (
|
||||
<div className='flex items-center h-10 justify-between rounded-xl px-2 bg-white border border-gray-200 cursor-pointer group/dataset-item'>
|
||||
<div className='w-0 grow flex items-center space-x-1.5'>
|
||||
{
|
||||
payload.data_source_type === DataSourceType.NOTION
|
||||
? (
|
||||
<div className='shrink-0 flex items-center justify-center w-6 h-6 rounded-md border-[0.5px] border-[#EAECF5]'>
|
||||
<FileIcon type='notion' className='w-4 h-4' />
|
||||
</div>
|
||||
)
|
||||
: <div className='shrink-0 flex items-center justify-center w-6 h-6 bg-[#F5F8FF] rounded-md border-[0.5px] border-[#E0EAFF]'>
|
||||
<Folder className='w-4 h-4 text-[#444CE7]' />
|
||||
</div>
|
||||
}
|
||||
<div className='w-0 grow text-[13px] font-normal text-gray-800 truncate'>{payload.name}</div>
|
||||
</div>
|
||||
{!readonly && (
|
||||
<div className='hidden group-hover/dataset-item:flex shrink-0 ml-2 items-center space-x-1'>
|
||||
<div
|
||||
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
|
||||
onClick={showSettingsModal}
|
||||
>
|
||||
<Settings01 className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
<div
|
||||
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
|
||||
onClick={onRemove}
|
||||
>
|
||||
<Trash03 className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isShowSettingsModal && (
|
||||
<Drawer isOpen={isShowSettingsModal} onClose={hideSettingsModal} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
|
||||
<SettingsModal
|
||||
currentDataset={payload}
|
||||
onCancel={hideSettingsModal}
|
||||
onSave={handleSave}
|
||||
/>
|
||||
</Drawer>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetItem)
|
||||
@@ -0,0 +1,62 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import produce from 'immer'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Item from './dataset-item'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
type Props = {
|
||||
list: DataSet[]
|
||||
onChange: (list: DataSet[]) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
const DatasetList: FC<Props> = ({
|
||||
list,
|
||||
onChange,
|
||||
readonly,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const handleRemove = useCallback((index: number) => {
|
||||
return () => {
|
||||
const newList = produce(list, (draft) => {
|
||||
draft.splice(index, 1)
|
||||
})
|
||||
onChange(newList)
|
||||
}
|
||||
}, [list, onChange])
|
||||
|
||||
const handleChange = useCallback((index: number) => {
|
||||
return (value: DataSet) => {
|
||||
const newList = produce(list, (draft) => {
|
||||
draft[index] = value
|
||||
})
|
||||
onChange(newList)
|
||||
}
|
||||
}, [list, onChange])
|
||||
return (
|
||||
<div className='space-y-1'>
|
||||
{list.length
|
||||
? list.map((item, index) => {
|
||||
return (
|
||||
<Item
|
||||
key={index}
|
||||
payload={item}
|
||||
onRemove={handleRemove(index)}
|
||||
onChange={handleChange(index)}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)
|
||||
})
|
||||
: (
|
||||
<div className='p-3 text-xs text-center text-gray-500 rounded-lg cursor-default select-none bg-gray-50'>
|
||||
{t('appDebug.datasetConfig.knowledgeTip')}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(DatasetList)
|
||||
@@ -0,0 +1,134 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
||||
import type { ModelConfig } from '../../../types'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
import type {
|
||||
DatasetConfigs,
|
||||
} from '@/models/debug'
|
||||
import { ChevronDown } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
|
||||
type Props = {
|
||||
payload: {
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
}
|
||||
onRetrievalModeChange: (mode: RETRIEVE_TYPE) => void
|
||||
onMultipleRetrievalConfigChange: (config: MultipleRetrievalConfig) => void
|
||||
singleRetrievalModelConfig?: ModelConfig
|
||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||
readonly?: boolean
|
||||
}
|
||||
|
||||
const RetrievalConfig: FC<Props> = ({
|
||||
payload,
|
||||
onRetrievalModeChange,
|
||||
onMultipleRetrievalConfigChange,
|
||||
singleRetrievalModelConfig,
|
||||
onSingleRetrievalModelChange,
|
||||
onSingleRetrievalModelParamsChange,
|
||||
readonly,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const { multiple_retrieval_config } = payload
|
||||
const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => {
|
||||
if (isRetrievalModeChange) {
|
||||
onRetrievalModeChange(configs.retrieval_model)
|
||||
return
|
||||
}
|
||||
onMultipleRetrievalConfigChange({
|
||||
top_k: configs.top_k,
|
||||
score_threshold: configs.score_threshold_enabled ? (configs.score_threshold || DATASET_DEFAULT.score_threshold) : null,
|
||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
: (!configs.reranking_model?.reranking_provider_name
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: {
|
||||
provider: configs.reranking_model?.reranking_provider_name,
|
||||
model: configs.reranking_model?.reranking_model_name,
|
||||
}),
|
||||
})
|
||||
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
// mainAxis: 12,
|
||||
crossAxis: -2,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger
|
||||
onClick={() => {
|
||||
if (readonly)
|
||||
return
|
||||
setOpen(v => !v)
|
||||
}}
|
||||
>
|
||||
<div className={cn(!readonly && 'cursor-pointer', open && 'bg-gray-100', 'flex items-center h-6 px-2 rounded-md hover:bg-gray-100 group select-none')}>
|
||||
<div className={cn(open ? 'text-gray-700' : 'text-gray-500', 'leading-[18px] text-xs font-medium group-hover:bg-gray-100')}>{payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? t('appDebug.datasetConfig.retrieveOneWay.title') : t('appDebug.datasetConfig.retrieveMultiWay.title')}</div>
|
||||
{!readonly && <ChevronDown className='w-3 h-3 ml-1' />}
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
|
||||
<div className='w-[404px] pt-3 pb-4 px-4 shadow-xl rounded-2xl border border-gray-200 bg-white'>
|
||||
<ConfigRetrievalContent
|
||||
datasetConfigs={
|
||||
{
|
||||
retrieval_model: payload.retrieval_mode,
|
||||
reranking_model: !multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
reranking_provider_name: rerankDefaultModel?.provider?.provider || '',
|
||||
reranking_model_name: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: {
|
||||
reranking_provider_name: multiple_retrieval_config?.reranking_model?.provider || '',
|
||||
reranking_model_name: multiple_retrieval_config?.reranking_model?.model || '',
|
||||
},
|
||||
top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config?.score_threshold === null),
|
||||
score_threshold: multiple_retrieval_config?.score_threshold,
|
||||
datasets: {
|
||||
datasets: [],
|
||||
},
|
||||
}
|
||||
}
|
||||
onChange={handleChange}
|
||||
isInWorkflow
|
||||
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
||||
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
||||
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
||||
/>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
export default React.memo(RetrievalConfig)
|
||||
@@ -0,0 +1,46 @@
|
||||
import { BlockEnum } from '../../types'
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
|
||||
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
const i18nPrefix = 'workflow'
|
||||
|
||||
const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
|
||||
defaultValue: {
|
||||
query_variable_selector: [],
|
||||
dataset_ids: [],
|
||||
retrieval_mode: RETRIEVE_TYPE.oneWay,
|
||||
},
|
||||
getAvailablePrevNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode
|
||||
? ALL_CHAT_AVAILABLE_BLOCKS
|
||||
: ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
|
||||
return nodes
|
||||
},
|
||||
getAvailableNextNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
|
||||
return nodes
|
||||
},
|
||||
checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
|
||||
let errorMessages = ''
|
||||
if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
|
||||
|
||||
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
|
||||
|
||||
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && !payload.multiple_retrieval_config?.reranking_model?.provider)
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
|
||||
|
||||
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
|
||||
|
||||
return {
|
||||
isValid: !errorMessages,
|
||||
errorMessage: errorMessages,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
export default nodeDefault
|
||||
@@ -0,0 +1,46 @@
|
||||
import { type FC, useEffect, useState } from 'react'
|
||||
import React from 'react'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({
|
||||
data,
|
||||
}) => {
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
if (data.dataset_ids?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } })
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
else {
|
||||
setSelectedDatasets([])
|
||||
}
|
||||
})()
|
||||
}, [data.dataset_ids])
|
||||
|
||||
if (!selectedDatasets.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className='mb-1 px-3 py-1'>
|
||||
<div className='space-y-0.5'>
|
||||
{selectedDatasets.map(({ id, name }) => (
|
||||
<div key={id} className='flex items-center h-[26px] bg-gray-100 rounded-md px-1 text-xs font-normal text-gray-700'>
|
||||
<div className='mr-1 shrink-0 p-1 bg-[#F5F8FF] rounded-md border-[0.5px] border-[#E0EAFF]'>
|
||||
<Folder className='w-3 h-3 text-[#444CE7]' />
|
||||
</div>
|
||||
<div className='text-xs font-normal text-gray-700'>
|
||||
{name}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Node)
|
||||
165
web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx
Normal file
165
web/app/components/workflow/nodes/knowledge-retrieval/panel.tsx
Normal file
@@ -0,0 +1,165 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import VarReferencePicker from '../_base/components/variable/var-reference-picker'
|
||||
import useConfig from './use-config'
|
||||
import RetrievalConfig from './components/retrieval-config'
|
||||
import AddKnowledge from './components/add-dataset'
|
||||
import DatasetList from './components/dataset-list'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
||||
import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types'
|
||||
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
|
||||
import ResultPanel from '@/app/components/workflow/run/result-panel'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.knowledgeRetrieval'
|
||||
|
||||
const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
||||
id,
|
||||
data,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const {
|
||||
readOnly,
|
||||
inputs,
|
||||
handleQueryVarChange,
|
||||
filterVar,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
selectedDatasets,
|
||||
handleOnDatasetsChange,
|
||||
isShowSingleRun,
|
||||
hideSingleRun,
|
||||
runningStatus,
|
||||
handleRun,
|
||||
handleStop,
|
||||
query,
|
||||
setQuery,
|
||||
runResult,
|
||||
} = useConfig(id, data)
|
||||
|
||||
return (
|
||||
<div className='mt-2'>
|
||||
<div className='px-4 pb-4 space-y-4'>
|
||||
{/* {JSON.stringify(inputs, null, 2)} */}
|
||||
<Field
|
||||
title={t(`${i18nPrefix}.queryVariable`)}
|
||||
>
|
||||
<VarReferencePicker
|
||||
nodeId={id}
|
||||
readonly={readOnly}
|
||||
isShowNodeName
|
||||
value={inputs.query_variable_selector}
|
||||
onChange={handleQueryVarChange}
|
||||
filterVar={filterVar}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
title={t(`${i18nPrefix}.knowledge`)}
|
||||
operations={
|
||||
<div className='flex items-center space-x-1'>
|
||||
<RetrievalConfig
|
||||
payload={{
|
||||
retrieval_mode: inputs.retrieval_mode,
|
||||
multiple_retrieval_config: inputs.multiple_retrieval_config,
|
||||
single_retrieval_config: inputs.single_retrieval_config,
|
||||
}}
|
||||
onRetrievalModeChange={handleRetrievalModeChange}
|
||||
onMultipleRetrievalConfigChange={handleMultipleRetrievalConfigChange}
|
||||
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
|
||||
onSingleRetrievalModelChange={handleModelChanged as any}
|
||||
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
||||
readonly={readOnly}
|
||||
/>
|
||||
{!readOnly && (<div className='w-px h-3 bg-gray-200'></div>)}
|
||||
{!readOnly && (
|
||||
<AddKnowledge
|
||||
selectedIds={inputs.dataset_ids}
|
||||
onChange={handleOnDatasetsChange}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<DatasetList
|
||||
list={selectedDatasets}
|
||||
onChange={handleOnDatasetsChange}
|
||||
readonly={readOnly}
|
||||
/>
|
||||
</Field>
|
||||
</div>
|
||||
|
||||
<Split />
|
||||
<div className='px-4 pt-4 pb-2'>
|
||||
<OutputVars>
|
||||
<>
|
||||
<VarItem
|
||||
name='result'
|
||||
type='Array[Object]'
|
||||
description={t(`${i18nPrefix}.outputVars.output`)}
|
||||
subItems={[
|
||||
{
|
||||
name: 'content',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.content`),
|
||||
},
|
||||
// url, title, link like bing search reference result: link, link page title, link page icon
|
||||
{
|
||||
name: 'title',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.title`),
|
||||
},
|
||||
{
|
||||
name: 'url',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.url`),
|
||||
},
|
||||
{
|
||||
name: 'icon',
|
||||
type: 'string',
|
||||
description: t(`${i18nPrefix}.outputVars.icon`),
|
||||
},
|
||||
{
|
||||
name: 'metadata',
|
||||
type: 'object',
|
||||
description: t(`${i18nPrefix}.outputVars.metadata`),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
|
||||
</>
|
||||
</OutputVars>
|
||||
{isShowSingleRun && (
|
||||
<BeforeRunForm
|
||||
nodeName={inputs.title}
|
||||
onHide={hideSingleRun}
|
||||
forms={[
|
||||
{
|
||||
inputs: [{
|
||||
label: t(`${i18nPrefix}.queryVariable`)!,
|
||||
variable: 'query',
|
||||
type: InputVarType.paragraph,
|
||||
required: true,
|
||||
}],
|
||||
values: { query },
|
||||
onChange: keyValue => setQuery((keyValue as any).query),
|
||||
},
|
||||
]}
|
||||
runningStatus={runningStatus}
|
||||
onRun={handleRun}
|
||||
onStop={handleStop}
|
||||
result={<ResultPanel {...runResult} showSteps={false} />}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Panel)
|
||||
@@ -0,0 +1,23 @@
|
||||
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { RETRIEVE_TYPE } from '@/types/app'
|
||||
|
||||
export type MultipleRetrievalConfig = {
|
||||
top_k: number
|
||||
score_threshold: number | null | undefined
|
||||
reranking_model?: {
|
||||
provider: string
|
||||
model: string
|
||||
}
|
||||
}
|
||||
|
||||
export type SingleRetrievalConfig = {
|
||||
model: ModelConfig
|
||||
}
|
||||
|
||||
export type KnowledgeRetrievalNodeType = CommonNodeType & {
|
||||
query_variable_selector: ValueSelector
|
||||
dataset_ids: string[]
|
||||
retrieval_mode: RETRIEVE_TYPE
|
||||
multiple_retrieval_config?: MultipleRetrievalConfig
|
||||
single_retrieval_config?: SingleRetrievalConfig
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import produce from 'immer'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
import { BlockEnum, VarType } from '../../types'
|
||||
import {
|
||||
useIsChatMode, useNodesReadOnly,
|
||||
useWorkflow,
|
||||
} from '../../hooks'
|
||||
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
const { nodesReadOnly: readOnly } = useNodesReadOnly()
|
||||
const isChatMode = useIsChatMode()
|
||||
const { getBeforeNodesInSameBranch } = useWorkflow()
|
||||
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
|
||||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
|
||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||
const newInputs = produce(s, (draft) => {
|
||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||
delete draft.single_retrieval_config
|
||||
else
|
||||
delete draft.multiple_retrieval_config
|
||||
})
|
||||
// not work in pass to draft...
|
||||
doSetInputs(newInputs)
|
||||
}, [doSetInputs])
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
useEffect(() => {
|
||||
inputRef.current = inputs
|
||||
}, [inputs])
|
||||
|
||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = newVar as ValueSelector
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
const draftModel = draft.single_retrieval_config?.model
|
||||
draftModel.provider = model.provider
|
||||
draftModel.name = model.modelId
|
||||
draftModel.mode = model.mode!
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
const handleCompletionParamsChange = useCallback((newParams: Record<string, any>) => {
|
||||
// inputRef.current.single_retrieval_config?.model is old when change the provider...
|
||||
if (isEqual(newParams, inputRef.current.single_retrieval_config?.model.completion_params))
|
||||
return
|
||||
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: '',
|
||||
name: '',
|
||||
mode: '',
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
draft.single_retrieval_config.model.completion_params = newParams
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [setInputs])
|
||||
|
||||
// set defaults models
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
||||
return
|
||||
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||
return
|
||||
|
||||
const newInput = produce(inputs, (draft) => {
|
||||
if (currentProvider?.provider && currentModel?.model) {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider,
|
||||
name: currentModel?.model,
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: multipleRetrievalConfig?.score_threshold,
|
||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
: (!multipleRetrievalConfig?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: multipleRetrievalConfig?.reranking_model),
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
|
||||
|
||||
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: draft.multiple_retrieval_config?.score_threshold,
|
||||
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: draft.multiple_retrieval_config?.reranking_model,
|
||||
}
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
if (!hasSetModel) {
|
||||
draft.single_retrieval_config = {
|
||||
model: {
|
||||
provider: currentProvider?.provider || '',
|
||||
name: currentModel?.model || '',
|
||||
mode: currentModel?.model_properties?.mode as string,
|
||||
completion_params: {},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.multiple_retrieval_config = newConfig
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
// datasets
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
const datasetIds = inputs.dataset_ids
|
||||
if (datasetIds?.length > 0) {
|
||||
const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } })
|
||||
setSelectedDatasets(dataSetsWithDetail)
|
||||
}
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = datasetIds
|
||||
})
|
||||
setInputs(newInputs)
|
||||
})()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||
query_variable_selector = [startNodeId, 'sys.query']
|
||||
|
||||
setInputs({
|
||||
...inputs,
|
||||
query_variable_selector,
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
}, [])
|
||||
|
||||
// single run
|
||||
const {
|
||||
isShowSingleRun,
|
||||
hideSingleRun,
|
||||
runningStatus,
|
||||
handleRun,
|
||||
handleStop,
|
||||
runInputData,
|
||||
setRunInputData,
|
||||
runResult,
|
||||
} = useOneStepRun<KnowledgeRetrievalNodeType>({
|
||||
id,
|
||||
data: inputs,
|
||||
defaultRunInputData: {
|
||||
query: '',
|
||||
},
|
||||
})
|
||||
|
||||
const query = runInputData.query
|
||||
const setQuery = useCallback((newQuery: string) => {
|
||||
setRunInputData({
|
||||
...runInputData,
|
||||
query: newQuery,
|
||||
})
|
||||
}, [runInputData, setRunInputData])
|
||||
|
||||
return {
|
||||
readOnly,
|
||||
inputs,
|
||||
handleQueryVarChange,
|
||||
filterVar,
|
||||
handleRetrievalModeChange,
|
||||
handleMultipleRetrievalConfigChange,
|
||||
handleModelChanged,
|
||||
handleCompletionParamsChange,
|
||||
selectedDatasets,
|
||||
handleOnDatasetsChange,
|
||||
isShowSingleRun,
|
||||
hideSingleRun,
|
||||
runningStatus,
|
||||
handleRun,
|
||||
handleStop,
|
||||
query,
|
||||
setQuery,
|
||||
runResult,
|
||||
}
|
||||
}
|
||||
|
||||
export default useConfig
|
||||
@@ -0,0 +1,5 @@
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
|
||||
export const checkNodeValid = (payload: KnowledgeRetrievalNodeType) => {
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user