fix: some RAG retrieval bugs (#1577)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
zxhlyh
2023-11-21 13:46:07 +08:00
committed by GitHub
parent d0456d0f42
commit 6768fd4d87
15 changed files with 267 additions and 106 deletions

View File

@@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelVaild,
retrievalConfig,
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: BackendModel
isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
indexMethod?: string
}) => {
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined)
const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)
if (isRerankDefaultModelVaild)
return !!rerankDefaultModel
return false
})()
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& !rerankModel
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModelSelected
)
return false
@@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel
) {
return {

View File

@@ -16,11 +16,23 @@ type Props = {
}
const RetrievalMethodConfig: FC<Props> = ({
value,
value: passValue,
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext()
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
},
}
}
return passValue
})()
return (
<div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (