From bb6577f3bb36e4ef0750ecb7ac9d5a3e4d60cbff Mon Sep 17 00:00:00 2001 From: Adam Goldsmith Date: Mon, 10 Apr 2023 14:29:12 -0400 Subject: [PATCH] paperwork: Improve typing for models and admin --- paperwork/admin.py | 44 ++++++++++++++++++++++++++++---------------- paperwork/models.py | 43 +++++++++++++++++++++++++++++-------------- 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/paperwork/admin.py b/paperwork/admin.py index 1be7af6..83b6ce9 100644 --- a/paperwork/admin.py +++ b/paperwork/admin.py @@ -1,16 +1,22 @@ +from typing import Optional, Any, Type, cast + from django import forms from django.core import mail from django.contrib import admin, messages from django.db.models import Value +from django.db.models.query import QuerySet from django.db.models.functions import Now, Concat, LPad +from django.http import HttpRequest from .models import ( + AbstractAudit, CmsRedRiverVeteransScholarship, Department, CertificationDefinition, Certification, CertificationAudit, CertificationVersion, + CertificationVersionAnnotated, InstructorOrVendor, SpecialProgram, Waiver, @@ -23,7 +29,7 @@ from .certification_emails import all_certification_emails class AlwaysChangedModelForm(forms.models.ModelForm): """By always returning true even unchanged inlines will get validated and saved.""" - def has_changed(self): + def has_changed(self) -> bool: return True @@ -31,7 +37,11 @@ class AbstractAuditInline(admin.TabularInline): extra = 0 form = AlwaysChangedModelForm - def get_formset(self, request, obj=None, **kwargs): + def get_formset( + self, request: HttpRequest, obj: Optional[AbstractAudit] = None, **kwargs: Any + ) -> Type[ + "forms.models.BaseInlineFormSet[AbstractAudit, Any, forms.models.ModelForm[Any]]" + ]: formset = super().get_formset(request, obj, **kwargs) formset.form.base_fields["user"].initial = request.user return formset @@ -60,11 +70,11 @@ class CertificationVersionInline(admin.TabularInline): ) @admin.display(description="Latest", boolean=True) - def is_latest(self, obj): + def is_latest(self, obj: CertificationVersionAnnotated) -> bool: return obj.is_latest @admin.display(description="Current", boolean=True) - def is_current(self, obj): + def is_current(self, obj: CertificationVersionAnnotated) -> bool: return obj.is_current @@ -80,8 +90,8 @@ class CertificationDefinitionAdmin(admin.ModelAdmin): inlines = [CertificationVersionInline] @admin.display(description="Latest Version") - def latest_semantic_version(self, obj): - return obj.latest_version().semantic_version() + def latest_semantic_version(self, obj: CertificationDefinition) -> str: + return str(obj.latest_version().semantic_version()) class CertificationAuditInline(AbstractAuditInline): @@ -100,7 +110,7 @@ class CertificationAdmin(admin.ModelAdmin): exclude = ["shop_lead_notified"] inlines = [CertificationAuditInline] - def get_queryset(self, request): + def get_queryset(self, request: HttpRequest) -> QuerySet[Certification]: qs = super().get_queryset(request) return qs.prefetch_related("certification_version__definition__department") @@ -108,7 +118,7 @@ class CertificationAdmin(admin.ModelAdmin): description="Certification Name", ordering="certification_version__definition__certification_name", ) - def certification_name(self, obj): + def certification_name(self, obj: Certification) -> str: return obj.certification_version.definition.certification_name @admin.display( @@ -123,22 +133,22 @@ class CertificationAdmin(admin.ModelAdmin): ) ), ) - def certification_semantic_version(self, obj): - return obj.certification_version.semantic_version() + def certification_semantic_version(self, obj: Certification) -> str: + return str(obj.certification_version.semantic_version()) @admin.display(description="Current", boolean=True) - def is_current(self, obj): - return obj.certification_version.is_current + def is_current(self, obj: Certification) -> bool: + return cast(CertificationVersionAnnotated, obj.certification_version).is_current @admin.display( description="Department", ordering="certification_version__definition__department", ) - def certification_department(self, obj): + def certification_department(self, obj: Certification) -> Department: return obj.certification_version.definition.department @admin.display(description="Latest Audit") - def latest_audit(self, obj): + def latest_audit(self, obj: Certification) -> CertificationAudit: return obj.audits.latest() list_display = [ @@ -167,7 +177,9 @@ class CertificationAdmin(admin.ModelAdmin): @admin.action( description="Notify Shop Leads and Members of selected certifications" ) - def send_notifications(self, request, queryset): + def send_notifications( + self, request: HttpRequest, queryset: QuerySet[Certification] + ) -> None: try: emails = list(all_certification_emails(queryset)) @@ -244,7 +256,7 @@ class WaiverAdmin(admin.ModelAdmin): inlines = [WaiverAuditInline] @admin.display(description="Latest Audit") - def latest_audit(self, obj): + def latest_audit(self, obj: Waiver) -> WaiverAudit: return obj.audits.latest() diff --git a/paperwork/models.py b/paperwork/models.py index 53e05aa..951be4a 100644 --- a/paperwork/models.py +++ b/paperwork/models.py @@ -1,11 +1,13 @@ import datetime import re +from typing import TypedDict, TYPE_CHECKING, Optional from semver import VersionInfo from django.db import models from django.db.models import OuterRef, Q, ExpressionWrapper, Subquery from django.conf import settings from django.core.validators import RegexValidator +from django_stubs_ext import WithAnnotations from membershipworks.models import Member, Flag as MembershipWorksFlag @@ -16,7 +18,7 @@ class AbstractAudit(models.Model): good = models.BooleanField(default=False) notes = models.CharField(max_length=255, blank=True) - def __str__(self): + def __str__(self) -> str: return f"{'Good' if self.good else 'Bad'} audit at {self.date} by {self.user}" class Meta: @@ -57,7 +59,7 @@ class CmsRedRiverVeteransScholarship(models.Model): db_column="Program Status", max_length=16, blank=True, null=True ) - def __str__(self): + def __str__(self) -> str: return f"{self.program_name} {self.member_name}" class Meta: @@ -84,19 +86,19 @@ class Department(models.Model): ) list_reply_to_address = models.EmailField(max_length=254, blank=True) - def __str__(self): + def __str__(self) -> str: return self.name @property - def list_name(self): + def list_name(self) -> Optional[str]: if self.has_mailing_list: return self.name.replace(" ", "_") + "-info" else: return None @property - def list_address(self): - if self.has_mailing_list: + def list_address(self) -> Optional[str]: + if self.list_name: return self.list_name + "@claremontmakerspace.org" else: return None @@ -111,7 +113,7 @@ class CertificationDefinition(models.Model): ) department = models.ForeignKey(Department, models.PROTECT) - def __str__(self): + def __str__(self) -> str: return f"{self.certification_name} <{self.department}>" class Meta: @@ -122,8 +124,13 @@ class CertificationDefinition(models.Model): return self.certificationversion_set.latest() -class CertificationVersionManager(models.Manager): - def get_queryset(self): +class CertificationVersionAnnotations(TypedDict): + is_latest: bool + is_current: bool + + +class CertificationVersionManager(models.Manager["CertificationVersion"]): + def get_queryset(self) -> models.QuerySet["CertificationVersion"]: qs = super().get_queryset() latest = qs.filter(definition__pk=OuterRef("definition__pk")).reverse() return qs.annotate( @@ -151,7 +158,7 @@ class CertificationVersion(models.Model): prerelease = models.CharField(max_length=255, blank=True) approval_date = models.DateField(blank=True, null=True) - def __str__(self): + def __str__(self) -> str: return f"{self.definition} [{self.semantic_version()}]" class Meta: @@ -182,6 +189,14 @@ class CertificationVersion(models.Model): ) +if TYPE_CHECKING: + CertificationVersionAnnotated = WithAnnotations[ + CertificationVersion, CertificationVersionAnnotations + ] +else: + CertificationVersionAnnotated = WithAnnotations[CertificationVersion] + + class Certification(models.Model): number = models.AutoField(db_column="Number", primary_key=True) certification_version = models.ForeignKey( @@ -206,7 +221,7 @@ class Certification(models.Model): ) notes = models.CharField(db_column="Notes", max_length=255, blank=True, null=True) - def __str__(self): + def __str__(self) -> str: return f"{self.name} - {self.certification_version}" class Meta: @@ -237,7 +252,7 @@ class InstructorOrVendor(models.Model): db_column="email address", max_length=255, blank=True, null=True ) - def __str__(self): + def __str__(self) -> str: return f"{self.name}" class Meta: @@ -274,7 +289,7 @@ class SpecialProgram(models.Model): db_column="Program Status", max_length=16, blank=True, null=True ) - def __str__(self): + def __str__(self) -> str: return self.program_name class Meta: @@ -300,7 +315,7 @@ class Waiver(models.Model): ) guardian_date = models.DateField(db_column="Guardian Date", blank=True, null=True) - def __str__(self): + def __str__(self) -> str: return f"{self.name} {self.date}" class Meta: