generalize position helper for parsing _position.yaml and sorting objects by name (#2803)

This commit is contained in:
Bowen Liang
2024-03-13 20:29:38 +08:00
committed by GitHub
parent 849dc0560b
commit 8b15b742ad
5 changed files with 95 additions and 46 deletions

View File

@@ -1,8 +1,7 @@
import os.path
from yaml import FullLoader, load
from core.tools.entities.user_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map
class BuiltinToolProviderSort:
@@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
cls._position = tmp_position
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
def sort_compare(provider: UserToolProvider) -> int:
def name_func(provider: UserToolProvider) -> str:
if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000)
return cls._position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare)
return f'model.{provider.name}'
else:
return provider.name
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
return sorted_providers