mirror of
http://112.124.100.131/huang.ze/ebiz-dify-ai.git
synced 2025-12-10 11:26:52 +08:00
generalize position helper for parsing _position.yaml and sorting objects by name (#2803)
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import os.path
|
||||
|
||||
from yaml import FullLoader, load
|
||||
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
@@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
tmp_position = {}
|
||||
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
|
||||
with open(file_path) as f:
|
||||
for pos, val in enumerate(load(f, Loader=FullLoader)):
|
||||
tmp_position[val] = pos
|
||||
cls._position = tmp_position
|
||||
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
def sort_compare(provider: UserToolProvider) -> int:
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||
return cls._position.get(f'model.{provider.name}', 10000)
|
||||
return cls._position.get(provider.name, 10000)
|
||||
|
||||
sorted_providers = sorted(providers, key=sort_compare)
|
||||
return f'model.{provider.name}'
|
||||
else:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
Reference in New Issue
Block a user