diff --git a/pokemon_v2/api.py b/pokemon_v2/api.py index 790f62b9..e5405211 100644 --- a/pokemon_v2/api.py +++ b/pokemon_v2/api.py @@ -4,6 +4,7 @@ from rest_framework.response import Response from rest_framework.views import APIView from django.shortcuts import get_object_or_404 from django.http import Http404 +from django.db.models import Q from .models import * from .serializers import * @@ -38,6 +39,15 @@ class NameOrIdRetrieval: idPattern = re.compile(r"^-?[0-9]+$") namePattern = re.compile(r"^[0-9A-Za-z\-\+]+$") + def get_queryset(self): + queryset = super().get_queryset() + filter = self.request.GET.get('q', '') + + if filter: + queryset = queryset.filter(Q(name__icontains=filter)) + + return queryset + def get_object(self): queryset = self.get_queryset() queryset = self.filter_queryset(queryset) diff --git a/pokemon_v2/tests.py b/pokemon_v2/tests.py index e458f467..48b8aa5e 100644 --- a/pokemon_v2/tests.py +++ b/pokemon_v2/tests.py @@ -5099,6 +5099,25 @@ class APITests(APIData, APITestCase): "{}".format(cries_data["legacy"]), ) + # test search pokemon using search query param `q=partial_name` + + response = self.client.get( + "{}/pokemon/?q={}".format(API_V2, pokemon.name[:2]), + HTTP_HOST="testserver", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["name"], pokemon.name) + + response = self.client.get( + "{}/pokemon/?q={}".format(API_V2, pokemon.name[-3:]), + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["count"], 1) + self.assertEqual(response.data["results"][0]["name"], pokemon.name) + def test_pokemon_form_api(self): pokemon_species = self.setup_pokemon_species_data() pokemon = self.setup_pokemon_data(pokemon_species=pokemon_species)