diff --git a/paperwork/tests.py b/paperwork/tests.py index 9f4fb61..d4b939b 100644 --- a/paperwork/tests.py +++ b/paperwork/tests.py @@ -1,3 +1,5 @@ +from itertools import chain + from django.contrib.auth import get_user_model from django.contrib.auth.models import Permission from django.contrib.contenttypes.models import ContentType @@ -78,29 +80,42 @@ class InstructorOrVendorReportTestCase(PermissionRequiredViewTestCaseMixin, Test @st.composite -def random_certifications(draw): - departments = draw(st.lists(from_model(Department), min_size=1)) - definitions = draw( - st.lists( - from_model( - CertificationDefinition, department=st.sampled_from(departments) - ), - min_size=1, - ) - ) - certification_versions = draw( - st.lists( - from_model(CertificationVersion, definition=st.sampled_from(definitions)), - min_size=1, - ) - ) - - return draw( - st.lists( +def random_certifications( + draw, +) -> list[Certification]: + def certifications(version: CertificationVersion): + return st.lists( from_model( Certification, number=st.none(), - certification_version=st.sampled_from(certification_versions), + certification_version=st.just(version), + ), + max_size=10, + ) + + def versions_with_certifications(definition: CertificationDefinition): + return st.lists( + from_model(CertificationVersion, definition=st.just(definition)).flatmap( + certifications + ), + max_size=2, + ) + + def definitions_with_versions(department: Department): + return st.lists( + from_model(CertificationDefinition, department=st.just(department)).flatmap( + versions_with_certifications + ), + max_size=2, + ) + + return draw( + st.lists( + from_model(Department).flatmap(definitions_with_versions), + max_size=2, + ).map( + lambda x: list( + chain.from_iterable(chain.from_iterable(chain.from_iterable(x))) ) ) ) @@ -111,7 +126,7 @@ class CertifiersReportTestCase(PermissionRequiredViewTestCaseMixin, TestCase): path = "/paperwork/certifiers" @given(certifications=random_certifications()) - def test_certifers_report(self, certifications: list[Certification]) -> None: + def test_certifiers_report(self, certifications: list[Certification]) -> None: self.client.force_login(self.user_with_permission) response = self.client.get(self.path) self.assertEqual(response.status_code, 200)