mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 03:16:51 +08:00
fix: some RAG retrieval bugs (#1577)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
@@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
retrievalConfig,
|
||||
rerankModelList,
|
||||
indexMethod,
|
||||
}: {
|
||||
rerankDefaultModel?: BackendModel
|
||||
isRerankDefaultModelVaild: boolean
|
||||
retrievalConfig: RetrievalConfig
|
||||
rerankModelList: BackendModel[]
|
||||
indexMethod?: string
|
||||
}) => {
|
||||
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined)
|
||||
const rerankModelSelected = (() => {
|
||||
if (retrievalConfig.reranking_model?.reranking_model_name)
|
||||
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
|
||||
|
||||
if (isRerankDefaultModelVaild)
|
||||
return !!rerankDefaultModel
|
||||
|
||||
return false
|
||||
})()
|
||||
|
||||
if (
|
||||
indexMethod === 'high_quality'
|
||||
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
|
||||
&& !rerankModel
|
||||
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
|
||||
&& !rerankModelSelected
|
||||
)
|
||||
return false
|
||||
|
||||
@@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
|
||||
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
|
||||
if (
|
||||
indexMethod === 'high_quality'
|
||||
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
|
||||
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
|
||||
&& !rerankModel
|
||||
) {
|
||||
return {
|
||||
|
||||
@@ -16,11 +16,23 @@ type Props = {
|
||||
}
|
||||
|
||||
const RetrievalMethodConfig: FC<Props> = ({
|
||||
value,
|
||||
value: passValue,
|
||||
onChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { supportRetrievalMethods } = useProviderContext()
|
||||
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
|
||||
const value = (() => {
|
||||
if (!passValue.reranking_model.reranking_model_name) {
|
||||
return {
|
||||
...passValue,
|
||||
reranking_model: {
|
||||
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
|
||||
reranking_model_name: rerankDefaultModel?.model_name || '',
|
||||
},
|
||||
}
|
||||
}
|
||||
return passValue
|
||||
})()
|
||||
return (
|
||||
<div className='space-y-2'>
|
||||
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
|
||||
|
||||
@@ -263,6 +263,7 @@ const StepTwo = ({
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
const getCreationParams = () => {
|
||||
let params
|
||||
@@ -282,6 +283,7 @@ const StepTwo = ({
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
retrievalConfig,
|
||||
indexMethod: indexMethod as string,
|
||||
@@ -359,6 +361,9 @@ const StepTwo = ({
|
||||
try {
|
||||
let res
|
||||
const params = getCreationParams()
|
||||
if (!params)
|
||||
return false
|
||||
|
||||
setIsCreating(true)
|
||||
if (!datasetId) {
|
||||
res = await createFirstDocument({
|
||||
|
||||
@@ -3,11 +3,14 @@ import type { FC } from 'react'
|
||||
import React, { useRef, useState } from 'react'
|
||||
import { useClickAway } from 'ahooks'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '../../base/toast'
|
||||
import { XClose } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
|
||||
type Props = {
|
||||
indexMethod: string
|
||||
@@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({
|
||||
onHide()
|
||||
}, ref)
|
||||
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
|
||||
const handleSave = () => {
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
})
|
||||
) {
|
||||
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||
return
|
||||
}
|
||||
onSave(ensureRerankModelSelected({
|
||||
rerankDefaultModel: rerankDefaultModel!,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
}))
|
||||
}
|
||||
|
||||
if (!isShow)
|
||||
return null
|
||||
|
||||
@@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({
|
||||
}}
|
||||
>
|
||||
<Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button>
|
||||
<Button type='primary' className='flex-shrink-0' onClick={() => onSave(retrievalConfig)} >{t('common.operation.save')}</Button>
|
||||
<Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -59,6 +59,7 @@ const Form = () => {
|
||||
const {
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
} = useProviderContext()
|
||||
|
||||
const handleSave = async () => {
|
||||
@@ -72,6 +73,7 @@ const Form = () => {
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelVaild,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user