From 802b8181105637cbecdabe67613ef450a254c1e9 Mon Sep 17 00:00:00 2001 From: drunest Date: Tue, 1 Oct 2024 07:23:58 -0500 Subject: [PATCH] feat: dumper commands and codename api --- app/src/auto_validator/core/api.py | 51 +++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/app/src/auto_validator/core/api.py b/app/src/auto_validator/core/api.py index 5acb3aa..1b1599c 100644 --- a/app/src/auto_validator/core/api.py +++ b/app/src/auto_validator/core/api.py @@ -1,6 +1,9 @@ -from rest_framework import mixins, parsers, routers, viewsets +import yaml +from django.conf import settings +from rest_framework import mixins, parsers, routers, status, viewsets from rest_framework.exceptions import AuthenticationFailed from rest_framework.permissions import AllowAny +from rest_framework.response import Response from auto_validator.core.models import Hotkey, Server, UploadedFile, ValidatorInstance from auto_validator.core.serializers import UploadedFileSerializer @@ -9,6 +12,8 @@ from .authentication import HotkeyAuthentication from .utils.bot import trigger_bot_send_message +YAML_FILE_PATH = settings.LOCAL_SUBNETS_SCRIPTS_PATH + "/subnets.yaml" + class FilesViewSet(mixins.CreateModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet): serializer_class = UploadedFileSerializer @@ -50,6 +55,48 @@ def perform_create(self, serializer): ) +class DumperCommandsViewSet(viewsets.ViewSet): + parser_classes = [parsers.MultiPartParser] + permission_classes = [AllowAny] + + def list(self, request): + subnet_identifier = request.headers.get("SubnetID") + if not subnet_identifier: + return Response({"error": "subnet_identifier is required"}, status=status.HTTP_400_BAD_REQUEST) + + try: + with open(YAML_FILE_PATH) as file: + data = yaml.safe_load(file) + if subnet_identifier in data: + return Response(data[subnet_identifier].get("dumper_commands", [])) + else: + return Response({"error": "subnet_identifier not found"}, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +class NormalizedCodenameViewSet(viewsets.ViewSet): + parser_classes = [parsers.MultiPartParser] + permission_classes = [AllowAny] + + def list(self, request): + subnet_identifier = request.headers.get("SubnetID") + if not subnet_identifier: + return Response({"error": "subnet_identifier is required"}, status=status.HTTP_400_BAD_REQUEST) + + try: + with open(YAML_FILE_PATH) as file: + data = yaml.safe_load(file) + codename_lower = subnet_identifier.lower() + for normalized_codename, sn_config in data.items(): + codenames = sn_config.get("codename_list", []) + if codename_lower in map(str.lower, codenames): + return Response(normalized_codename) + return Response({"error": "subnet_identifier not found"}, status=status.HTTP_404_NOT_FOUND) + except Exception as e: + return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + class APIRootView(routers.DefaultRouter.APIRootView): description = "api-root" @@ -60,3 +107,5 @@ class APIRouter(routers.DefaultRouter): router = APIRouter() router.register(r"files", FilesViewSet, basename="file") +router.register(r"commands", DumperCommandsViewSet, basename="commands") +router.register(r"codename", NormalizedCodenameViewSet, basename="codename")