fix: some RAG retrieval bugs (#1577)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
zxhlyh
2023-11-21 13:46:07 +08:00
committed by GitHub
parent d0456d0f42
commit 6768fd4d87
15 changed files with 267 additions and 106 deletions

View File

@@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelVaild,
retrievalConfig,
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: BackendModel
isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
indexMethod?: string
}) => {
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined)
const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
if (isRerankDefaultModelVaild)
return !!rerankDefaultModel
return false
})()
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& !rerankModel
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModelSelected
)
return false
@@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel
) {
return {

View File

@@ -16,11 +16,23 @@ type Props = {
}
const RetrievalMethodConfig: FC<Props> = ({
value,
value: passValue,
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext()
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
},
}
}
return passValue
})()
return (
<div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (

View File

@@ -263,6 +263,7 @@ const StepTwo = ({
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const getCreationParams = () => {
let params
@@ -282,6 +283,7 @@ const StepTwo = ({
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig,
indexMethod: indexMethod as string,
@@ -359,6 +361,9 @@ const StepTwo = ({
try {
let res
const params = getCreationParams()
if (!params)
return false
setIsCreating(true)
if (!datasetId) {
res = await createFirstDocument({

View File

@@ -3,11 +3,14 @@ import type { FC } from 'react'
import React, { useRef, useState } from 'react'
import { useClickAway } from 'ahooks'
import { useTranslation } from 'react-i18next'
import Toast from '../../base/toast'
import { XClose } from '@/app/components/base/icons/src/vender/line/general'
import type { RetrievalConfig } from '@/types/app'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import Button from '@/app/components/base/button'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
type Props = {
indexMethod: string
@@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({
onHide()
}, ref)
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const handleSave = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
) {
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
onSave(ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig,
indexMethod,
}))
}
if (!isShow)
return null
@@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({
}}
>
<Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={() => onSave(retrievalConfig)} >{t('common.operation.save')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
</div>
</div>
)

View File

@@ -59,6 +59,7 @@ const Form = () => {
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const handleSave = async () => {
@@ -72,6 +73,7 @@ const Form = () => {
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})