Add methods that automatically handle paginated endpoints

This commit is contained in:
Adam Goldsmith 2024-11-26 12:42:45 -05:00
parent 79877e0303
commit 0c43370add
8 changed files with 534 additions and 25 deletions

View File

@ -4,6 +4,6 @@ Wrapper for Unifi Access API.
See the [official API Reference](https://core-config-gfoz.uid.alpha.ui.com/configs/unifi-access/api_reference.pdf) for more details.
"""
from ._client import AccessClient, Response, ResponseCode
from ._client import AccessClient, PaginatedResponse, ResponseCode, UnifiAccessError
__all__ = ["AccessClient", "Response", "ResponseCode"]
__all__ = ["AccessClient", "PaginatedResponse", "ResponseCode", "UnifiAccessError"]

View File

@ -1,7 +1,9 @@
import datetime
import functools
import math
from collections.abc import Iterable, Sequence
from enum import StrEnum, auto
from typing import Any, Generic, Literal, Never, Self, TypeVar
from typing import Any, Generic, Literal, Never, Protocol, Self
import requests
from pydantic import (
@ -10,6 +12,7 @@ from pydantic import (
RootModel,
TypeAdapter,
)
from typing_extensions import TypeVar
from unifi_access.schemas import (
AccessPolicy,
@ -63,7 +66,7 @@ from unifi_access.schemas._base import (
UnixTimestampDateTime,
)
from unifi_access.schemas._identity import IdentityResource, IdentityResourceType
from unifi_access.schemas._system_log import FetchSystemLogsResponse
from unifi_access.schemas._system_log import FetchSystemLogsResponse, SystemLogEntry
class ResponseCode(StrEnum):
@ -198,27 +201,63 @@ class ResponsePagination(ForbidExtraBaseModel):
# TODO: this has nicer syntax in Python 3.12, but not currently supported in Pydantic
ResponseDataType = TypeVar("ResponseDataType")
ResponsePaginationType = TypeVar("ResponsePaginationType", default=None)
class SuccessResponse(ForbidExtraBaseModel, Generic[ResponseDataType]):
class SuccessResponse(
ForbidExtraBaseModel, Generic[ResponseDataType, ResponsePaginationType]
):
"""A successful response containing data"""
code: Literal[ResponseCode.SUCCESS]
msg: str
# sometimes the Access API omits this when it would be null
data: ResponseDataType = Field(default=None, validate_default=True) # type: ignore
pagination: ResponsePagination | None = None
pagination: ResponsePaginationType = Field(default=None, validate_default=True) # type: ignore
def success_or_raise(self) -> Self:
return self
class Response(RootModel[SuccessResponse[ResponseDataType] | ErrorResponse]):
class BaseResponse(
RootModel[
SuccessResponse[ResponseDataType, ResponsePaginationType] | ErrorResponse
],
Generic[ResponseDataType, ResponsePaginationType],
):
@classmethod
def validate_and_unwrap(cls, r: requests.Response) -> ResponseDataType:
return cls.model_validate_json(r.content).root.success_or_raise().data
class PaginatedResponse(BaseResponse[ResponseDataType, ResponsePagination]):
pass
# TODO: this really shouldn't be necessary, but the default on ResponsePaginationType doesn't seem
# to work well through the RootModel in 3.11
class Response(BaseResponse[ResponseDataType, None]):
pass
T = TypeVar("T")
class PageNumberable(Protocol, Generic[T]):
def __call__(
self, page_num: int
) -> SuccessResponse[list[T], ResponsePagination]: ...
def iterate_pages(fxn: PageNumberable[T]) -> Iterable[T]:
resp = fxn(page_num=1)
yield from resp.data
total_pages = math.ceil(resp.pagination.total / resp.pagination.page_size)
for page_num in range(2, total_pages + 1):
yield from fxn(page_num=page_num).data
class RequestPagination(ForbidExtraBaseModel):
page_num: int | None = None
page_size: int | None = None
@ -324,8 +363,11 @@ class AccessClient:
expand_access_policies: bool = False,
page_num: int | None = None,
page_size: int | None = None,
) -> SuccessResponse[list[FullUser]]:
"""3.5 Fetch All Users"""
) -> SuccessResponse[list[FullUser], ResponsePagination]:
"""3.5 Fetch All Users
If you don't need to manually handle pagination, consider [`fetch_all_users__unpaged`][unifi_access.AccessClient.fetch_all_users__unpaged].
"""
class FetchAllUsersParams(RequestPagination):
model_config = ConfigDict(populate_by_name=True)
@ -339,9 +381,27 @@ class AccessClient:
r = self._session.get(f"{self._base_url}/users", params=params)
return (
Response[list[FullUser]].model_validate_json(r.content).root
PaginatedResponse[list[FullUser]].model_validate_json(r.content).root
).success_or_raise()
def fetch_all_users__unpaged(
self,
expand_access_policies: bool = False,
page_size: int | None = None,
) -> Iterable[FullUser]:
"""3.5 Fetch All Users
This will automatically handle pagination.
If you need more control, consider [`fetch_all_users`][unifi_access.AccessClient.fetch_all_users].
"""
yield from iterate_pages(
functools.partial(
self.fetch_all_users,
expand_access_policies=expand_access_policies,
page_size=page_size,
)
)
def assign_access_policy_to_user(
self, user_id: UserId, access_policy_ids: list[AccessPolicyId]
) -> None:
@ -600,8 +660,11 @@ class AccessClient:
expand: list[FetchAllVisitorsExpansion] | None = None,
page_num: int | None = None,
page_size: int | None = None,
) -> SuccessResponse[list[Visitor]]:
"""4.4 Fetch All Visitors"""
) -> SuccessResponse[list[Visitor], ResponsePagination]:
"""4.4 Fetch All Visitors
If you don't need to manually handle pagination, consider [`fetch_all_visitors__unpaged`][unifi_access.AccessClient.fetch_all_visitors__unpaged].
"""
class FetchAllVisitorsRequest(RequestPagination):
model_config = ConfigDict(populate_by_name=True)
@ -630,9 +693,31 @@ class AccessClient:
r = self._session.get(f"{self._base_url}/visitors", params=params)
return (
Response[list[Visitor]].model_validate_json(r.content).root
PaginatedResponse[list[Visitor]].model_validate_json(r.content).root
).success_or_raise()
def fetch_all_visitors__unpaged(
self,
status: VisitorStatus | None = None,
keyword: str | None = None,
expand: list[FetchAllVisitorsExpansion] | None = None,
page_size: int | None = None,
) -> Iterable[Visitor]:
"""4.4 Fetch All Visitors
This will automatically handle pagination.
If you need more control, consider [`fetch_all_visitors`][unifi_access.AccessClient.fetch_all_visitors].
"""
yield from iterate_pages(
functools.partial(
self.fetch_all_visitors,
status=status,
keyword=keyword,
expand=expand,
page_size=page_size,
)
)
def update_visitor( # noqa: PLR0913
self,
visitor_id: VisitorId,
@ -1006,16 +1091,31 @@ class AccessClient:
def fetch_all_nfc_cards(
self, page_num: int | None = None, page_size: int | None = None
) -> SuccessResponse[list[NfcCard]]:
"""6.8 Fetch NFC Cards"""
) -> SuccessResponse[list[NfcCard], ResponsePagination]:
"""6.8 Fetch NFC Cards
If you don't need to manually handle pagination, consider [`fetch_all_nfc_cards__unpaged`][unifi_access.AccessClient.fetch_all_nfc_cards__unpaged].
"""
params = RequestPagination(page_num=page_num, page_size=page_size)
r = self._session.get(
f"{self._base_url}/credentials/nfc_cards/tokens", params=params
)
return (
Response[list[NfcCard]].model_validate_json(r.content).root
PaginatedResponse[list[NfcCard]].model_validate_json(r.content).root
).success_or_raise()
def fetch_all_nfc_cards__unpaged(
self, page_size: int | None = None
) -> Iterable[NfcCard]:
"""6.8 Fetch NFC Cards
This will automatically handle pagination.
If you need more control, consider [`fetch_all_nfc_cards`][unifi_access.AccessClient.fetch_all_nfc_cards].
"""
yield from iterate_pages(
functools.partial(self.fetch_all_nfc_cards, page_size=page_size)
)
def delete_nfc_card(self, nfc_card_token: NfcCardToken) -> Literal["success"]:
"""6.7 Fetch NFC Card"""
r = self._session.delete(
@ -1141,8 +1241,11 @@ class AccessClient:
actor_id: ActorId | None = None,
page_num: int | None = None,
page_size: int | None = None,
) -> SuccessResponse[FetchSystemLogsResponse]:
"""9.2 Fetch System Logs"""
) -> SuccessResponse[FetchSystemLogsResponse, ResponsePagination]:
"""9.2 Fetch System Logs
If you don't need to manually handle pagination, consider [`fetch_all_system_logs__unpaged`][unifi_access.AccessClient.fetch_all_system_logs__unpaged].
"""
params = RequestPagination(page_num=page_num, page_size=page_size).model_dump(
exclude_none=True
)
@ -1164,9 +1267,46 @@ class AccessClient:
f"{self._base_url}/system/logs", params=params, json=body
)
return (
Response[FetchSystemLogsResponse].model_validate_json(r.content)
PaginatedResponse[FetchSystemLogsResponse].model_validate_json(r.content)
).root.success_or_raise()
def fetch_system_logs__unpaged(
self,
topic: SystemLogTopic,
since: datetime.datetime | None = None,
until: datetime.datetime | None = None,
actor_id: ActorId | None = None,
page_size: int | None = None,
) -> Iterable[SystemLogEntry]:
"""9.2 Fetch System Logs
This will automatically handle pagination.
If you need more control, consider [`fetch_all_system_logs`][unifi_access.AccessClient.fetch_all_system_logs].
"""
# can't just use `functools.partial` here, because we need to have `.data` return a list
# instead of `FetchSystemLogsResponse`
def extract_hits_wrapper(
page_num: int,
) -> SuccessResponse[list[SystemLogEntry], ResponsePagination]:
resp = self.fetch_system_logs(
topic=topic,
since=since,
until=until,
actor_id=actor_id,
page_size=page_size,
page_num=page_num,
)
return SuccessResponse[list[SystemLogEntry], ResponsePagination](
code=resp.code,
msg=resp.msg,
data=resp.data.hits,
pagination=resp.pagination,
)
yield from iterate_pages(extract_hits_wrapper)
def export_system_logs(
self,
topic: SystemLogTopic,

View File

@ -44,7 +44,7 @@ def test_user_lifecycle(live_access_client: AccessClient, user: User) -> None:
# Check for the user in full user list
# TODO: test pagination
all_users = live_access_client.fetch_all_users().data
all_users = live_access_client.fetch_all_users__unpaged()
matching_user = next(u for u in all_users if u.id == user.id)
assert matching_user.id == user.id
assert matching_user.first_name == "Test"

View File

@ -54,11 +54,11 @@ def test_visitor_lifecycle(
)
assert updated_visitor.first_name == "Updated Test"
all_visitors = live_access_client.fetch_all_visitors().data
all_visitors = live_access_client.fetch_all_visitors__unpaged()
matched_visitor = next(v for v in all_visitors if v.id == visitor.id)
assert matched_visitor.first_name == "Updated Test"
expanded_all_visitors = live_access_client.fetch_all_visitors(
expanded_all_visitors = live_access_client.fetch_all_visitors__unpaged(
expand=[
FetchAllVisitorsExpansion.ACCESS_POLICY,
FetchAllVisitorsExpansion.RESOURCE,
@ -66,16 +66,16 @@ def test_visitor_lifecycle(
FetchAllVisitorsExpansion.NFC_CARD,
FetchAllVisitorsExpansion.PIN_CODE,
]
).data
)
expanded_matched_visitor = next(
v for v in expanded_all_visitors if v.id == visitor.id
)
assert expanded_matched_visitor.first_name == "Updated Test"
# TODO: test expanded contents
non_expanded_all_visitors = live_access_client.fetch_all_visitors(
non_expanded_all_visitors = live_access_client.fetch_all_visitors__unpaged(
expand=[FetchAllVisitorsExpansion.NONE]
).data
)
non_expanded_matched_visitor = next(
v for v in non_expanded_all_visitors if v.id == visitor.id
)

View File

@ -167,6 +167,77 @@ class CredentialTests(UnifiAccessTests):
resp = self.client.fetch_all_nfc_cards(page_num=1, page_size=25)
# NOTE: not taken from API docs examples
@responses.activate
def test_fetch_all_nfc_cards__unpaged(self) -> None:
"""6.8 Fetch All NFC Cards"""
responses.get(
f"https://{self.host}/api/v1/developer/credentials/nfc_cards/tokens",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 1, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"alias": "",
"card_type": "ua_card",
"display_id": "100004",
"note": "100004",
"status": "assigned",
"token": "9e24cdfafebf63e58fd02c5f67732b478948e5793d31124239597d9a86b30dc4",
"user": {
"avatar": "",
"first_name": "H",
"id": "e0051e08-c4d5-43db-87c8-a9b19cb66513",
"last_name": "L",
"name": "H L",
},
"user_id": "e0051e08-c4d5-43db-87c8-a9b19cb66513",
"user_type": "USER",
},
],
"msg": "succ",
"pagination": {"page_num": 1, "page_size": 1, "total": 2},
},
)
responses.get(
f"https://{self.host}/api/v1/developer/credentials/nfc_cards/tokens",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 2, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"alias": "F77D69B03",
"card_type": "ua_card",
"display_id": "100005",
"note": "100005",
"status": "assigned",
"token": "f77d69b08eaf5eb5d647ac1a0a73580f1b27494b345f40f54fa022a8741fa15c",
"user": {
"avatar": "",
"first_name": "H2",
"id": "34dc90a7-409f-4bf8-a5a8-1c59535a21b9",
"last_name": "L",
"name": "H2 L",
},
"user_id": "34dc90a7-409f-4bf8-a5a8-1c59535a21b9",
"user_type": "VISITOR",
},
],
"msg": "succ",
"pagination": {"page_num": 1, "page_size": 1, "total": 2},
},
)
resp = list(self.client.fetch_all_nfc_cards__unpaged(page_size=1))
assert resp[0].display_id == "100004"
assert resp[1].display_id == "100005"
@responses.activate
def test_delete_nfc_card(self) -> None:
"""6.9 Delete NFC Card"""

View File

@ -83,6 +83,128 @@ class SystemLogTests(UnifiAccessTests):
actor_id=UserId("3e1f196e-c97b-4748-aecb-eab5e9c251b2"),
)
# NOTE: not taken from API docs examples
@responses.activate
def test_fetch_system_logs__unpaged(self) -> None:
"""9.2 Fetch System Logs"""
responses.post(
f"https://{self.host}/api/v1/developer/system/logs",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_size": 1, "page_num": 1}),
matchers.json_params_matcher({"topic": "door_openings"}),
],
json={
"code": "SUCCESS",
"data": {
"hits": [
{
"@timestamp": "2023-07-11T12:11:27Z",
"_id": "",
"_source": {
"actor": {
"alternate_id": "",
"alternate_name": "",
"display_name": "N/A",
"id": "",
"type": "user",
},
"authentication": {
"credential_provider": "NFC",
"issuer": "6FC02554",
},
"event": {
"display_message": "Access Denied / Unknown (NFC)",
"published": 1689077487000,
"reason": "",
"result": "BLOCKED",
"type": "access.door.unlock",
"log_key": "",
},
"target": [
{
"alternate_id": "",
"alternate_name": "",
"display_name": "UA-HUB-3855",
"id": "7483c2773855",
"type": "UAH",
}
],
},
"tag": "access",
}
]
},
"msg": "succ",
"pagination": {"page_num": 1, "page_size": 1, "total": 2},
},
)
responses.post(
f"https://{self.host}/api/v1/developer/system/logs",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_size": 1, "page_num": 2}),
matchers.json_params_matcher({"topic": "door_openings"}),
],
json={
"code": "SUCCESS",
"data": {
"hits": [
{
"@timestamp": "2023-07-12T12:11:27Z",
"_id": "",
"_source": {
"actor": {
"alternate_id": "",
"alternate_name": "",
"display_name": "N/A",
"id": "",
"type": "user",
},
"authentication": {
"credential_provider": "NFC",
"issuer": "6FC02554",
},
"event": {
"display_message": "Access Denied / Unknown (NFC)",
"published": 1689077487000,
"reason": "",
"result": "BLOCKED",
"type": "access.door.unlock",
"log_key": "",
},
"target": [
{
"alternate_id": "",
"alternate_name": "",
"display_name": "UA-HUB-3855",
"id": "7483c2773855",
"type": "UAH",
}
],
},
"tag": "access",
}
]
},
"msg": "succ",
"pagination": {"page_num": 2, "page_size": 1, "total": 2},
},
)
resp = list(
self.client.fetch_system_logs__unpaged(
page_size=1,
topic=SystemLogTopic.DOOR_OPENINGS,
)
)
assert resp[0].timestamp == datetime.datetime.fromisoformat(
"2023-07-11T12:11:27Z"
)
assert resp[1].timestamp == datetime.datetime.fromisoformat(
"2023-07-12T12:11:27Z"
)
@responses.activate
def test_export_system_logs(self) -> None:
"""9.3 Export System Logs"""

View File

@ -262,6 +262,85 @@ class UserTests(UnifiAccessTests):
assert resp.pagination
# TODO: verify correctness of data?
# NOTE: not taken from API docs examples
@responses.activate
def test_fetch_all_users__unpaged(self) -> None:
"""3.5 Fetch All Users, with pagination"""
responses.get(
f"https://{self.host}/api/v1/developer/users",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 1, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"access_policy_ids": ["73f15cab-c725-4a76-a419-a4026d131e96"],
"employee_number": "",
"first_name": "UniFi",
"id": "83569f9b-0096-48ab-b2e4-5c9a598568a8",
"last_name": "User",
"user_email": "",
"nfc_cards": [],
"onboard_time": 0,
"pin_code": None,
"status": "ACTIVE",
"alias": "",
"avatar_relative_path": "",
"email": "",
"email_status": "UNVERIFIED",
"full_name": "UniFi User",
"phone": "",
"username": "",
},
],
"msg": "success",
"pagination": {"page_num": 1, "page_size": 1, "total": 2},
},
)
responses.get(
f"https://{self.host}/api/v1/developer/users",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 2, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"access_policy_ids": ["c1682fb8-ef6e-4fe2-aa8a-b6f29df753ff"],
"employee_number": "",
"first_name": "Ttttt",
"id": "3a3ba57a-796e-46e0-b8f3-478bb70a114d",
"last_name": "Tttt",
"nfc_cards": [],
"onboard_time": 1689048000,
"pin_code": None,
"status": "ACTIVE",
"alias": "",
"avatar_relative_path": "",
"user_email": "",
"email": "",
"email_status": "UNVERIFIED",
"full_name": "Ttttt Tttt",
"phone": "",
"username": "",
},
],
"msg": "success",
"pagination": {"page_num": 2, "page_size": 1, "total": 2},
},
)
resp = list(
self.client.fetch_all_users__unpaged(
expand_access_policies=False, page_size=1
)
)
assert resp[0].id == "83569f9b-0096-48ab-b2e4-5c9a598568a8"
assert resp[1].id == "3a3ba57a-796e-46e0-b8f3-478bb70a114d"
@responses.activate
def test_assign_access_policy_to_user(self) -> None:
"""3.6 Assign Access Policy to User"""

View File

@ -454,6 +454,103 @@ class VisitorTests(UnifiAccessTests):
assert resp.pagination
# TODO: verify correctness of data?
# NOTE: not taken from API docs examples
@responses.activate
def test_fetch_all_visitors__unpaged(self) -> None:
"""4.4 Fetch All Visitors"""
responses.get(
f"https://{self.host}/api/v1/developer/visitors",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 1, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"avatar": "",
"email": "",
"end_time": 1731880901,
"first_name": "Test",
"id": "faaffd2e-b555-4991-810f-c18b36407c55",
"inviter_id": "",
"inviter_name": "",
"last_name": "Visitor",
"location_id": "",
"mobile_phone": "",
"nfc_cards": [],
"remarks": "",
"resources": [],
"schedule": {
"holiday_group": None,
"holiday_group_id": "",
"holiday_schedule": [],
"id": "",
"is_default": False,
"name": "",
"type": "",
"weekly": None,
},
"schedule_id": "",
"start_time": 1731794501,
"status": "UPCOMING",
"visit_reason": "Business",
"visitor_company": "",
}
],
"msg": "succ",
"pagination": {"page_num": 1, "page_size": 1, "total": 2},
},
)
responses.get(
f"https://{self.host}/api/v1/developer/visitors",
match=[
matchers.header_matcher(self.common_headers),
matchers.query_param_matcher({"page_num": 2, "page_size": 1}),
],
json={
"code": "SUCCESS",
"data": [
{
"avatar": "",
"email": "",
"end_time": 1731880901,
"first_name": "Test",
"id": "173c4cb9-e174-4a83-89fa-01ba8f25362f",
"inviter_id": "",
"inviter_name": "",
"last_name": "Visitor",
"location_id": "",
"mobile_phone": "",
"nfc_cards": [],
"remarks": "",
"resources": [],
"schedule": {
"holiday_group": None,
"holiday_group_id": "",
"holiday_schedule": [],
"id": "",
"is_default": False,
"name": "",
"type": "",
"weekly": None,
},
"schedule_id": "",
"start_time": 1731794501,
"status": "UPCOMING",
"visit_reason": "Business",
"visitor_company": "",
}
],
"msg": "succ",
"pagination": {"page_num": 2, "page_size": 1, "total": 2},
},
)
resp = list(self.client.fetch_all_visitors__unpaged(page_size=1))
assert resp[0].id == "faaffd2e-b555-4991-810f-c18b36407c55"
assert resp[1].id == "173c4cb9-e174-4a83-89fa-01ba8f25362f"
@responses.activate
def test_update_visitor(self) -> None:
"""4.5 Update Visitor"""