mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
feat: n to 1 retrieval legacy (#6554)
This commit is contained in:
@@ -4,15 +4,17 @@ import React, { useCallback } from 'react'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import {
|
||||
RiDeleteBinLine,
|
||||
RiEditLine,
|
||||
} from '@remixicon/react'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { DataSourceType } from '@/models/datasets'
|
||||
import { Settings01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import FileIcon from '@/app/components/base/file-icon'
|
||||
import { Folder } from '@/app/components/base/icons/src/vender/solid/files'
|
||||
import SettingsModal from '@/app/components/app/configuration/dataset-config/settings-modal'
|
||||
import Drawer from '@/app/components/base/drawer'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import { useKnowledge } from '@/hooks/use-knowledge'
|
||||
|
||||
type Props = {
|
||||
payload: DataSet
|
||||
@@ -29,6 +31,7 @@ const DatasetItem: FC<Props> = ({
|
||||
}) => {
|
||||
const media = useBreakpoints()
|
||||
const isMobile = media === MediaType.mobile
|
||||
const { formatIndexingTechniqueAndMethod } = useKnowledge()
|
||||
|
||||
const [isShowSettingsModal, {
|
||||
setTrue: showSettingsModal,
|
||||
@@ -62,7 +65,7 @@ const DatasetItem: FC<Props> = ({
|
||||
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
|
||||
onClick={showSettingsModal}
|
||||
>
|
||||
<Settings01 className='w-4 h-4 text-gray-500' />
|
||||
<RiEditLine className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
<div
|
||||
className='flex items-center justify-center w-6 h-6 hover:bg-black/5 rounded-md cursor-pointer'
|
||||
@@ -72,6 +75,10 @@ const DatasetItem: FC<Props> = ({
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<Badge
|
||||
className='group-hover/dataset-item:hidden shrink-0'
|
||||
text={formatIndexingTechniqueAndMethod(payload.indexing_technique, payload.retrieval_model_dict?.search_method)}
|
||||
/>
|
||||
|
||||
{isShowSettingsModal && (
|
||||
<Drawer isOpen={isShowSettingsModal} onClose={hideSettingsModal} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'>
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback, useState } from 'react'
|
||||
import { RiEqualizer2Line } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiArrowDownSLine } from '@remixicon/react'
|
||||
import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types'
|
||||
import type { ModelConfig } from '../../../types'
|
||||
import cn from '@/utils/classnames'
|
||||
@@ -16,10 +16,9 @@ import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
import type {
|
||||
DatasetConfigs,
|
||||
} from '@/models/debug'
|
||||
import Button from '@/app/components/base/button'
|
||||
import type { DatasetConfigs } from '@/models/debug'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
|
||||
type Props = {
|
||||
payload: {
|
||||
@@ -33,6 +32,9 @@ type Props = {
|
||||
onSingleRetrievalModelChange?: (config: ModelConfig) => void
|
||||
onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void
|
||||
readonly?: boolean
|
||||
openFromProps?: boolean
|
||||
onOpenFromPropsChange?: (openFromProps: boolean) => void
|
||||
selectedDatasets: DataSet[]
|
||||
}
|
||||
|
||||
const RetrievalConfig: FC<Props> = ({
|
||||
@@ -43,10 +45,18 @@ const RetrievalConfig: FC<Props> = ({
|
||||
onSingleRetrievalModelChange,
|
||||
onSingleRetrievalModelParamsChange,
|
||||
readonly,
|
||||
openFromProps,
|
||||
onOpenFromPropsChange,
|
||||
selectedDatasets,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [open, setOpen] = useState(false)
|
||||
const mergedOpen = openFromProps !== undefined ? openFromProps : open
|
||||
|
||||
const handleOpen = useCallback((newOpen: boolean) => {
|
||||
setOpen(newOpen)
|
||||
onOpenFromPropsChange?.(newOpen)
|
||||
}, [onOpenFromPropsChange])
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
@@ -72,16 +82,18 @@ const RetrievalConfig: FC<Props> = ({
|
||||
provider: configs.reranking_model?.reranking_provider_name,
|
||||
model: configs.reranking_model?.reranking_model_name,
|
||||
}),
|
||||
reranking_mode: configs.reranking_mode,
|
||||
weights: configs.weights as any,
|
||||
reranking_enable: configs.reranking_enable,
|
||||
})
|
||||
}, [onMultipleRetrievalConfigChange, payload.retrieval_mode, rerankDefaultModel?.provider?.provider, rerankDefaultModel?.model, onRetrievalModeChange])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
open={mergedOpen}
|
||||
onOpenChange={handleOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
// mainAxis: 12,
|
||||
crossAxis: -2,
|
||||
}}
|
||||
>
|
||||
@@ -89,13 +101,18 @@ const RetrievalConfig: FC<Props> = ({
|
||||
onClick={() => {
|
||||
if (readonly)
|
||||
return
|
||||
setOpen(v => !v)
|
||||
handleOpen(!mergedOpen)
|
||||
}}
|
||||
>
|
||||
<div className={cn(!readonly && 'cursor-pointer', open && 'bg-gray-100', 'flex items-center h-6 px-2 rounded-md hover:bg-gray-100 group select-none')}>
|
||||
<div className={cn(open ? 'text-gray-700' : 'text-gray-500', 'leading-[18px] text-xs font-medium group-hover:bg-gray-100')}>{payload.retrieval_mode === RETRIEVE_TYPE.oneWay ? t('appDebug.datasetConfig.retrieveOneWay.title') : t('appDebug.datasetConfig.retrieveMultiWay.title')}</div>
|
||||
{!readonly && <RiArrowDownSLine className='w-3 h-3 ml-1' />}
|
||||
</div>
|
||||
<Button
|
||||
variant='ghost'
|
||||
size='small'
|
||||
disabled={readonly}
|
||||
className={cn(open && 'bg-components-button-ghost-bg-hover')}
|
||||
>
|
||||
<RiEqualizer2Line className='mr-1 w-3.5 h-3.5' />
|
||||
{t('dataset.retrievalSettings')}
|
||||
</Button>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent style={{ zIndex: 1001 }}>
|
||||
<div className='w-[404px] pt-3 pb-4 px-4 shadow-xl rounded-2xl border border-gray-200 bg-white'>
|
||||
@@ -103,21 +120,24 @@ const RetrievalConfig: FC<Props> = ({
|
||||
datasetConfigs={
|
||||
{
|
||||
retrieval_model: payload.retrieval_mode,
|
||||
reranking_model: !multiple_retrieval_config?.reranking_model?.provider
|
||||
reranking_model: multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
reranking_provider_name: rerankDefaultModel?.provider?.provider || '',
|
||||
reranking_model_name: rerankDefaultModel?.model || '',
|
||||
reranking_provider_name: multiple_retrieval_config.reranking_model?.provider,
|
||||
reranking_model_name: multiple_retrieval_config.reranking_model?.model,
|
||||
}
|
||||
: {
|
||||
reranking_provider_name: multiple_retrieval_config?.reranking_model?.provider || '',
|
||||
reranking_model_name: multiple_retrieval_config?.reranking_model?.model || '',
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config?.score_threshold === null),
|
||||
score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null),
|
||||
score_threshold: multiple_retrieval_config?.score_threshold,
|
||||
datasets: {
|
||||
datasets: [],
|
||||
},
|
||||
reranking_mode: multiple_retrieval_config?.reranking_mode,
|
||||
weights: multiple_retrieval_config?.weights,
|
||||
reranking_enable: multiple_retrieval_config?.reranking_enable,
|
||||
}
|
||||
}
|
||||
onChange={handleChange}
|
||||
@@ -125,6 +145,7 @@ const RetrievalConfig: FC<Props> = ({
|
||||
singleRetrievalModelConfig={singleRetrievalModelConfig}
|
||||
onSingleRetrievalModelChange={onSingleRetrievalModelChange}
|
||||
onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange}
|
||||
selectedDatasets={selectedDatasets}
|
||||
/>
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
|
||||
@@ -2,7 +2,7 @@ import { BlockEnum } from '../../types'
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
|
||||
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
const i18nPrefix = 'workflow'
|
||||
|
||||
@@ -10,7 +10,12 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
|
||||
defaultValue: {
|
||||
query_variable_selector: [],
|
||||
dataset_ids: [],
|
||||
retrieval_mode: RETRIEVE_TYPE.oneWay,
|
||||
retrieval_mode: RETRIEVE_TYPE.multiWay,
|
||||
multiple_retrieval_config: {
|
||||
top_k: DATASET_DEFAULT.top_k,
|
||||
score_threshold: undefined,
|
||||
reranking_enable: false,
|
||||
},
|
||||
},
|
||||
getAvailablePrevNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { useMemo } from 'react'
|
||||
import { getSelectedDatasetsMode } from './utils'
|
||||
import type {
|
||||
DataSet,
|
||||
SelectedDatasetsMode,
|
||||
} from '@/models/datasets'
|
||||
|
||||
export const useSelectedDatasetsMode = (datasets: DataSet[]) => {
|
||||
const selectedDatasetsMode: SelectedDatasetsMode = useMemo(() => {
|
||||
return getSelectedDatasetsMode(datasets)
|
||||
}, [datasets])
|
||||
|
||||
return selectedDatasetsMode
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
} from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import VarReferencePicker from '../_base/components/variable/var-reference-picker'
|
||||
import useConfig from './use-config'
|
||||
@@ -41,8 +44,14 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
||||
query,
|
||||
setQuery,
|
||||
runResult,
|
||||
rerankModelOpen,
|
||||
setRerankModelOpen,
|
||||
} = useConfig(id, data)
|
||||
|
||||
const handleOpenFromPropsChange = useCallback((openFromProps: boolean) => {
|
||||
setRerankModelOpen(openFromProps)
|
||||
}, [setRerankModelOpen])
|
||||
|
||||
return (
|
||||
<div className='mt-2'>
|
||||
<div className='px-4 pb-4 space-y-4'>
|
||||
@@ -75,7 +84,10 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
||||
singleRetrievalModelConfig={inputs.single_retrieval_config?.model}
|
||||
onSingleRetrievalModelChange={handleModelChanged as any}
|
||||
onSingleRetrievalModelParamsChange={handleCompletionParamsChange}
|
||||
readonly={readOnly}
|
||||
readonly={readOnly || !selectedDatasets.length}
|
||||
openFromProps={rerankModelOpen}
|
||||
onOpenFromPropsChange={handleOpenFromPropsChange}
|
||||
selectedDatasets={selectedDatasets}
|
||||
/>
|
||||
{!readOnly && (<div className='w-px h-3 bg-gray-200'></div>)}
|
||||
{!readOnly && (
|
||||
@@ -162,4 +174,4 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Panel)
|
||||
export default memo(Panel)
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { RETRIEVE_TYPE } from '@/types/app'
|
||||
import type {
|
||||
RerankingModeEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '@/models/datasets'
|
||||
|
||||
export type MultipleRetrievalConfig = {
|
||||
top_k: number
|
||||
@@ -8,6 +12,19 @@ export type MultipleRetrievalConfig = {
|
||||
provider: string
|
||||
model: string
|
||||
}
|
||||
reranking_mode?: RerankingModeEnum
|
||||
weights?: {
|
||||
weight_type: WeightedScoreEnum
|
||||
vector_setting: {
|
||||
vector_weight: number
|
||||
embedding_provider_name: string
|
||||
embedding_model_name: string
|
||||
}
|
||||
keyword_setting: {
|
||||
keyword_weight: number
|
||||
}
|
||||
}
|
||||
reranking_enable?: boolean
|
||||
}
|
||||
|
||||
export type SingleRetrievalConfig = {
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react'
|
||||
import produce from 'immer'
|
||||
import { isEqual } from 'lodash-es'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
@@ -8,6 +13,10 @@ import {
|
||||
useWorkflow,
|
||||
} from '../../hooks'
|
||||
import type { KnowledgeRetrievalNodeType, MultipleRetrievalConfig } from './types'
|
||||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
getSelectedDatasetsMode,
|
||||
} from './utils'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
@@ -126,34 +135,20 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: multipleRetrievalConfig?.score_threshold,
|
||||
reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay
|
||||
? undefined
|
||||
: (!multipleRetrievalConfig?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: multipleRetrievalConfig?.reranking_model),
|
||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentProvider?.provider, currentModel, rerankDefaultModel])
|
||||
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
const [rerankModelOpen, setRerankModelOpen] = useState(false)
|
||||
const handleRetrievalModeChange = useCallback((newMode: RETRIEVE_TYPE) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: draft.multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k,
|
||||
score_threshold: draft.multiple_retrieval_config?.score_threshold,
|
||||
reranking_model: !draft.multiple_retrieval_config?.reranking_model?.provider
|
||||
? {
|
||||
provider: rerankDefaultModel?.provider?.provider || '',
|
||||
model: rerankDefaultModel?.model || '',
|
||||
}
|
||||
: draft.multiple_retrieval_config?.reranking_model,
|
||||
}
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
@@ -170,17 +165,16 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, rerankDefaultModel?.model, rerankDefaultModel?.provider?.provider, setInputs])
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.multiple_retrieval_config = newConfig
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
}, [inputs, setInputs, selectedDatasets])
|
||||
|
||||
// datasets
|
||||
const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([])
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const inputs = inputRef.current
|
||||
@@ -210,12 +204,25 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
}, [])
|
||||
|
||||
const handleOnDatasetsChange = useCallback((newDatasets: DataSet[]) => {
|
||||
const {
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
} = getSelectedDatasetsMode(newDatasets)
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.dataset_ids = newDatasets.map(d => d.id)
|
||||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
setSelectedDatasets(newDatasets)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
@@ -266,6 +273,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
query,
|
||||
setQuery,
|
||||
runResult,
|
||||
rerankModelOpen,
|
||||
setRerankModelOpen,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,124 @@
|
||||
import type { KnowledgeRetrievalNodeType } from './types'
|
||||
import { uniq } from 'lodash-es'
|
||||
import type { MultipleRetrievalConfig } from './types'
|
||||
import type {
|
||||
DataSet,
|
||||
SelectedDatasetsMode,
|
||||
} from '@/models/datasets'
|
||||
import {
|
||||
DEFAULT_WEIGHTED_SCORE,
|
||||
RerankingModeEnum,
|
||||
WeightedScoreEnum,
|
||||
} from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { DATASET_DEFAULT } from '@/config'
|
||||
|
||||
export const checkNodeValid = (payload: KnowledgeRetrievalNodeType) => {
|
||||
export const checkNodeValid = () => {
|
||||
return true
|
||||
}
|
||||
|
||||
export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
|
||||
let allHighQuality = true
|
||||
let allHighQualityVectorSearch = true
|
||||
let allHighQualityFullTextSearch = true
|
||||
let allEconomic = true
|
||||
let mixtureHighQualityAndEconomic = true
|
||||
let inconsistentEmbeddingModel = false
|
||||
if (!datasets.length) {
|
||||
allHighQuality = false
|
||||
allHighQualityVectorSearch = false
|
||||
allHighQualityFullTextSearch = false
|
||||
allEconomic = false
|
||||
mixtureHighQualityAndEconomic = false
|
||||
inconsistentEmbeddingModel = false
|
||||
}
|
||||
datasets.forEach((dataset) => {
|
||||
if (dataset.indexing_technique === 'economy') {
|
||||
allHighQuality = false
|
||||
allHighQualityVectorSearch = false
|
||||
allHighQualityFullTextSearch = false
|
||||
}
|
||||
if (dataset.indexing_technique === 'high_quality') {
|
||||
allEconomic = false
|
||||
|
||||
if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.semantic)
|
||||
allHighQualityVectorSearch = false
|
||||
|
||||
if (dataset.retrieval_model_dict.search_method !== RETRIEVE_METHOD.fullText)
|
||||
allHighQualityFullTextSearch = false
|
||||
}
|
||||
})
|
||||
|
||||
if (allHighQuality || allEconomic)
|
||||
mixtureHighQualityAndEconomic = false
|
||||
|
||||
if (allHighQuality)
|
||||
inconsistentEmbeddingModel = uniq(datasets.map(item => item.embedding_model)).length > 1
|
||||
|
||||
return {
|
||||
allHighQuality,
|
||||
allHighQualityVectorSearch,
|
||||
allHighQualityFullTextSearch,
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
} as SelectedDatasetsMode
|
||||
}
|
||||
|
||||
export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => {
|
||||
const {
|
||||
allHighQuality,
|
||||
allHighQualityVectorSearch,
|
||||
allHighQualityFullTextSearch,
|
||||
allEconomic,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
} = getSelectedDatasetsMode(selectedDatasets)
|
||||
|
||||
const {
|
||||
top_k = DATASET_DEFAULT.top_k,
|
||||
score_threshold,
|
||||
reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
reranking_enable,
|
||||
} = multipleRetrievalConfig || { top_k: DATASET_DEFAULT.top_k }
|
||||
|
||||
const result = {
|
||||
top_k,
|
||||
score_threshold,
|
||||
reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
reranking_enable,
|
||||
}
|
||||
|
||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel)
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined)
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && !weights) {
|
||||
result.weights = {
|
||||
weight_type: WeightedScoreEnum.Customized,
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
|
||||
: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
|
||||
embedding_model_name: selectedDatasets[0].embedding_model,
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
|
||||
: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user