Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions validity/netbox_changes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from validity import config


def get_base_table_kwargs(self):
return {"user": self.request.user} if config.netbox_version < "4.5.4" else {}


StrFilterLookup = locate("strawberry_django.StrFilterLookup") if config.netbox_version >= "4.5.5" else FilterLookup[str]

if config.netbox_version >= "4.5.0":
Expand Down
67 changes: 59 additions & 8 deletions validity/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import textwrap
from functools import partial
from http import HTTPStatus
Expand All @@ -7,6 +8,7 @@
import pytest
from base import ViewTest
from django.urls import reverse
from django.utils import timezone
from django.utils.functional import classproperty
from factories import (
BackupPointFactory,
Expand Down Expand Up @@ -167,6 +169,15 @@ def test_get_serialized_state(admin_client, item, monkeypatch):
assert resp.status_code == HTTPStatus.OK


@pytest.mark.parametrize("query_params", [{}, {"sort": "test"}, {"sort": "-created"}])
@pytest.mark.django_db
def test_device_results(admin_client, query_params):
device = DeviceFactory()
CompTestResultFactory(device=device)
resp = admin_client.get(f"/dcim/devices/{device.pk}/results/", query_params)
assert resp.status_code == HTTPStatus.OK


@pytest.mark.parametrize("query_params", [{}, {"sort": "device"}, {"sort": "-device"}])
@pytest.mark.django_db
def test_report_devices(admin_client, query_params):
Expand Down Expand Up @@ -263,35 +274,75 @@ def test_datasource_devices(admin_client):
assert resp.status_code == HTTPStatus.OK


class TestRunTests:
url = "/plugins/validity/tests/run/"
class TestRunTestsView:
"""Covers RunTestsView (validity.views.script.RunTestsView)."""

@staticmethod
def _url():
return reverse("plugins:validity:compliancetest_run")

def test_get(self, admin_client):
resp = admin_client.get(self.url)
resp = admin_client.get(self._url())
assert resp.status_code == HTTPStatus.OK

@pytest.mark.parametrize(
"form_data, status_code, has_workers",
[
({}, HTTPStatus.FOUND, True),
({}, HTTPStatus.OK, False),
({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist
({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist — invalid choices
],
)
def test_post(self, admin_client, di, form_data, status_code, has_workers):
launcher = Mock(**{"has_workers": has_workers, "return_value.pk": 1})
with di.override({dependencies.runtests_launcher: lambda: launcher}):
result = admin_client.post(self.url, form_data)
result = admin_client.post(self._url(), form_data)
assert result.status_code == status_code
if status_code == HTTPStatus.FOUND: # if form is valid
launcher.assert_called_once()
assert isinstance(launcher.call_args.args[0], RunTestsParams)

@pytest.mark.django_db
def test_post_with_valid_devices(self, admin_client, di):
d1, d2 = DeviceFactory(), DeviceFactory()
launcher = Mock(has_workers=True, return_value=Mock(pk=1))
with di.override({dependencies.runtests_launcher: lambda: launcher}):
resp = admin_client.post(self._url(), {"devices": [d1.pk, d2.pk]})
assert resp.status_code == HTTPStatus.FOUND
launcher.assert_called_once()
assert isinstance(launcher.call_args.args[0], RunTestsParams)


@pytest.mark.parametrize("job_factory", [RunTestsJobFactory, DSBackupJobFactory])
def test_scriptresult(admin_client, job_factory):
job = job_factory(status="completed")
resp = admin_client.get(f"/plugins/validity/scripts/results/{job.pk}/")
@pytest.mark.django_db
def test_script_result_view_completed_job(admin_client, job_factory):
"""
Full GET for a finished job: ScriptResultView only builds the log table when ``job.completed``
(the *timestamp* field) is set — same as real jobs after ``terminate()``. ``status`` alone is not enough.
"""
completed_at = timezone.now()
job = job_factory(
status="completed",
started=completed_at - datetime.timedelta(minutes=1),
completed=completed_at,
data={"output": "test output", "log": []},
)
assert job.completed, "need completion timestamp set or get_table is skipped (differs from browser)"

url = reverse("plugins:validity:script_result", kwargs={"pk": job.pk})
resp = admin_client.get(url)
assert resp.status_code == HTTPStatus.OK, getattr(resp, "content", b"")[:2000]


@pytest.mark.parametrize("job_factory", [RunTestsJobFactory, DSBackupJobFactory])
@pytest.mark.django_db
def test_script_result_view_incomplete_job(admin_client, job_factory):
"""Running job has no completion timestamp, so get_table is not used."""
job = job_factory(status="running", started=timezone.now())
assert not job.completed

url = reverse("plugins:validity:script_result", kwargs={"pk": job.pk})
resp = admin_client.get(url)
assert resp.status_code == HTTPStatus.OK


Expand Down
5 changes: 2 additions & 3 deletions validity/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from utilities.views import ObjectPermissionRequiredMixin as _ObjectPermissionRequiredMixin
from utilities.views import ViewTab

from validity import filtersets, forms, models, scripts, tables
from validity import filtersets, forms, models, netbox_changes, scripts, tables
from validity.utils.misc import partialcls


Expand Down Expand Up @@ -100,8 +100,7 @@ def get_table(self, **kwargs):
table.exclude = (self.result_relation,)
return table

def get_table_kwargs(self):
return {"user": self.request.user}
get_table_kwargs = netbox_changes.get_base_table_kwargs

def get_queryset(self):
return self.queryset.filter(**{self.result_relation: self.kwargs["pk"]})
Expand Down
5 changes: 2 additions & 3 deletions validity/views/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from netbox.views import generic
from utilities.views import ViewTab, register_model_view

from validity import config, filtersets, forms, models, tables
from validity import filtersets, forms, models, netbox_changes, tables
from validity.choices import DeviceGroupByChoices, SeverityChoices
from .base import FilterViewWithForm, ObjectPermissionRequiredMixin, TestResultBaseView

Expand Down Expand Up @@ -102,8 +102,7 @@ def get_table(self, **kwargs):
table.configure(self.request)
return table

def get_table_kwargs(self):
return {"user": self.request.user} if config.netbox_version < "4.5.4" else {}
get_table_kwargs = netbox_changes.get_base_table_kwargs

def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
return super().get_context_data(**kwargs) | {
Expand Down
6 changes: 4 additions & 2 deletions validity/views/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from validity import di
from validity.forms import RunTestsForm
from validity.netbox_changes import get_logs
from validity.netbox_changes import get_base_table_kwargs, get_logs
from validity.scripts import Launcher, RunTestsParams, ScriptParams
from validity.tables import ScriptResultTable
from .base import LauncherMixin
Expand Down Expand Up @@ -52,10 +52,12 @@ class ScriptResultView(PermissionRequiredMixin, TableMixin, ObjectView):

def get_table(self, job, request, bulk_actions=False):
logs = [entry | {"index": i} for i, entry in enumerate(get_logs(job), start=1)]
table = self.table_class(logs, user=request.user)
table = self.table_class(logs, **self.get_table_kwargs())
table.configure(request)
return table

get_table_kwargs = get_base_table_kwargs

def get(self, request, **kwargs):
job = self.get_object(**kwargs)
table = self.get_table(job, request) if job.completed else None
Expand Down
Loading