Skip to content

Commit

Permalink
feat(judge): add FULLTEXT index and tests for Problem model
Browse files Browse the repository at this point in the history
Adds a FULLTEXT index to the `Problem` model as needed by the builtin search. Includes migration and tests (skipped for non-MySQL databases).

see DMOJ/docs#100
  • Loading branch information
JasonLovesDoggo committed Oct 1, 2024
1 parent e381e1f commit 5cb9223
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 28 deletions.
51 changes: 51 additions & 0 deletions judge/migrations/0148_judge_add_fulltext_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Generated by Django 3.2.25 on 2024-10-01 06:08

from django.db import migrations


def execute_mysql_command(apps, schema_editor, sql, error_msg, success_msg):
if schema_editor.connection.vendor != 'mysql':
return

Problem = apps.get_model('judge', 'Problem')
formatted_sql = sql.format(Problem._meta.db_table)

with schema_editor.connection.cursor() as cursor:
try:
cursor.execute(formatted_sql)
print(success_msg)
except Exception as e:
if error_msg in str(e):
print(f'Info: {error_msg}')
else:
raise


def add_fulltext_index(apps, schema_editor):
execute_mysql_command(
apps,
schema_editor,
'ALTER TABLE {} ADD FULLTEXT(code, name, description)',
'Duplicate key name',
'FULLTEXT index added successfully.',
)


def remove_fulltext_index(apps, schema_editor):
execute_mysql_command(
apps,
schema_editor,
'ALTER TABLE {} DROP INDEX code',
'check that column/key exists',
'FULLTEXT index removed successfully.',
)


class Migration(migrations.Migration):
dependencies = [
('judge', '0147_judge_add_tiers'),
]

operations = [
migrations.RunPython(add_fulltext_index, remove_fulltext_index),
]
137 changes: 109 additions & 28 deletions judge/models/tests/test_problem.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,43 @@
from unittest import skipIf

from django.core.exceptions import ValidationError
from django.db import connection
from django.db.models import F
from django.test import SimpleTestCase, TestCase
from django.utils import timezone

from judge.models import Language, LanguageLimit, Problem, Submission
from judge.models.problem import VotePermission, disallowed_characters_validator
from judge.models.tests.util import CommonDataMixin, create_contest, create_contest_participation, \
create_organization, create_problem, create_problem_type, create_solution, create_user
from judge.models.tests.util import (
CommonDataMixin,
create_contest,
create_contest_participation,
create_organization,
create_problem,
create_problem_type,
create_solution,
create_user,
)


class ProblemTestCase(CommonDataMixin, TestCase):
@classmethod
def setUpTestData(self):
def setUpTestData(cls):
super().setUpTestData()

self.users.update({
'staff_problem_edit_only_all': create_user(
username='staff_problem_edit_only_all',
is_staff=True,
user_permissions=('edit_all_problem',),
),
})
cls.users.update(
{
'staff_problem_edit_only_all': create_user(
username='staff_problem_edit_only_all',
is_staff=True,
user_permissions=('edit_all_problem',),
),
},
)

create_problem_type(name='type')

self.basic_problem = create_problem(
cls.basic_problem = create_problem(
code='basic',
allowed_languages=Language.objects.values_list('key', flat=True),
types=('type',),
Expand All @@ -35,32 +49,32 @@ def setUpTestData(self):
for lang in Language.objects.filter(common_name=Language.get_python3().common_name):
limits.append(
LanguageLimit(
problem=self.basic_problem,
problem=cls.basic_problem,
language=lang,
time_limit=100,
memory_limit=131072,
),
)
LanguageLimit.objects.bulk_create(limits)

self.organization_private_problem = create_problem(
cls.organization_private_problem = create_problem(
code='organization_private',
time_limit=2,
is_public=True,
is_organization_private=True,
curators=('staff_problem_edit_own', 'staff_problem_edit_own_no_staff'),
)

self.problem_organization = create_organization(
cls.problem_organization = create_organization(
name='problem organization',
admins=('normal', 'staff_problem_edit_public'),
)
self.organization_admin_private_problem = create_problem(
cls.organization_admin_private_problem = create_problem(
code='org_admin_private',
is_organization_private=True,
organizations=('problem organization',),
)
self.organization_admin_problem = create_problem(
cls.organization_admin_problem = create_problem(
code='organization_admin',
organizations=('problem organization',),
)
Expand All @@ -79,7 +93,10 @@ def test_basic_problem(self):

self.assertListEqual(list(self.basic_problem.author_ids), [self.users['normal'].profile.id])
self.assertListEqual(list(self.basic_problem.editor_ids), [self.users['normal'].profile.id])
self.assertListEqual(list(self.basic_problem.tester_ids), [self.users['staff_problem_edit_public'].profile.id])
self.assertListEqual(
list(self.basic_problem.tester_ids),
[self.users['staff_problem_edit_public'].profile.id],
)
self.assertListEqual(list(self.basic_problem.usable_languages), [])
self.assertListEqual(self.basic_problem.types_list, ['type'])
self.assertSetEqual(self.basic_problem.usable_common_names, set())
Expand Down Expand Up @@ -255,7 +272,10 @@ def give_basic_problem_ac(self, user, points=None):
)

def test_problem_voting_permissions(self):
self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['anonymous']), VotePermission.NONE)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['anonymous']),
VotePermission.NONE,
)

now = timezone.now()
basic_contest = create_contest(
Expand All @@ -281,17 +301,29 @@ def test_problem_voting_permissions(self):
banned_from_voting = create_user(username='banned_from_voting')
banned_from_voting.profile.is_banned_from_problem_voting = True
self.give_basic_problem_ac(banned_from_voting)
self.assertEqual(self.basic_problem.vote_permission_for_user(banned_from_voting), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(banned_from_voting),
VotePermission.VIEW,
)

banned_from_problem = create_user(username='banned_from_problem')
self.basic_problem.banned_users.add(banned_from_problem.profile)
self.give_basic_problem_ac(banned_from_problem)
self.assertEqual(self.basic_problem.vote_permission_for_user(banned_from_problem), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(banned_from_problem),
VotePermission.VIEW,
)

self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['normal']), VotePermission.VIEW)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['normal']),
VotePermission.VIEW,
)

self.give_basic_problem_ac(self.users['normal'])
self.assertEqual(self.basic_problem.vote_permission_for_user(self.users['normal']), VotePermission.VOTE)
self.assertEqual(
self.basic_problem.vote_permission_for_user(self.users['normal']),
VotePermission.VOTE,
)

partial_ac = create_user(username='partial_ac')
self.give_basic_problem_ac(partial_ac, 0.5) # ensure this value is not equal to its point value
Expand Down Expand Up @@ -330,12 +362,14 @@ class SolutionTestCase(CommonDataMixin, TestCase):
@classmethod
def setUpTestData(self):
super().setUpTestData()
self.users.update({
'staff_solution_see_all': create_user(
username='staff_solution_see_all',
user_permissions=('see_private_solution',),
),
})
self.users.update(
{
'staff_solution_see_all': create_user(
username='staff_solution_see_all',
user_permissions=('see_private_solution',),
),
},
)

now = timezone.now()

Expand Down Expand Up @@ -448,3 +482,50 @@ def test_invalid(self):
disallowed_characters_validator('“')
with self.assertRaisesRegex(ValidationError, 'Disallowed characters: (?=.*‘)(?=.*’)'):
disallowed_characters_validator('‘’')


@skipIf(connection.vendor != 'mysql', 'FULLTEXT search is only supported on MySQL')
class FullTextSearchTestCase(TestCase):
def setUp(self):
Problem.objects.create(code='P1', name='Django Test', description='A test problem for Django')
Problem.objects.create(code='P2', name='Python Challenge', description='A challenging Python problem')
Problem.objects.create(code='P3', name='Database Query', description='A problem about SQL and databases')

def test_fulltext_search_name(self):
results = Problem.objects.filter(name__search='Python')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P2')

def test_fulltext_search_description(self):
results = Problem.objects.filter(description__search='database')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P3')

def test_fulltext_search_multiple_columns(self):
results = Problem.objects.filter(name__search='test') | Problem.objects.filter(description__search='test')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P1')

def test_fulltext_search_ranking(self):
Problem.objects.create(code='P4', name='Advanced Python', description='Python for advanced users')
Problem.objects.create(code='P5', name='Python Basics', description='Introduction to Python programming')

results = Problem.objects.filter(name__search='Python') | Problem.objects.filter(description__search='Python')
results = results.annotate(relevance=F('name__search') + F('description__search')).order_by('-relevance')

self.assertTrue(len(results) > 1)
self.assertEqual(results[0].code, 'P2')

def test_fulltext_search_boolean_mode(self):
results = Problem.objects.filter(description__search='+SQL -Python')
self.assertEqual(results.count(), 1)
self.assertEqual(results[0].code, 'P3')

def test_fulltext_search_no_results(self):
results = Problem.objects.filter(name__search='NonexistentTerm')
self.assertEqual(results.count(), 0)

@classmethod
def tearDownClass(cls):
Problem.objects.all().delete()
super().tearDownClass()

0 comments on commit 5cb9223

Please sign in to comment.