)
+[
]()
+[
]()
-
-:License: BSD
-
-====
-
-Overview
-====
+## Overview
SCRAM is a web based service to assist in automation of security data. There is a web interface as well as a REST API available.
+
The idea is to create actiontypes which allow you to take actions on the IPs/cidr networks you provide.
-Components
-====
+## Components
+
SCRAM utilizes ``docker compose`` to run the following stack in production:
- nginx (as a webserver and static asset server)
@@ -35,8 +26,7 @@ SCRAM utilizes ``docker compose`` to run the following stack in production:
A predefined actiontype of "block" exists which utilizes bgp nullrouting to effectivley block any traffic you want to apply.
You can add any other actiontypes via the admin page of the web interface dynamically, but keep in mind translator support would need to be added as well.
-Installation
-====
+## Installation
To get a basic implementation up and running locally:
@@ -48,6 +38,9 @@ To get a basic implementation up and running locally:
- Create ``$scram_home/.envs/.production/.postgres`` a template exists in the docs/templates directory
- Make sure to set the right credentials
- By default this template assumes you have a service defined in docker compose file called postgres. If you use another postgres server, make sure to update that setting as well
+- Create a ``.env`` file with the necessary environment variables:
+ - [comment]: # This chooses if you want to use oidc or local accounts. This can be local or oidc only. Default: `local`
+ - scram_auth_method: "local"
- ``make build``
- ``make toggle-prod``
- This will turn off debug mode in django and start using nginx to reverse proxy for the app
@@ -56,8 +49,7 @@ To get a basic implementation up and running locally:
- ``make django-open``
-*** Copyright Notice ***
-====
+## Copyright Notice
Security Catch and Release Automation Manager (SCRAM) Copyright (c) 2022,
The Regents of the University of California, through Lawrence Berkeley
diff --git a/docs/__init__.py b/docs/__init__.py
index 8772c827..d34efb05 100644
--- a/docs/__init__.py
+++ b/docs/__init__.py
@@ -1 +1 @@
-# Included so that Django's startproject comment runs against the docs directory
+"""Included so that Django's startproject comment runs against the docs directory."""
diff --git a/docs/conf.py b/docs/conf.py
deleted file mode 100644
index d4d4bad7..00000000
--- a/docs/conf.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# This file only contains a selection of the most common options. For a full
-# list see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
-
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-
-import os
-import sys
-
-import django
-
-if os.getenv("READTHEDOCS", default=False) == "True":
- sys.path.insert(0, os.path.abspath(".."))
- os.environ["DJANGO_READ_DOT_ENV_FILE"] = "True"
- os.environ["USE_DOCKER"] = "no"
-else:
- sys.path.insert(0, os.path.abspath("/app"))
-os.environ["DATABASE_URL"] = "sqlite:///readthedocs.db"
-os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
-django.setup()
-
-# -- Project information -----------------------------------------------------
-
-project = "SCRAM"
-copyright = """2021, Sam Oehlert"""
-author = "Sam Oehlert"
-
-
-# -- General configuration ---------------------------------------------------
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- "sphinx.ext.autodoc",
- "sphinx.ext.napoleon",
-]
-
-# Add any paths that contain templates here, relative to this directory.
-# templates_path = ["_templates"]
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
-
-# -- Options for HTML output -------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = "alabaster"
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-# html_static_path = ["_static"]
diff --git a/docs/cookiecutter.md b/docs/cookiecutter.md
deleted file mode 100644
index 17445af4..00000000
--- a/docs/cookiecutter.md
+++ /dev/null
@@ -1,71 +0,0 @@
-This documentation was initially provided in the README.rst from the cookiecutter used to generate the project
-
-Settings
---------
-
-Moved to settings_.
-
-.. _settings: http://cookiecutter-django.readthedocs.io/en/latest/settings.html
-
-Basic Commands
---------------
-
-Setting Up Your Users
-^^^^^^^^^^^^^^^^^^^^^
-
-* To create a **normal user account**, just go to Sign Up and fill out the form. Once you submit it, you'll see a "Verify Your E-mail Address" page. Go to your console to see a simulated email verification message. Copy the link into your browser. Now the user's email should be verified and ready to go.
-
-* To create an **superuser account**, use this command::
-
- $ python manage.py createsuperuser
-
-For convenience, you can keep your normal user logged in on Chrome and your superuser logged in on Firefox (or similar), so that you can see how the site behaves for both kinds of users.
-
-Type checks
-^^^^^^^^^^^
-
-Running type checks with mypy:
-
-::
-
- $ mypy scram
-
-Test coverage
-^^^^^^^^^^^^^
-
-To run the tests, check your test coverage, and generate an HTML coverage report::
-
- $ coverage run -m pytest
- $ coverage html
- $ open htmlcov/index.html
-
-Running tests with py.test
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-::
-
- $ pytest
-
-Live reloading and Sass CSS compilation
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Moved to `Live reloading and SASS compilation`_.
-
-.. _`Live reloading and SASS compilation`: http://cookiecutter-django.readthedocs.io/en/latest/live-reloading-and-sass-compilation.html
-
-
-
-Deployment
-----------
-
-The following details how to deploy this application.
-
-
-
-Docker
-^^^^^^
-
-See detailed `cookiecutter-django Docker documentation`_.
-
-.. _`cookiecutter-django Docker documentation`: http://cookiecutter-django.readthedocs.io/en/latest/deployment-with-docker.html
-
diff --git a/docs/environment_variables.md b/docs/environment_variables.md
new file mode 100644
index 00000000..c44629a7
--- /dev/null
+++ b/docs/environment_variables.md
@@ -0,0 +1,76 @@
+## Environment Variables to Set for Deployment
+[comment]: # Which branch of SCRAM to use (you probably want to set it to a release tag)
+scram_code_branch:
+#### Systems
+[comment]: # Email of the main admin
+scram_manager_email:
+[comment]: # Set to true for production mode; set to false to set up the compose.override.local.yml stack
+scram_prod: true
+[comment]: # Set to true if you want ansible to install a scram user
+scram_install_user: true
+[comment]: # What group to put `scram` user in
+scram_group: 'scram'
+[comment]: # What username to use for `scram` user
+scram_user: ''
+[comment]: # WHat uid to use for `scram` user
+scram_uid: ''
+[comment]: # What directory to use for base of the repo
+scram_home: '/usr/local/scram'
+[comment]: # IP or DNS record for your postgres host
+scram_postgres_host:
+[comment]: # What postgres user to use
+scram_postgres_user: ''
+
+#### Authentication
+[comment]: # This chooses if you want to use oidc or local accounts. This can be local or oidc only. Default: `local`
+scram_auth_method: "local"
+[comment]: # This client id (username) for your oidc connection. Only need to set this if you are trying to do oidc.
+scram_oidc_client_id:
+
+#### Networking
+[comment]: # What is the peering interface docker uses for gobgp to talk to the router
+scram_peering_iface: 'ens192'
+[comment]: # The v6 network of your peering connection
+scram_v4_subnet: '10.0.0.0/24'
+[comment]: # The v4 IP of the peering connection for the router side
+scram_v4_gateway: '10.0.0.1'
+[comment]: # The v4 IP of the peering connection for gobgp side
+scram_v4_address: '10.0.0.2'
+[comment]: # The v6 network of your peering connection
+scram_v6_subnet: '2001:db8::/64'
+[comment]: # The v6 IP of the peering connection for the router side
+scram_v6_gateway: '2001:db8::2'
+[comment]: # The v6 IP of the peering connection for the gobgp side
+scram_v6_address: '2001:db8::3'
+[comment]: # The AS you want to use for gobgp
+scram_as:
+[comment]: # A string representing your gobgp instance. Often seen as the local IP of the gobgp instance
+scram_router_id:
+[comment]: #
+scram_peer_as:
+[comment]: # The AS you want to use for gobgp side (can this be the same as `scram_as`?)
+scram_local_as:
+[comment]: # The fqdn of the server hosting this - to be used for nginx
+scram_nginx_host:
+[comment]: # List of allowed hosts per the django setting "ALLOWED_HOSTS". This should be a list of strings in shell
+[comment]: # `django` is required for the websockets to work
+[comment]: # Our Ansible assumes `django` + `scram_nginx_host`
+scram_django_allowed_hosts: "django"
+[comment]: # The fqdn of the server hosting this - to be used for nginx
+scram_server_alias:
+[comment]: # Do you want to set an md5 for authentication of bgp
+scram_bgp_md5_enabled: false
+[comment]: # The neighbor config of your gobgp config
+scram_neighbors:
+[comment]: # The v6 address of your neighbor
+ - neighbor_address: 2001:db8::2
+[comment]: # This is a v6 address so don't use v4
+ ipv4: false
+[comment]: # This is a v6 address so use v6
+ ipv6: true
+[comment]: # The v4 address of your neighbor
+ - neighbor_address: 10.0.0.200
+[comment]: # This is a v4 address so use v4
+ ipv4: true
+[comment]: # This is a v4 address so don't use v6
+ ipv6: false
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index a22d92b7..00000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,46 +0,0 @@
-@ECHO OFF
-
-pushd %~dp0
-
-REM Command file for Sphinx documentation
-
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build -c .
-)
-set SOURCEDIR=_source
-set BUILDDIR=_build
-set APP=..\scram
-
-if "%1" == "" goto help
-
-%SPHINXBUILD% >NUL 2>NUL
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.Install sphinx-autobuild for live serving.
- echo.If you don't have Sphinx installed, grab it from
- echo.http://sphinx-doc.org/
- exit /b 1
-)
-
-%SPHINXBUILD% -b %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-goto end
-
-:livehtml
-sphinx-autobuild -b html --open-browser -p 7000 --watch %APP% -c . %SOURCEDIR% %BUILDDIR%/html
-GOTO :EOF
-
-:apidocs
-sphinx-apidoc -o %SOURCEDIR%/api %APP%
-GOTO :EOF
-
-:help
-%SPHINXBUILD% -b help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-
-:end
-popd
diff --git a/docs/reference/django.md b/docs/reference/django.md
new file mode 100644
index 00000000..be734554
--- /dev/null
+++ b/docs/reference/django.md
@@ -0,0 +1,3 @@
+# Route Manager
+
+::: scram.route_manager
diff --git a/docs/reference/index.md b/docs/reference/index.md
new file mode 100644
index 00000000..fc771185
--- /dev/null
+++ b/docs/reference/index.md
@@ -0,0 +1,7 @@
+# Module Reference
+
+The SCRAM ecosystem consists of two parts:
+
+A Django app, [route_manager](/reference/django)
+
+A translator service, [translator](/reference/translator)
diff --git a/docs/reference/translator.md b/docs/reference/translator.md
new file mode 100644
index 00000000..897d84e2
--- /dev/null
+++ b/docs/reference/translator.md
@@ -0,0 +1,3 @@
+# Translator
+
+::: translator
diff --git a/manage.py b/manage.py
index 37181fe1..cbc80d8d 100755
--- a/manage.py
+++ b/manage.py
@@ -1,27 +1,25 @@
#!/usr/bin/env python
+"""Django's command-line utility for administrative tasks."""
+
import os
import sys
from pathlib import Path
-if __name__ == "__main__":
- os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
+def main():
+ """Run administrative tasks."""
+ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
try:
- from django.core.management import execute_from_command_line
- except ImportError:
- # The above import may fail for some other reason. Ensure that the
- # issue is really that Django is missing to avoid masking other
- # exceptions on Python 2.
- try:
- import django # noqa
- except ImportError:
- raise ImportError(
- "Couldn't import Django. Are you sure it's installed and "
- "available on your PYTHONPATH environment variable? Did you "
- "forget to activate a virtual environment?"
- )
-
- raise
+ from django.core.management import execute_from_command_line # noqa: PLC0415
+ except ImportError as exc:
+ msg = (
+ "Couldn't import Django. Are you sure it's installed and "
+ "available on your PYTHONPATH environment variable? Did you "
+ "forget to activate a virtual environment?"
+ )
+ raise ImportError(
+ msg,
+ ) from exc
# This allows easy placement of apps within the interior
# scram directory.
@@ -29,3 +27,7 @@
sys.path.append(str(current_path / "scram"))
execute_from_command_line(sys.argv)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/merge_production_dotenvs_in_dotenv.py b/merge_production_dotenvs_in_dotenv.py
deleted file mode 100644
index b66558c3..00000000
--- a/merge_production_dotenvs_in_dotenv.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import os
-from pathlib import Path
-from typing import Sequence
-
-import pytest
-
-ROOT_DIR_PATH = Path(__file__).parent.resolve()
-PRODUCTION_DOTENVS_DIR_PATH = ROOT_DIR_PATH / ".envs" / ".production"
-PRODUCTION_DOTENV_FILE_PATHS = [
- PRODUCTION_DOTENVS_DIR_PATH / ".django",
- PRODUCTION_DOTENVS_DIR_PATH / ".postgres",
-]
-DOTENV_FILE_PATH = ROOT_DIR_PATH / ".env"
-
-
-def merge(output_file_path: str, merged_file_paths: Sequence[str], append_linesep: bool = True) -> None:
- with open(output_file_path, "w") as output_file:
- for merged_file_path in merged_file_paths:
- with open(merged_file_path, "r") as merged_file:
- merged_file_content = merged_file.read()
- output_file.write(merged_file_content)
- if append_linesep:
- output_file.write(os.linesep)
-
-
-def main():
- merge(DOTENV_FILE_PATH, PRODUCTION_DOTENV_FILE_PATHS)
-
-
-@pytest.mark.parametrize("merged_file_count", range(3))
-@pytest.mark.parametrize("append_linesep", [True, False])
-def test_merge(tmpdir_factory, merged_file_count: int, append_linesep: bool):
- tmp_dir_path = Path(str(tmpdir_factory.getbasetemp()))
-
- output_file_path = tmp_dir_path / ".env"
-
- expected_output_file_content = ""
- merged_file_paths = []
- for i in range(merged_file_count):
- merged_file_ord = i + 1
-
- merged_filename = ".service{}".format(merged_file_ord)
- merged_file_path = tmp_dir_path / merged_filename
-
- merged_file_content = merged_filename * merged_file_ord
-
- with open(merged_file_path, "w+") as file:
- file.write(merged_file_content)
-
- expected_output_file_content += merged_file_content
- if append_linesep:
- expected_output_file_content += os.linesep
-
- merged_file_paths.append(merged_file_path)
-
- merge(output_file_path, merged_file_paths, append_linesep)
-
- with open(output_file_path, "r") as output_file:
- actual_output_file_content = output_file.read()
-
- assert actual_output_file_content == expected_output_file_content
-
-
-if __name__ == "__main__":
- main()
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 00000000..704db80c
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,16 @@
+site_name: Security Catch and Release Automation Manager
+
+theme:
+ name: "material"
+
+plugins:
+ - mkdocstrings:
+ handlers:
+ python:
+ import:
+ - url: https://docs.djangoproject.com/en/4.2/_objects/
+ base_url: https://docs.djangoproject.com/en/4.2/
+ options:
+ show_submodules: true
+ - search
+ - section-index
diff --git a/pyproject.toml b/pyproject.toml
index 54476b25..5c332104 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,27 +9,107 @@ python_files = [
# ==== Coverage ====
[tool.coverage.run]
-omit = ["*/migrations/*", "*/tests/*"]
+include = ["scram/*", "config/*", "translator/*"]
+omit = ["**/migrations/*", "scram/contrib/*", "*/tests/*"]
plugins = ["django_coverage_plugin"]
+branch = true
+data_file = "coverage.coverage"
+
+[tool.coverage.report]
+exclude_also = [
+ "if debug:",
+ "if self.debug:",
+ "if settings.DEBUG",
+ "raise AssertionError",
+ "raise NotImplementedError",
+ "if __name__ == .__main__.:",
+ ]
+
+# ===== ruff ====
+[tool.ruff]
+exclude = [
+ "migrations",
+]
-
-# ==== black ====
-[tool.black]
line-length = 119
-target-version = ['py311']
+target-version = 'py312'
+preview = true
+
+[tool.ruff.lint]
+select = [
+ "A", # builtins
+ "ASYNC", # async
+ "B", # bugbear
+ "BLE", # blind-except
+ "C4", # comprehensions
+ "C90", # complexity
+ "COM", # commas
+ "D", # pydocstyle
+ "DJ", # django
+ "DOC", # pydoclint
+ "DTZ", # datetimez
+ "E", # pycodestyle
+ "EM", # errmsg
+ "ERA", # eradicate
+ "F", # pyflakes
+ "FBT", # boolean-trap
+ "FLY", # flynt
+ "G", # logging-format
+ "I", # isort
+ "ICN", # import-conventions
+ "ISC", # implicit-str-concat
+ "LOG", # logging
+ "N", # pep8-naming
+ "PERF", # perflint
+ "PIE", # pie
+ "PL", # pylint
+ "PTH", # use-pathlib
+ "Q", # quotes
+ "RET", # return
+ "RSE", # raise
+ "RUF", # Ruff
+ "S", # bandit
+ "SIM", # simplify
+ "SLF", # self
+ "SLOT", # slots
+ "T20", # print
+ "TRY", # tryceratops
+ "UP", # pyupgrade
+]
+ignore = [
+ "COM812", # handled by the formatter
+ "DOC501", # add possible exceptions to the docstring (TODO)
+ "ISC001", # handled by the formatter
+ "RUF012", # need more widespread typing
+ "SIM102", # Use a single `if` statement instead of nested `if` statements
+ "SIM108", # Use ternary operator instead of `if`-`else`-block
+]
+[tool.ruff.lint.mccabe]
+max-complexity = 7 # our current code adheres to this without too much effort
-# ==== isort ====
-[tool.isort]
-profile = "black"
-line_length = 119
-known_first_party = [
- "scram",
- "config",
+[tool.ruff.lint.per-file-ignores]
+"**/{tests}/*" = [
+ "DOC201", # documenting return values
+ "DOC402", # documenting yield values
+ "PLR6301", # could be a static method
+ "S101", # use of assert
+ "S106", # hardcoded password
+]
+"test.py" = [
+ "S105", # hardcoded password as argument
+]
+"scram/users/**" = [
+ "DOC201", # documenting return values
+ "FBT001", # minimal issue; don't need to mess with in the User app
+ "PLR2004", # magic values when checking HTTP status codes
+]
+"**/views.py" = [
+ "DOC201", # documenting return values; it's fairly obvious in a View
]
-skip = ["venv/"]
-skip_glob = ["**/migrations/*.py"]
+[tool.ruff.lint.pydocstyle]
+convention = "google"
# ==== mypy ====
[tool.mypy]
@@ -51,53 +131,3 @@ ignore_errors = true
[tool.django-stubs]
django_settings_module = "config.settings.test"
-
-
-# ==== PyLint ====
-[tool.pylint.MASTER]
-load-plugins = [
- "pylint_django",
-]
-django-settings-module = "config.settings.local"
-
-[tool.pylint.FORMAT]
-max-line-length = 119
-
-[tool.pylint."MESSAGES CONTROL"]
-disable = [
- "missing-docstring",
- "invalid-name",
-]
-
-[tool.pylint.DESIGN]
-max-parents = 13
-
-[tool.pylint.TYPECHECK]
-generated-members = [
- "REQUEST",
- "acl_users",
- "aq_parent",
- "[a-zA-Z]+_set{1,2}",
- "save",
- "delete",
-]
-
-
-# ==== djLint ====
-[tool.djlint]
-blank_line_after_tag = "load,extends"
-close_void_tags = true
-format_css = true
-format_js = true
-# TODO: remove T002 when fixed https://github.com/Riverside-Healthcare/djLint/issues/687
-ignore = "H006,H021,H023,H025,H029,H030,H031,T002,T003"
-include = "H017,H035"
-indent = 2
-max_line_length = 119
-profile = "django"
-
-[tool.djlint.css]
-indent_size = 2
-
-[tool.djlint.js]
-indent_size = 2
diff --git a/pytest.ini b/pytest.ini
deleted file mode 100644
index c2b3a233..00000000
--- a/pytest.ini
+++ /dev/null
@@ -1,3 +0,0 @@
-[pytest]
-addopts = --ds=config.settings.test --reuse-db
-python_files = tests.py test_*.py
diff --git a/requirements/base.txt b/requirements/base.txt
index cd489161..b43b582a 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -1,38 +1,31 @@
-pytz==2022.1 # https://github.com/stub42/pytz
-django-netfields==1.2.2 # https://pypi.org/project/django-netfields/
-python-slugify==6.1.2 # https://github.com/un33k/python-slugify
-Pillow==9.1.1 # https://github.com/python-pillow/Pillow
argon2-cffi==21.3.0 # https://github.com/hynek/argon2_cffi
-whitenoise==6.2.0 # https://github.com/evansd/whitenoise
-celery==5.2.7 # pyup: < 6.0 # https://github.com/celery/celery
-django-celery-beat==2.3.0 # https://github.com/celery/django-celery-beat
+django-celery-beat # https://github.com/celery/django-celery-beat
+django-netfields # https://pypi.org/project/django-netfields/
flower==1.0.0 # https://github.com/mher/flower
+python-slugify==6.1.2 # https://github.com/un33k/python-slugify
uvicorn[standard]==0.17.6 # https://github.com/encode/uvicorn
-asgiref~=3.5.0 # https://github.com/django/asgiref
+whitenoise==6.2.0 # https://github.com/evansd/whitenoise
# Django
# ------------------------------------------------------------------------------
-channels==3.0.4
channels_redis==3.4.0
-django-redis==5.3.0
-django==3.2.13 # pyup: < 4.0 # https://www.djangoproject.com/
-django-grip==3.4.0
-django-eventstream==4.5.1 # https://github.com/fanout/django-eventstream
+crispy-bootstrap5==0.6 # https://github.com/django-crispy-forms/crispy-bootstrap5
+django~=4.2 # https://www.djangoproject.com/
+
+django-allauth==0.51.0 # https://github.com/pennersr/django-allauth
django-environ==0.9.0 # https://github.com/joke2k/django-environ
+django-eventstream==4.5.1 # https://github.com/fanout/django-eventstream
django-model-utils==4.2.0 # https://github.com/jazzband/django-model-utils
-django-allauth==0.51.0 # https://github.com/pennersr/django-allauth
-django-crispy-forms==1.14.0 # https://github.com/django-crispy-forms/django-crispy-forms
-crispy-bootstrap5==0.6 # https://github.com/django-crispy-forms/crispy-bootstrap5
+django-redis==5.3.0
+django-simple-history~=3.1.1
# Django REST Framework
-djangorestframework==3.13.1 # https://github.com/encode/django-rest-framework
+djangorestframework~=3.15 # https://github.com/encode/django-rest-framework
django-cors-headers==3.13.0 # https://github.com/adamchainz/django-cors-headers
-# DRF-spectacular for api documentation
-drf-spectacular==0.22.1 # https://github.com/tfranzel/drf-spectacular
+drf-spectacular # https://github.com/tfranzel/drf-spectacular
# OIDC
# ------------------------------------------------------------------------------
mozilla_django_oidc==2.0.0 # https://github.com/mozilla/mozilla-django-oidc
websockets~=10.3
-django-simple-history~=3.1.1
diff --git a/requirements/local.txt b/requirements/local.txt
index edc15a30..134a2213 100644
--- a/requirements/local.txt
+++ b/requirements/local.txt
@@ -2,28 +2,32 @@
Werkzeug[watchdog]==2.0.3 # https://github.com/pallets/werkzeug
ipdb==0.13.9 # https://github.com/gotcha/ipdb
-psycopg2-binary==2.9.3 # https://github.com/psycopg/psycopg2
+psycopg2-binary==2.9.10 # https://github.com/psycopg/psycopg2
watchgod==0.8.2 # https://github.com/samuelcolvin/watchgod
# Testing
# ------------------------------------------------------------------------------
-mypy==0.950 # https://github.com/python/mypy
django-stubs==1.11.0 # https://github.com/typeddjango/django-stubs
-pytest==7.1.2 # https://github.com/pytest-dev/pytest
-pytest-sugar==0.9.4 # https://github.com/Frozenball/pytest-sugar
+pytest-sugar==0.9.6 # https://github.com/Frozenball/pytest-sugar
behave-django==1.4.0 # https://github.com/behave/behave-django
-django-debug-toolbar~=3.2 # https://github.com/jazzband/django-debug-toolbar
# Documentation
# ------------------------------------------------------------------------------
-sphinx==5.0.1 # https://github.com/sphinx-doc/sphinx
-sphinx-autobuild==2021.3.14 # https://github.com/GaretJax/sphinx-autobuild
+mkdocs==1.6.1
+mkdocs-autorefs==1.2.0
+mkdocs-gen-files==0.5.0
+mkdocs-get-deps==0.2.0
+mkdocs-literate-nav==0.6.1
+mkdocs-material==9.5.44
+mkdocs-material-extensions==1.3.1
+mkdocs-section-index==0.3.9
+mkdocstrings==0.27.0
+mkdocstrings-python==1.12.2
# Code quality
# ------------------------------------------------------------------------------
-flake8==4.0.1 # https://github.com/PyCQA/flake8
+flake8-docstrings==1.7.0 # https://github.com/pycqa/flake8-docstrings
flake8-isort==4.1.1 # https://github.com/gforcada/flake8-isort
-coverage==6.4.1 # https://github.com/nedbat/coveragepy
black==22.3.0 # https://github.com/psf/black
pylint-django==2.5.3 # https://github.com/PyCQA/pylint-django
pylint-celery==0.3 # https://github.com/PyCQA/pylint-celery
@@ -35,10 +39,8 @@ factory-boy==3.2.1 # https://github.com/FactoryBoy/factory_boy
django-debug-toolbar==3.4.0 # https://github.com/jazzband/django-debug-toolbar
django-extensions==3.1.5 # https://github.com/django-extensions/django-extensions
-django-coverage-plugin==2.0.3 # https://github.com/nedbat/django_coverage_plugin
+django-coverage-plugin==3.1.0 # https://github.com/nedbat/django_coverage_plugin
pytest-django==4.5.2 # https://github.com/pytest-dev/pytest-django
-pytest~=7.1.2
-behave~=1.2.6 # https://behave.readthedocs.io/en/stable/
# Debugging
# ------------------------------------------------------------------------------
diff --git a/scram/__init__.py b/scram/__init__.py
index eed836da..8f5a0220 100644
--- a/scram/__init__.py
+++ b/scram/__init__.py
@@ -1,2 +1,4 @@
-__version__ = "0.1.0"
-__version_info__ = tuple([int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")])
+"""The Django project for Security Catch and Release Automation Manager (SCRAM)."""
+
+__version__ = "1.1.1"
+__version_info__ = tuple(int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split("."))
diff --git a/scram/conftest.py b/scram/conftest.py
index 52d755b4..29a5bc61 100644
--- a/scram/conftest.py
+++ b/scram/conftest.py
@@ -1,3 +1,5 @@
+"""Some simple tests for the User app."""
+
import pytest
from scram.users.models import User
@@ -6,9 +8,11 @@
@pytest.fixture(autouse=True)
def media_storage(settings, tmpdir):
+ """Configure the test to use the temp directory."""
settings.MEDIA_ROOT = tmpdir.strpath
@pytest.fixture
def user() -> User:
+ """Return the UserFactory."""
return UserFactory()
diff --git a/scram/contrib/__init__.py b/scram/contrib/__init__.py
index 1c7ecc89..adb76b31 100644
--- a/scram/contrib/__init__.py
+++ b/scram/contrib/__init__.py
@@ -1,5 +1,4 @@
-"""
-To understand why this file is here, please read:
+"""To understand why this file is here, please read the CookieCutter documentation.
http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django
"""
diff --git a/scram/contrib/sites/__init__.py b/scram/contrib/sites/__init__.py
index 1c7ecc89..adb76b31 100644
--- a/scram/contrib/sites/__init__.py
+++ b/scram/contrib/sites/__init__.py
@@ -1,5 +1,4 @@
-"""
-To understand why this file is here, please read:
+"""To understand why this file is here, please read the CookieCutter documentation.
http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django
"""
diff --git a/scram/local_auth/__init__.py b/scram/local_auth/__init__.py
new file mode 100644
index 00000000..113a7bfa
--- /dev/null
+++ b/scram/local_auth/__init__.py
@@ -0,0 +1 @@
+"""Local_auth is the app that holds urls we want to use with local Django auth."""
diff --git a/scram/local_auth/urls.py b/scram/local_auth/urls.py
new file mode 100644
index 00000000..56c7b727
--- /dev/null
+++ b/scram/local_auth/urls.py
@@ -0,0 +1,15 @@
+"""Register URLs for local auth known to Django, and the View that will handle each."""
+
+from django.contrib.auth.views import LoginView, LogoutView
+from django.urls import path
+
+app_name = "local_auth"
+
+urlpatterns = [
+ path(
+ "login/",
+ LoginView.as_view(template_name="local_auth/login.html", success_url="route_manager:home"),
+ name="login",
+ ),
+ path("logout/", LogoutView.as_view(), name="logout"),
+]
diff --git a/scram/route_manager/__init__.py b/scram/route_manager/__init__.py
index e69de29b..db37df27 100644
--- a/scram/route_manager/__init__.py
+++ b/scram/route_manager/__init__.py
@@ -0,0 +1 @@
+"""Route Manager is the core app, and it handles adding and removing routes that will be blocked."""
diff --git a/scram/route_manager/admin.py b/scram/route_manager/admin.py
index 96a6e324..172dd6c7 100644
--- a/scram/route_manager/admin.py
+++ b/scram/route_manager/admin.py
@@ -1,3 +1,5 @@
+"""Register models in the Admin site."""
+
from django.contrib import admin
from simple_history.admin import SimpleHistoryAdmin
@@ -6,6 +8,8 @@
@admin.register(ActionType)
class ActionTypeAdmin(SimpleHistoryAdmin):
+ """Configure the ActionType and how it shows up in the Admin site."""
+
list_filter = ("available",)
list_display = ("name", "available")
diff --git a/scram/route_manager/api/__init__.py b/scram/route_manager/api/__init__.py
index e69de29b..26fc9ceb 100644
--- a/scram/route_manager/api/__init__.py
+++ b/scram/route_manager/api/__init__.py
@@ -0,0 +1 @@
+"""The API, which leverages Django Request Framework."""
diff --git a/scram/route_manager/api/exceptions.py b/scram/route_manager/api/exceptions.py
index 5483055b..0445dbe4 100644
--- a/scram/route_manager/api/exceptions.py
+++ b/scram/route_manager/api/exceptions.py
@@ -1,22 +1,30 @@
+"""Custom exceptions for the API."""
+
from django.conf import settings
from rest_framework.exceptions import APIException
class PrefixTooLarge(APIException):
+ """The CIDR prefix that was specified is larger than the prefix allowed in the settings."""
+
v4_min_prefix = getattr(settings, "V4_MINPREFIX", 0)
v6_min_prefix = getattr(settings, "V6_MINPREFIX", 0)
status_code = 400
- default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: 501
+ default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: E501
default_code = "prefix_too_large"
class IgnoredRoute(APIException):
+ """An operation attempted to add a route that overlaps with a route on the ignore list."""
+
status_code = 400
default_detail = "This CIDR is on the ignore list. You are not allowed to add it here."
default_code = "ignored_route"
class ActiontypeNotAllowed(APIException):
+ """An operation attempted to perform an action on behalf of a client that is unauthorized to perform that type."""
+
status_code = 403
default_detail = "This client is not allowed to use this actiontype"
default_code = "actiontype_not_allowed"
diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py
index dec9df15..f1bc1907 100644
--- a/scram/route_manager/api/serializers.py
+++ b/scram/route_manager/api/serializers.py
@@ -1,5 +1,8 @@
+"""Serializers provide mappings between the API and the underlying model."""
+
import logging
+from drf_spectacular.utils import extend_schema_field
from netfields import rest_framework
from rest_framework import serializers
from rest_framework.fields import CurrentUserDefault
@@ -10,16 +13,29 @@
logger = logging.getLogger(__name__)
+@extend_schema_field(field={"type": "string", "format": "cidr"})
+class CustomCidrAddressField(rest_framework.CidrAddressField):
+ """Define a wrapper field so swagger can properly handle the inherited field."""
+
+
class ActionTypeSerializer(serializers.ModelSerializer):
+ """Map the serializer to the model via Meta."""
+
class Meta:
+ """Maps to the ActionType model, and specifies the fields exposed by the API."""
+
model = ActionType
fields = ["pk", "name", "available"]
class RouteSerializer(serializers.ModelSerializer):
- route = rest_framework.CidrAddressField()
+ """Exposes route as a CIDR field."""
+
+ route = CustomCidrAddressField()
class Meta:
+ """Maps to the Route model, and specifies the fields exposed by the API."""
+
model = Route
fields = [
"route",
@@ -27,16 +43,24 @@ class Meta:
class ClientSerializer(serializers.ModelSerializer):
+ """Map the serializer to the model via Meta."""
+
class Meta:
+ """Maps to the Client model, and specifies the fields exposed by the API."""
+
model = Client
fields = ["hostname", "uuid"]
class EntrySerializer(serializers.HyperlinkedModelSerializer):
+ """Due to the use of ForeignKeys, this follows some relationships to make sense via the API."""
+
url = serializers.HyperlinkedIdentityField(
- view_name="api:v1:entry-detail", lookup_url_kwarg="pk", lookup_field="route"
+ view_name="api:v1:entry-detail",
+ lookup_url_kwarg="pk",
+ lookup_field="route",
)
- route = rest_framework.CidrAddressField()
+ route = CustomCidrAddressField()
actiontype = serializers.CharField(default="block")
if CurrentUserDefault():
# This is set if we are calling this serializer from WUI
@@ -46,28 +70,44 @@ class EntrySerializer(serializers.HyperlinkedModelSerializer):
comment = serializers.CharField()
class Meta:
+ """Maps to the Entry model, and specifies the fields exposed by the API."""
+
model = Entry
fields = ["route", "actiontype", "url", "comment", "who"]
- def get_comment(self, obj):
+ @staticmethod
+ def get_comment(obj):
+ """Provide a nicer name for change reason.
+
+ Returns:
+ string: The change reason that modified the Entry.
+ """
return obj.get_change_reason()
- def create(self, validated_data):
+ @staticmethod
+ def create(validated_data):
+ """Implement custom logic and validates creating a new route.""" # noqa: DOC201
valid_route = validated_data.pop("route")
actiontype = validated_data.pop("actiontype")
comment = validated_data.pop("comment")
- route_instance, created = Route.objects.get_or_create(route=valid_route)
+ route_instance, _ = Route.objects.get_or_create(route=valid_route)
actiontype_instance = ActionType.objects.get(name=actiontype)
- entry_instance, created = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)
+ entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)
- logger.debug(f"{comment}")
+ logger.debug("Created entry with comment: %s", comment)
update_change_reason(entry_instance, comment)
return entry_instance
class IgnoreEntrySerializer(serializers.ModelSerializer):
+ """Map the route to the right field type."""
+
+ route = CustomCidrAddressField()
+
class Meta:
+ """Maps to the IgnoreEntry model, and specifies the fields exposed by the API."""
+
model = IgnoreEntry
fields = ["route", "comment"]
diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py
index 2972e6dd..70d30b34 100644
--- a/scram/route_manager/api/views.py
+++ b/scram/route_manager/api/views.py
@@ -1,3 +1,5 @@
+"""Views provide mappings between the underlying model and how they're listed in the API."""
+
import ipaddress
import logging
@@ -8,6 +10,7 @@
from django.db.models import Q
from django.http import Http404
from django.utils.dateparse import parse_datetime
+from drf_spectacular.utils import extend_schema
from rest_framework import status, viewsets
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
@@ -17,23 +20,42 @@
from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer
channel_layer = get_channel_layer()
+logger = logging.getLogger(__name__)
+@extend_schema(
+ description="API endpoint for actiontypes",
+ responses={200: ActionTypeSerializer},
+)
class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet):
+ """Lookup ActionTypes by name when authenticated, and bind to the serializer."""
+
queryset = ActionType.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = ActionTypeSerializer
lookup_field = "name"
+@extend_schema(
+ description="API endpoint for ignore entries",
+ responses={200: IgnoreEntrySerializer},
+)
class IgnoreEntryViewSet(viewsets.ModelViewSet):
+ """Lookup IgnoreEntries by route when authenticated, and bind to the serializer."""
+
queryset = IgnoreEntry.objects.all()
permission_classes = (IsAuthenticated,)
serializer_class = IgnoreEntrySerializer
lookup_field = "route"
+@extend_schema(
+ description="API endpoint for clients",
+ responses={200: ClientSerializer},
+)
class ClientViewSet(viewsets.ModelViewSet):
+ """Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer."""
+
queryset = Client.objects.all()
# We want to allow a client to be registered from anywhere
permission_classes = (AllowAny,)
@@ -42,21 +64,56 @@ class ClientViewSet(viewsets.ModelViewSet):
http_method_names = ["post"]
+@extend_schema(
+ description="API endpoint for entries",
+ responses={200: EntrySerializer},
+)
class EntryViewSet(viewsets.ModelViewSet):
+ """Lookup Entry when authenticated, and bind to the serializer."""
+
queryset = Entry.objects.filter(is_active=True)
permission_classes = (IsAuthenticated,)
serializer_class = EntrySerializer
lookup_value_regex = ".*"
http_method_names = ["get", "post", "head", "delete"]
- # Ovveride the permissions classes for POST method since we want to accept Entry creates from any client
- # Note: We make authorization decisions on whether to actually create the object in the perform_create method later
def get_permissions(self):
+ """Override the permissions classes for POST method since we want to accept Entry creates from any client.
+
+ Note: We make authorization decisions on whether to actually create the object in the perform_create method
+ later.
+ """
if self.request.method == "POST":
return [AllowAny()]
return super().get_permissions()
+ def check_client_authorization(self, actiontype):
+ """Ensure that a given client is authorized to use a given actiontype."""
+ uuid = self.request.data.get("uuid")
+ if uuid:
+ authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list(
+ "authorized_actiontypes__name",
+ flat=True,
+ )
+ authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized")
+ if not authorized_client or actiontype not in authorized_actiontypes:
+ logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes)
+ logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype)
+ raise ActiontypeNotAllowed
+ elif not self.request.user.has_perm("route_manager.can_add_entry"):
+ raise PermissionDenied
+
+ @staticmethod
+ def check_ignore_list(route):
+ """Ensure that we're not trying to block something from the ignore list."""
+ overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
+ if overlapping_ignore.count():
+ ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()]
+ logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries)
+ raise IgnoredRoute
+
def perform_create(self, serializer):
+ """Create a new Entry, causing that route to receive the actiontype (i.e. block)."""
actiontype = serializer.validated_data["actiontype"]
route = serializer.validated_data["route"]
if self.request.user.username:
@@ -69,63 +126,45 @@ def perform_create(self, serializer):
tmp_exp = self.request.data.get("expiration", "")
try:
- expiration = parse_datetime(tmp_exp) # noqa: F841
+ expiration = parse_datetime(tmp_exp)
except ValueError:
- logging.info(f"Could not parse expiration DateTime string {tmp_exp!r}.")
+ logger.warning("Could not parse expiration DateTime string: %s", tmp_exp)
# Make sure we put in an acceptable sized prefix
min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0)
if route.prefixlen < min_prefix:
- raise PrefixTooLarge()
-
- # Make sure this client is authorized to add this entry with this actiontype
- if self.request.data.get("uuid"):
- client_uuid = self.request.data["uuid"]
- authorized_actiontypes = Client.objects.filter(uuid=client_uuid).values_list(
- "authorized_actiontypes__name", flat=True
+ raise PrefixTooLarge
+
+ self.check_client_authorization(actiontype)
+ self.check_ignore_list(route)
+
+ elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
+ if not elements:
+ logger.warning("No elements found for actiontype: %s", actiontype)
+
+ for element in elements:
+ msg = element.websocketmessage
+ msg.msg_data[msg.msg_data_route_field] = str(route)
+ # Must match a channel name defined in asgi.py
+ async_to_sync(channel_layer.group_send)(
+ f"translator_{actiontype}",
+ {"type": msg.msg_type, "message": msg.msg_data},
)
- authorized_client = Client.objects.filter(uuid=client_uuid).values("is_authorized")
- if not authorized_client or actiontype not in authorized_actiontypes:
- logging.debug(f"Client {client_uuid} actiontypes: {authorized_actiontypes}")
- logging.info(f"{client_uuid} is not allowed to add an entry to the {actiontype} list")
- raise ActiontypeNotAllowed()
- elif not self.request.user.has_perm("route_manager.can_add_entry"):
- raise PermissionDenied()
- # Don't process if we have the entry in the ignorelist
- overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
- if overlapping_ignore.count():
- ignore_entries = []
- for ignore_entry in overlapping_ignore.values():
- ignore_entries.append(str(ignore_entry["route"]))
- logging.info(f"Cannot proceed adding {route}. The ignore list contains {ignore_entries}")
- raise IgnoredRoute
- else:
- elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
- if not elements:
- logging.warning(f"No elements found for actiontype={actiontype}.")
-
- for element in elements:
- msg = element.websocketmessage
- msg.msg_data[msg.msg_data_route_field] = str(route)
- # Must match a channel name defined in asgi.py
- async_to_sync(channel_layer.group_send)(
- f"translator_{actiontype}", {"type": msg.msg_type, "message": msg.msg_data}
- )
-
- serializer.save()
-
- entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
- if expiration:
- entry.expiration = expiration
- entry.who = who
- entry.is_active = True
- entry.comment = comment
- logging.info(f"{entry}")
- entry.save()
+ serializer.save()
+
+ entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
+ if expiration:
+ entry.expiration = expiration
+ entry.who = who
+ entry.is_active = True
+ entry.comment = comment
+ logger.info("Created entry: %s", entry)
+ entry.save()
@staticmethod
def find_entries(arg, active_filter=None):
+ """Query entries either by pk or overlapping route."""
if not arg:
return Entry.objects.none()
@@ -133,13 +172,13 @@ def find_entries(arg, active_filter=None):
try:
pk = int(arg)
query = Q(pk=pk)
- except ValueError:
+ except ValueError as exc:
# Maybe a CIDR? We want the ValueError at this point, if not.
cidr = ipaddress.ip_network(arg, strict=False)
min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0)
if cidr.prefixlen < min_prefix:
- raise PrefixTooLarge()
+ raise PrefixTooLarge from exc
query = Q(route__route__net_overlaps=cidr)
@@ -149,6 +188,7 @@ def find_entries(arg, active_filter=None):
return Entry.objects.filter(query)
def retrieve(self, request, pk=None, **kwargs):
+ """Retrieve a single route."""
entries = self.find_entries(pk, active_filter=True)
# TODO: What happens if we get multiple? Is that ok? I think yes, and return them all?
if entries.count() != 1:
@@ -157,6 +197,7 @@ def retrieve(self, request, pk=None, **kwargs):
return Response(serializer.data)
def destroy(self, request, pk=None, *args, **kwargs):
+ """Only delete active (e.g. announced) entries."""
for entry in self.find_entries(pk, active_filter=True):
entry.delete()
diff --git a/scram/route_manager/apps.py b/scram/route_manager/apps.py
index 319a8d17..bf13bf9e 100644
--- a/scram/route_manager/apps.py
+++ b/scram/route_manager/apps.py
@@ -1,5 +1,9 @@
+"""Register ourselves with Django."""
+
from django.apps import AppConfig
class RouteManagerConfig(AppConfig):
+ """Define the name of the module that's the main app."""
+
name = "scram.route_manager"
diff --git a/scram/route_manager/authentication_backends.py b/scram/route_manager/authentication_backends.py
index 1c549376..1831d788 100644
--- a/scram/route_manager/authentication_backends.py
+++ b/scram/route_manager/authentication_backends.py
@@ -1,47 +1,49 @@
+"""Define one or more custom auth backends."""
+
from django.conf import settings
from django.contrib.auth.models import Group
from mozilla_django_oidc.auth import OIDCAuthenticationBackend
-class ESnetAuthBackend(OIDCAuthenticationBackend):
- def update_groups(self, user, claims):
- """Sets the users group(s) to whatever is in the claims."""
+def groups_overlap(a, b):
+ """Helper function to see if a and b have any overlap.
- claimed_groups = claims.get("groups", [])
-
- effective_groups = []
- is_admin = False
+ Returns:
+ bool: True if there's any overlap between a and b.
+ """
+ return not set(a).isdisjoint(b)
- ro_group = Group.objects.get(name="readonly")
- rw_group = Group.objects.get(name="readwrite")
- for g in claimed_groups:
- # If any of the user's groups are in DENIED_GROUPS, deny them and stop processing immediately
- if g in settings.SCRAM_DENIED_GROUPS:
- effective_groups = []
- is_admin = False
- break
-
- if g in settings.SCRAM_ADMIN_GROUPS:
- is_admin = True
+class ESnetAuthBackend(OIDCAuthenticationBackend):
+ """Extend the OIDC backend with a custom permission model."""
- if g in settings.SCRAM_READONLY_GROUPS:
- if ro_group not in effective_groups:
- effective_groups.append(ro_group)
+ @staticmethod
+ def update_groups(user, claims):
+ """Set the user's group(s) to whatever is in the claims."""
+ effective_groups = []
+ claimed_groups = claims.get("groups", [])
- if g in settings.SCRAM_READWRITE_GROUPS:
- if rw_group not in effective_groups:
- effective_groups.append(rw_group)
+ if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS):
+ is_admin = False
+ # Don't even look at anything else if they're denied
+ else:
+ is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS)
+ if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS):
+ effective_groups.append(Group.objects.get(name="readwrite"))
+ if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS):
+ effective_groups.append(Group.objects.get(name="readonly"))
user.groups.set(effective_groups)
user.is_staff = user.is_superuser = is_admin
user.save()
def create_user(self, claims):
- user = super(ESnetAuthBackend, self).create_user(claims)
+ """Wrap the superclass's user creation.""" # noqa: DOC201
+ user = super().create_user(claims)
return self.update_user(user, claims)
def update_user(self, user, claims):
+ """Determine the user name from the claims and update said user's groups.""" # noqa: DOC201
user.name = claims.get("given_name", "") + " " + claims.get("family_name", "")
user.username = claims.get("preferred_username", "")
if claims.get("groups", False):
diff --git a/scram/route_manager/context_processors.py b/scram/route_manager/context_processors.py
index 3781c3d2..0daf06c9 100644
--- a/scram/route_manager/context_processors.py
+++ b/scram/route_manager/context_processors.py
@@ -1,8 +1,15 @@
+"""Define custom functions that take a request and add to the context before template rendering."""
+
from django.conf import settings
from django.urls import reverse
def login_logout(request):
+ """Pass through the relevant URLs from the settings.
+
+ Returns:
+ dict: login and logout URLs
+ """
login_url = reverse(settings.LOGIN_URL)
logout_url = reverse(settings.LOGOUT_URL)
return {"login": login_url, "logout": logout_url}
diff --git a/scram/route_manager/models.py b/scram/route_manager/models.py
index ef063ead..aa72f7ba 100644
--- a/scram/route_manager/models.py
+++ b/scram/route_manager/models.py
@@ -1,3 +1,5 @@
+"""Define the models used in the route_manager app."""
+
import logging
import uuid as uuid_lib
@@ -8,35 +10,41 @@
from netfields import CidrAddressField
from simple_history.models import HistoricalRecords
+logger = logging.getLogger(__name__)
+
class Route(models.Model):
- """Model describing a route"""
+ """Define a route as a CIDR route and a UUID."""
route = CidrAddressField(unique=True)
uuid = models.UUIDField(db_index=True, default=uuid_lib.uuid4, editable=False)
- def get_absolute_url(self):
- return reverse("")
-
def __str__(self):
+ """Don't display the UUID, only the route.""" # noqa: DOC201
return str(self.route)
+ @staticmethod
+ def get_absolute_url():
+ """Ensure we use UUID on the API side instead.""" # noqa: DOC201
+ return reverse("")
+
class ActionType(models.Model):
- """Defines an action that can be done with a given route. e.g. Block, shunt, redirect, etc."""
+ """Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc."""
name = models.CharField(help_text="One-word description of the action", max_length=30)
available = models.BooleanField(help_text="Is this a valid choice for new entries?", default=True)
history = HistoricalRecords()
def __str__(self):
+ """Display clearly whether the action is currently available.""" # noqa: DOC201
if not self.available:
return f"{self.name} (Inactive)"
return self.name
class WebSocketMessage(models.Model):
- """Defines a single message sent to downstream translators via WebSocket."""
+ """Define a single message sent to downstream translators via WebSocket."""
msg_type = models.CharField("The type of the message", max_length=50)
msg_data = models.JSONField("The JSON payload. See also msg_data_route_field.", default=dict)
@@ -47,16 +55,17 @@ class WebSocketMessage(models.Model):
)
def __str__(self):
+ """Display clearly what the fields are used for.""" # noqa: DOC201
return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}"
class WebSocketSequenceElement(models.Model):
- """In a sequence of messages, defines a single element."""
+ """In a sequence of messages, define a single element."""
websocketmessage = models.ForeignKey("WebSocketMessage", on_delete=models.CASCADE)
order_num = models.SmallIntegerField(
"Sequences are sent from the smallest order_num to the highest. "
- + "Messages with the same order_num could be sent in any order",
+ "Messages with the same order_num could be sent in any order",
default=0,
)
@@ -70,9 +79,10 @@ class WebSocketSequenceElement(models.Model):
action_type = models.ForeignKey("ActionType", on_delete=models.CASCADE)
def __str__(self):
+ """Summarize the fields into something short and readable.""" # noqa: DOC201
return (
f"{self.websocketmessage} as order={self.order_num} for "
- + f"{self.verb} actions on actiontype={self.action_type}"
+ f"{self.verb} actions on actiontype={self.action_type}"
)
@@ -81,7 +91,7 @@ class Entry(models.Model):
route = models.ForeignKey("Route", on_delete=models.PROTECT)
actiontype = models.ForeignKey("ActionType", on_delete=models.PROTECT)
- comment = models.TextField(blank=True, null=True)
+ comment = models.TextField(blank=True, default="")
is_active = models.BooleanField(default=True)
# TODO: fix name if this works
history = HistoricalRecords()
@@ -91,60 +101,71 @@ class Entry(models.Model):
expiration_reason = models.CharField(
help_text="Optional reason for the expiration",
max_length=200,
- null=True,
blank=True,
+ default="",
)
- def delete(self, *args, **kwargs):
- if not self.is_active:
- # We've already expired this route, don't send another message
- return
- else:
- # We don't actually delete records; we set them to inactive and then tell the translator to remove them
- logging.info(f"Deactivating {self.route}")
- self.is_active = False
- self.save()
-
- # Unblock it
- async_to_sync(channel_layer.group_send)(
- f"translator_{self.actiontype}",
- {
- "type": "translator_remove",
- "message": {"route": str(self.route)},
- },
- )
-
class Meta:
+ """Ensure that multiple routes can be added as long as they have different action types."""
+
unique_together = ["route", "actiontype"]
verbose_name_plural = "Entries"
def __str__(self):
+ """Summarize the most important fields to something easily readable.""" # noqa: DOC201
desc = f"{self.route} ({self.actiontype})"
if not self.is_active:
desc += " (inactive)"
return desc
+ def delete(self, *args, **kwargs):
+ """Set inactive instead of deleting, as we want to ensure a history of entries."""
+ if not self.is_active:
+ # We've already expired this route, don't send another message
+ return
+ # We don't actually delete records; we set them to inactive and then tell the translator to remove them
+ logger.info("Deactivating %s", self.route)
+ self.is_active = False
+ self.save()
+
+ # Unblock it
+ async_to_sync(channel_layer.group_send)(
+ f"translator_{self.actiontype}",
+ {
+ "type": "translator_remove",
+ "message": {"route": str(self.route)},
+ },
+ )
+
def get_change_reason(self):
+ """Traverse some complex relationships to determine the most recent change reason.
+
+ Returns:
+ str: The most recent change reason
+ """
hist_mgr = getattr(self, self._meta.simple_history_manager_attribute)
return hist_mgr.order_by("-history_date").first().history_change_reason
class IgnoreEntry(models.Model):
- """For cidrs you NEVER want to block ie don't shoot yourself in the foot list"""
+ """Define CIDRs you NEVER want to block (i.e. the "don't shoot yourself in the foot" list)."""
route = CidrAddressField(unique=True)
comment = models.CharField(max_length=100)
history = HistoricalRecords()
class Meta:
+ """Ensure the plural is grammatically correct."""
+
verbose_name_plural = "Ignored Entries"
def __str__(self):
+ """Only display the route.""" # noqa: DOC201
return str(self.route)
class Client(models.Model):
- """Any client that would like to hit the API to add entries (e.g. Zeek)"""
+ """Any client that would like to hit the API to add entries (e.g. Zeek)."""
hostname = models.CharField(max_length=50, unique=True)
uuid = models.UUIDField()
@@ -153,6 +174,7 @@ class Client(models.Model):
authorized_actiontypes = models.ManyToManyField(ActionType)
def __str__(self):
+ """Only display the hostname.""" # noqa: DOC201
return str(self.hostname)
diff --git a/scram/route_manager/tests/__init__.py b/scram/route_manager/tests/__init__.py
index e69de29b..fb031a32 100644
--- a/scram/route_manager/tests/__init__.py
+++ b/scram/route_manager/tests/__init__.py
@@ -0,0 +1 @@
+"""Define tests executed by pytest."""
diff --git a/scram/route_manager/tests/acceptance/environment.py b/scram/route_manager/tests/acceptance/environment.py
index 5bae6990..0d31198b 100644
--- a/scram/route_manager/tests/acceptance/environment.py
+++ b/scram/route_manager/tests/acceptance/environment.py
@@ -1,9 +1,12 @@
+"""Setup the environment for tests."""
+
from rest_framework.test import APIClient
from scram.users.tests.factories import UserFactory
def django_ready(context):
+ """Create a user that tests can use for authenticated/unauthenticated testing."""
# Create a user
UserFactory(username="user", password="password")
diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py
index 1d502e96..f688bd64 100644
--- a/scram/route_manager/tests/acceptance/steps/common.py
+++ b/scram/route_manager/tests/acceptance/steps/common.py
@@ -1,32 +1,37 @@
+"""Define steps used exclusively by the Behave tests."""
+
import datetime
import time
-import django.conf as conf
from asgiref.sync import async_to_sync
from behave import given, step, then, when
from channels.layers import get_channel_layer
+from django import conf
from django.urls import reverse
from scram.route_manager.models import ActionType, Client, WebSocketMessage, WebSocketSequenceElement
@given("a {name} actiontype is defined")
-def step_impl(context, name):
+def create_actiontype(context, name):
+ """Create an actiontype of that name."""
context.channel_layer = get_channel_layer()
async_to_sync(context.channel_layer.group_send)(
- f"translator_{name}", {"type": "translator_remove_all", "message": {}}
+ f"translator_{name}",
+ {"type": "translator_remove_all", "message": {}},
)
- at, created = ActionType.objects.get_or_create(name=name)
- wsm, created = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
+ at, _ = ActionType.objects.get_or_create(name=name)
+ wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
wsm.save()
- wsse, created = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at)
+ wsse, _ = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at)
wsse.save()
@given("a client with {name} authorization")
-def step_impl(context, name):
- at, created = ActionType.objects.get_or_create(name=name)
+def create_authed_client(context, name):
+ """Create a client and authorize it for that action type."""
+ at, _ = ActionType.objects.get_or_create(name=name)
authorized_client = Client.objects.create(
hostname="authorized_client.es.net",
uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
@@ -36,7 +41,8 @@ def step_impl(context, name):
@given("a client without {name} authorization")
-def step_impl(context, name):
+def create_unauthed_client(context, name):
+ """Create a client that has no authorized action types."""
unauthorized_client = Client.objects.create(
hostname="unauthorized_client.es.net",
uuid="91e134a5-77cf-4560-9797-6bbdbffde9f8",
@@ -45,23 +51,27 @@ def step_impl(context, name):
@when("we're logged in")
-def step_impl(context):
+def login(context):
+ """Login."""
context.test.client.login(username="user", password="password")
@when("the CIDR prefix limits are {v4_minprefix:d} and {v6_minprefix:d}")
-def step_impl(context, v4_minprefix, v6_minprefix):
+def set_cidr_limit(context, v4_minprefix, v6_minprefix):
+ """Override our settings with the provided values."""
conf.settings.V4_MINPREFIX = v4_minprefix
conf.settings.V6_MINPREFIX = v6_minprefix
@then("we get a {status_code:d} status code")
-def step_impl(context, status_code):
+def check_status_code(context, status_code):
+ """Ensure the status code response matches the expected value."""
context.test.assertEqual(context.response.status_code, status_code)
@when("we add the entry {value:S}")
-def step_impl(context, value):
+def add_entry(context, value):
+ """Block the provided route."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
{
@@ -77,7 +87,8 @@ def step_impl(context, value):
@when("we add the entry {value:S} with comment {comment}")
-def step_impl(context, value, comment):
+def add_entry_with_comment(context, value, comment):
+ """Block the provided route and add a comment."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
{
@@ -92,7 +103,8 @@ def step_impl(context, value, comment):
@when("we add the entry {value:S} with expiration {exp:S}")
-def step_impl(context, value, exp):
+def add_entry_with_absolute_expiration(context, value, exp):
+ """Block the provided route and add an absolute expiration datetime."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
{
@@ -108,9 +120,10 @@ def step_impl(context, value, exp):
@when("we add the entry {value:S} with expiration in {secs:d} seconds")
-def step_impl(context, value, secs):
+def add_entry_with_relative_expiration(context, value, secs):
+ """Block the provided route and add a relative expiration."""
td = datetime.timedelta(seconds=secs)
- expiration = datetime.datetime.now() + td
+ expiration = datetime.datetime.now(tz=datetime.UTC) + td
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
@@ -127,34 +140,41 @@ def step_impl(context, value, secs):
@step("we wait {secs:d} seconds")
-def step_impl(context, secs):
+def wait(context, secs):
+ """Wait to allow messages to propagate."""
time.sleep(secs)
@then("we remove expired entries")
-def step_impl(context):
+def remove_expired(context):
+ """Call the function that removes expired entries."""
context.response = context.test.client.get(reverse("route_manager:process-expired"))
@when("we add the ignore entry {value:S}")
-def step_impl(context, value):
+def add_ignore_entry(context, value):
+ """Add an IgnoreEntry with the specified route."""
context.response = context.test.client.post(
- reverse("api:v1:ignoreentry-list"), {"route": value, "comment": "test api"}
+ reverse("api:v1:ignoreentry-list"),
+ {"route": value, "comment": "test api"},
)
@when("we remove the {model} {value}")
-def step_impl(context, model, value):
+def remove_an_object(context, model, value):
+ """Remove any model object with the matching value."""
context.response = context.test.client.delete(reverse(f"api:v1:{model.lower()}-detail", args=[value]))
@when("we list the {model}s")
-def step_impl(context, model):
+def list_objects(context, model):
+ """List all objects of an arbitrary model."""
context.response = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
@when("we update the {model} {value_from} to {value_to}")
-def step_impl(context, model, value_from, value_to):
+def update_object(context, model, value_from, value_to):
+ """Modify any model object with the matching value to the new value instead."""
context.response = context.test.client.patch(
reverse(f"api:v1:{model.lower()}-detail", args=[value_from]),
{model.lower(): value_to},
@@ -162,7 +182,8 @@ def step_impl(context, model, value_from, value_to):
@then("the number of {model}s is {num:d}")
-def step_impl(context, model, num):
+def count_objects(context, model, num):
+ """Count the number of objects of an arbitrary model."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
context.test.assertEqual(len(objs.json()), num)
@@ -171,7 +192,8 @@ def step_impl(context, model, num):
@then("{value} is one of our list of {model}s")
-def step_impl(context, value, model):
+def check_object(context, value, model):
+ """Ensure that the arbitrary model has an object with the specified value."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
found = False
@@ -186,7 +208,8 @@ def step_impl(context, value, model):
@when("we register a client named {hostname} with the uuid of {uuid}")
-def step_impl(context, hostname, uuid):
+def add_client(context, hostname, uuid):
+ """Create a client with a specific UUID."""
context.response = context.test.client.post(
reverse("api:v1:client-list"),
{
diff --git a/scram/route_manager/tests/acceptance/steps/ip.py b/scram/route_manager/tests/acceptance/steps/ip.py
index a346e954..34536a86 100644
--- a/scram/route_manager/tests/acceptance/steps/ip.py
+++ b/scram/route_manager/tests/acceptance/steps/ip.py
@@ -1,12 +1,14 @@
+"""Define steps used for IP-related logic by the Behave tests."""
+
import ipaddress
from behave import then, when
from django.urls import reverse
-# This does a CIDR match
@then("{route} is contained in our list of {model}s")
-def step_impl(context, route, model):
+def check_route(context, route, model):
+ """Perform a CIDR match on the matching object."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
ip_target = ipaddress.ip_address(route)
@@ -21,7 +23,8 @@ def step_impl(context, route, model):
@when("we query for {ip}")
-def step_impl(context, ip):
+def check_ip(context, ip):
+ """Find an Entry for the specified IP."""
try:
context.response = context.test.client.get(reverse("api:v1:entry-detail", args=[ip]))
context.queryException = None
@@ -31,12 +34,14 @@ def step_impl(context, ip):
@then("we get a ValueError")
-def step_impl(context):
+def check_error(context):
+ """Ensure we received a ValueError exception."""
assert isinstance(context.queryException, ValueError)
@then("the change entry for {value:S} is {comment}")
-def step_impl(context, value, comment):
+def check_comment(context, value, comment):
+ """Verify the comment for the Entry."""
try:
objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value]))
context.test.assertEqual(objs.json()[0]["comment"], comment)
diff --git a/scram/route_manager/tests/acceptance/steps/translator.py b/scram/route_manager/tests/acceptance/steps/translator.py
index 85621f21..ff478a9a 100644
--- a/scram/route_manager/tests/acceptance/steps/translator.py
+++ b/scram/route_manager/tests/acceptance/steps/translator.py
@@ -1,3 +1,5 @@
+"""Define steps used for translator integration testing by the Behave tests."""
+
from behave import then
from behave.api.async_step import async_run_until_complete
from channels.testing import WebsocketCommunicator
@@ -5,9 +7,10 @@
from config.asgi import ws_application
-async def query_translator(route, actiontype, is_announced=True):
+async def query_translator(route, actiontype, is_announced):
+ """Ensure the specified route is currently either blocked or unblocked."""
communicator = WebsocketCommunicator(ws_application, f"/ws/route_manager/webui_{actiontype}/")
- connected, subprotocol = await communicator.connect()
+ connected, _ = await communicator.connect()
assert connected
await communicator.send_json_to({"type": "wui_check_req", "message": {"route": route}})
@@ -20,11 +23,13 @@ async def query_translator(route, actiontype, is_announced=True):
@then("{route} is announced by {actiontype} translators")
@async_run_until_complete
-async def step_impl(context, route, actiontype):
- await query_translator(route, actiontype)
+async def check_blocked(context, route, actiontype):
+ """Ensure the specified route is currently blocked."""
+ await query_translator(route, actiontype, is_announced=True)
@then("{route} is not announced by {actiontype} translators")
@async_run_until_complete
-async def step_impl(context, route, actiontype):
+async def check_unblocked(context, route, actiontype):
+ """Ensure the specified route is currently unblocked."""
await query_translator(route, actiontype, is_announced=False)
diff --git a/scram/route_manager/tests/functional_tests.py b/scram/route_manager/tests/functional_tests.py
index c4312ecf..a87ce21e 100644
--- a/scram/route_manager/tests/functional_tests.py
+++ b/scram/route_manager/tests/functional_tests.py
@@ -1,9 +1,10 @@
+"""Use the Django web client to perform end-to-end, WebUI-based testing."""
+
import unittest
class HomePageTest(unittest.TestCase):
-
- pass
+ """Ensure the home page works."""
if __name__ == "__main__":
diff --git a/scram/route_manager/tests/test_api.py b/scram/route_manager/tests/test_api.py
index 8152b877..1a81e907 100644
--- a/scram/route_manager/tests/test_api.py
+++ b/scram/route_manager/tests/test_api.py
@@ -1,3 +1,5 @@
+"""Use pytest to unit test the API."""
+
from django.contrib.auth import get_user_model
from django.urls import reverse
from rest_framework import status
@@ -7,7 +9,10 @@
class TestAddRemoveIP(APITestCase):
+ """Ensure that we can block IPs, and that duplicate blocks don't generate an error."""
+
def setUp(self):
+ """Set up the environment for our tests."""
self.url = reverse("api:v1:entry-list")
self.superuser = get_user_model().objects.create_superuser("admin", "admin@es.net", "admintestpassword")
self.client.login(username="admin", password="admintestpassword")
@@ -19,6 +24,7 @@ def setUp(self):
self.authorized_client.authorized_actiontypes.set([1])
def test_block_ipv4(self):
+ """Block a v4 IP."""
response = self.client.post(
self.url,
{
@@ -31,6 +37,7 @@ def test_block_ipv4(self):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_block_duplicate_ipv4(self):
+ """Block an existing v4 IP and ensure we don't get an error."""
self.client.post(
self.url,
{
@@ -52,6 +59,7 @@ def test_block_duplicate_ipv4(self):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_block_ipv6(self):
+ """Block a v6 IP."""
response = self.client.post(
self.url,
{
@@ -64,6 +72,7 @@ def test_block_ipv6(self):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_block_duplicate_ipv6(self):
+ """Block an existing v6 IP and ensure we don't get an error."""
self.client.post(
self.url,
{
@@ -86,11 +95,15 @@ def test_block_duplicate_ipv6(self):
class TestUnauthenticatedAccess(APITestCase):
+ """Ensure that an unathenticated client can't do anything."""
+
def setUp(self):
+ """Define some helper variables."""
self.entry_url = reverse("api:v1:entry-list")
self.ignore_url = reverse("api:v1:ignoreentry-list")
def test_unauthenticated_users_have_no_create_access(self):
+ """Ensure an unauthenticated client can't add an Entry."""
response = self.client.post(
self.entry_url,
{
@@ -104,9 +117,11 @@ def test_unauthenticated_users_have_no_create_access(self):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_unauthenticated_users_have_no_ignore_create_access(self):
+ """Ensure an unauthenticated client can't add an IgnoreEntry."""
response = self.client.post(self.ignore_url, {"route": "192.0.2.4"}, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_unauthenticated_users_have_no_list_access(self):
+ """Ensure an unauthenticated client can't list Entries."""
response = self.client.get(self.entry_url, format="json")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
diff --git a/scram/route_manager/tests/test_authorization.py b/scram/route_manager/tests/test_authorization.py
index 6f4269d1..9a0d4bff 100644
--- a/scram/route_manager/tests/test_authorization.py
+++ b/scram/route_manager/tests/test_authorization.py
@@ -1,3 +1,5 @@
+"""Define tests for authorization and permissions."""
+
from django.conf import settings
from django.contrib.auth.models import Group
from django.test import Client, TestCase
@@ -10,7 +12,10 @@
class AuthzTest(TestCase):
+ """Define tests using the built-in authentication."""
+
def setUp(self):
+ """Define several users for our tests."""
self.client = Client()
self.unauthorized_user = User.objects.create(username="unauthorized")
@@ -49,6 +54,7 @@ def setUp(self):
)
def create_entry(self):
+ """Ensure the admin user can create an Entry."""
self.client.force_login(self.admin_user)
self.client.post(
reverse("route_manager:add"),
@@ -63,8 +69,7 @@ def create_entry(self):
return Entry.objects.latest("id").id
def test_unauthorized_add_entry(self):
- """Unauthorized users should not be able to add an entry"""
-
+ """Unauthorized users should not be able to add an Entry."""
for user in self.write_blocked_users:
if user:
self.client.force_login(user)
@@ -79,6 +84,7 @@ def test_unauthorized_add_entry(self):
self.assertEqual(response.status_code, 302)
def test_authorized_add_entry(self):
+ """Test authorized users with various permissions to ensure they can add an Entry."""
for user in self.write_allowed_users:
self.client.force_login(user)
response = self.client.post(
@@ -92,6 +98,7 @@ def test_authorized_add_entry(self):
self.assertEqual(response.status_code, 200)
def test_unauthorized_detail_view(self):
+ """Ensure that unauthorized users can't view the blocked IPs."""
pk = self.create_entry()
for user in self.detail_blocked_users:
@@ -101,6 +108,7 @@ def test_unauthorized_detail_view(self):
self.assertIn(response.status_code, [302, 403], msg=f"username={user}")
def test_authorized_detail_view(self):
+ """Test authorized users with various permissions to ensure they can view block details."""
pk = self.create_entry()
for user in self.detail_allowed_users:
@@ -110,7 +118,6 @@ def test_authorized_detail_view(self):
def test_unauthorized_after_group_removal(self):
"""The user has r/w access, then when we remove them from the r/w group, they no longer do."""
-
test_user = User.objects.create(username="tmp_readwrite")
test_user.groups.set([self.readwrite_group])
test_user.save()
@@ -125,8 +132,11 @@ def test_unauthorized_after_group_removal(self):
self.assertEqual(response.status_code, 302)
-class OidcTest(TestCase):
+class ESnetAuthBackendTest(TestCase):
+ """Define tests using OIDC authentication with our ESnetAuthBackend."""
+
def setUp(self):
+ """Create a sample OIDC user."""
self.client = Client()
self.claims = {
"given_name": "Edward",
@@ -136,8 +146,7 @@ def setUp(self):
}
def test_unauthorized(self):
- """A user with no groups should have no access"""
-
+ """A user with no groups should have no access."""
claims = dict(self.claims)
user = ESnetAuthBackend().create_user(claims)
@@ -146,8 +155,7 @@ def test_unauthorized(self):
self.assertEqual(list(user.user_permissions.all()), [])
def test_readonly(self):
- """Test r/o groups"""
-
+ """Test r/o groups."""
claims = dict(self.claims)
claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]]
user = ESnetAuthBackend().create_user(claims)
@@ -158,8 +166,7 @@ def test_readonly(self):
self.assertFalse(user.has_perm("route_manager.add_entry"))
def test_readwrite(self):
- """Test r/w groups"""
-
+ """Test r/w groups."""
claims = dict(self.claims)
claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]]
user = ESnetAuthBackend().create_user(claims)
@@ -171,8 +178,7 @@ def test_readwrite(self):
self.assertTrue(user.has_perm("route_manager.add_entry"))
def test_admin(self):
- """Test admin_groups"""
-
+ """Test admin_groups."""
claims = dict(self.claims)
claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
user = ESnetAuthBackend().create_user(claims)
@@ -183,8 +189,7 @@ def test_admin(self):
self.assertTrue(user.has_perm("route_manager.add_entry"))
def test_authorized_removal(self):
- """Have an authorized user, then downgrade them and make sure they're unauthorized"""
-
+ """Have an authorized user, then downgrade them and make sure they're unauthorized."""
claims = dict(self.claims)
claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]]
user = ESnetAuthBackend().create_user(claims)
@@ -230,9 +235,8 @@ def test_authorized_removal(self):
def test_disabled(self):
"""Pass all the groups, user should be disabled as it takes precedence."""
-
claims = dict(self.claims)
- claims["groups"] = [settings.SCRAM_GROUPS]
+ claims["groups"] = settings.SCRAM_GROUPS
user = ESnetAuthBackend().create_user(claims)
self.assertFalse(user.is_staff)
diff --git a/scram/route_manager/tests/test_autocreate_admin.py b/scram/route_manager/tests/test_autocreate_admin.py
new file mode 100644
index 00000000..88270f4e
--- /dev/null
+++ b/scram/route_manager/tests/test_autocreate_admin.py
@@ -0,0 +1,50 @@
+"""Test the auto-creation of an admin user."""
+
+import pytest
+from django.contrib.messages import get_messages
+from django.test import Client
+from django.urls import reverse
+
+from scram.users.models import User
+
+LEVEL_SUCCESS = 25
+LEVEL_INFO = 20
+
+
+@pytest.mark.django_db
+def test_autocreate_admin(settings):
+ """Test that an admin user is auto-created when AUTOCREATE_ADMIN is True."""
+ settings.AUTOCREATE_ADMIN = True
+ client = Client()
+ response = client.get(reverse("route_manager:home"))
+ assert response.status_code == 200 # noqa: PLR2004
+ assert User.objects.count() == 1
+ user = User.objects.get(username="admin")
+ assert user.is_superuser
+ assert user.email == "admin@example.com"
+ messages = list(get_messages(response.wsgi_request))
+ assert len(messages) == 2 # noqa: PLR2004
+ assert messages[0].level == LEVEL_SUCCESS
+ assert messages[1].level == LEVEL_INFO
+
+
+@pytest.mark.django_db
+def test_autocreate_admin_disabled(settings):
+ """Test that an admin user is not auto-created when AUTOCREATE_ADMIN is False."""
+ settings.AUTOCREATE_ADMIN = False
+ client = Client()
+ response = client.get(reverse("route_manager:home"))
+ assert response.status_code == 200 # noqa: PLR2004
+ assert User.objects.count() == 0
+
+
+@pytest.mark.django_db
+def test_autocreate_admin_existing_user(settings):
+ """Test that an admin user is not auto-created when an existing user is present."""
+ settings.AUTOCREATE_ADMIN = True
+ User.objects.create_user("testuser", "test@example.com", "password")
+ client = Client()
+ response = client.get(reverse("route_manager:home"))
+ assert response.status_code == 200 # noqa: PLR2004
+ assert User.objects.count() == 1
+ assert not User.objects.filter(username="admin").exists()
diff --git a/scram/route_manager/tests/test_history.py b/scram/route_manager/tests/test_history.py
index f00f16bb..d2848e83 100644
--- a/scram/route_manager/tests/test_history.py
+++ b/scram/route_manager/tests/test_history.py
@@ -1,3 +1,5 @@
+"""Define tests for the history feature."""
+
from django.test import TestCase
from simple_history.utils import get_change_reason_from_object, update_change_reason
@@ -5,20 +7,27 @@
class TestActiontypeHistory(TestCase):
+ """Test the history on an action type."""
+
def setUp(self):
+ """Set up the test environment."""
self.atype = ActionType.objects.create(name="Block")
def test_comments(self):
+ """Ensure we can go back and set a reason."""
self.atype.name = "Nullroute"
- self.atype._change_reason = "Use more descriptive name"
+ self.atype._change_reason = "Use more descriptive name" # noqa SLF001
self.atype.save()
self.assertIsNotNone(get_change_reason_from_object(self.atype))
class TestEntryHistory(TestCase):
+ """Test the history on an Entry."""
+
routes = ["192.0.2.16/32", "198.51.100.16/28"]
def setUp(self):
+ """Set up the test environment."""
self.atype = ActionType.objects.create(name="Block")
for r in self.routes:
route = Route.objects.create(route=r)
@@ -28,6 +37,7 @@ def setUp(self):
self.assertEqual(entry.get_change_reason(), create_reason)
def test_comments(self):
+ """Ensure we can update the reason."""
for r in self.routes:
route_old = Route.objects.get(route=r)
e = Entry.objects.get(route=route_old)
@@ -37,7 +47,7 @@ def test_comments(self):
e.route = Route.objects.create(route=route_new)
change_reason = "I meant 32, not 16."
- e._change_reason = change_reason
+ e._change_reason = change_reason # noqa SLF001
e.save()
self.assertEqual(len(e.history.all()), 2)
diff --git a/scram/route_manager/tests/test_swagger.py b/scram/route_manager/tests/test_swagger.py
new file mode 100644
index 00000000..ce955db9
--- /dev/null
+++ b/scram/route_manager/tests/test_swagger.py
@@ -0,0 +1,30 @@
+"""Test the swagger API endpoints."""
+
+import pytest
+from django.urls import reverse
+
+
+@pytest.mark.django_db
+def test_swagger_api(client):
+ """Test that the Swagger API endpoint returns a successful response."""
+ url = reverse("swagger-ui")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+
+
+@pytest.mark.django_db
+def test_redoc_api(client):
+ """Test that the Redoc API endpoint returns a successful response."""
+ url = reverse("redoc")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+
+
+@pytest.mark.django_db
+def test_schema_api(client):
+ """Test that the Schema API endpoint returns a successful response."""
+ url = reverse("schema")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+ expected_strings = [b"/entries/", b"/actiontypes/", b"/ignore_entries/", b"/users/"]
+ assert all(string in response.content for string in expected_strings)
diff --git a/scram/route_manager/tests/test_views.py b/scram/route_manager/tests/test_views.py
index 415cdae0..1433135a 100644
--- a/scram/route_manager/tests/test_views.py
+++ b/scram/route_manager/tests/test_views.py
@@ -1,10 +1,58 @@
+"""Define simple tests for the template-based Views."""
+
+from django.conf import settings
from django.test import TestCase
-from django.urls import resolve
+from django.urls import resolve, reverse
from scram.route_manager.views import home_page
class HomePageTest(TestCase):
+ """Test how the home page renders."""
+
def test_root_url_resolves_to_home_page_view(self):
+ """Ensure we can find the home page."""
found = resolve("/")
self.assertEqual(found.func, home_page)
+
+
+class HomePageFirstVisitTest(TestCase):
+ """Test how the home page renders the first time we view it."""
+
+ def setUp(self):
+ """Get the home page."""
+ self.response = self.client.get(reverse("route_manager:home"))
+
+ def test_first_homepage_view_has_userinfo(self):
+ """The first time we view the home page, a user was created for us."""
+ self.assertContains(self.response, b"An admin user was created for you.")
+
+ def test_first_homepage_view_is_logged_in(self):
+ """The first time we view the home page, we're logged in."""
+ self.assertContains(self.response, b'type="submit">Logout')
+
+
+class HomePageLogoutTest(TestCase):
+ """Verify that once logged out, we can't view anything."""
+
+ def test_homepage_logout_links_missing(self):
+ """After logout, we can't see anything."""
+ response = self.client.get(reverse("route_manager:home"))
+ response = self.client.post(reverse(settings.LOGOUT_URL), follow=True)
+ self.assertEqual(response.status_code, 200)
+ response = self.client.get(reverse("route_manager:home"))
+
+ self.assertNotContains(response, b"An admin user was created for you.")
+ self.assertNotContains(response, b'type="submit">Logout')
+ self.assertNotContains(response, b">Admin")
+
+
+class NotFoundTest(TestCase):
+ """Verify that our custom 404 page is being served up."""
+
+ def test_404(self):
+ """Grab a bad URL."""
+ response = self.client.get("/foobarbaz")
+ self.assertContains(
+ response, b'The page you are looking for was not found.
', status_code=404
+ )
diff --git a/scram/route_manager/tests/test_websockets.py b/scram/route_manager/tests/test_websockets.py
index 367359b8..ca9cd528 100644
--- a/scram/route_manager/tests/test_websockets.py
+++ b/scram/route_manager/tests/test_websockets.py
@@ -1,3 +1,5 @@
+"""Define unit tests for the websockets-based communication."""
+
import json
from asyncio import gather
from contextlib import asynccontextmanager
@@ -15,7 +17,7 @@
@asynccontextmanager
async def get_communicators(actiontypes, should_match, *args, **kwds):
- """Creates a set of communicators, and then handles tear-down.
+ """Create a set of communicators, and then handle tear-down.
Given two lists of the same length, a set of actiontypes, and set of boolean values,
creates that many communicators, one for each actiontype-bool pair.
@@ -24,15 +26,13 @@ async def get_communicators(actiontypes, should_match, *args, **kwds):
Returns a list of (communicator, should_match bool) pairs.
"""
- assert len(actiontypes) == len(should_match)
-
router = URLRouter(websocket_urlpatterns)
communicators = [
WebsocketCommunicator(router, f"/ws/route_manager/translator_{actiontype}/") for actiontype in actiontypes
]
- response = zip(communicators, should_match)
+ response = zip(communicators, should_match, strict=True)
- for communicator, should_match in response:
+ for communicator, _ in response:
connected, _ = await communicator.connect()
assert connected
@@ -40,7 +40,7 @@ async def get_communicators(actiontypes, should_match, *args, **kwds):
yield response
finally:
- for communicator, should_match in response:
+ for communicator, _ in response:
await communicator.disconnect()
@@ -48,6 +48,7 @@ class TestTranslatorBaseCase(TestCase):
"""Base case that other test cases build on top of. Three translators in one group, test one v4 and one v6."""
def setUp(self):
+ """Set up our test environment."""
# TODO: This is copied from test_api; should de-dupe this.
self.url = reverse("api:v1:entry-list")
self.superuser = get_user_model().objects.create_superuser("admin", "admin@example.net", "admintestpassword")
@@ -68,7 +69,9 @@ def setUp(self):
wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
_, _ = WebSocketSequenceElement.objects.get_or_create(
- websocketmessage=wsm, verb="A", action_type=self.actiontype
+ websocketmessage=wsm,
+ verb="A",
+ action_type=self.actiontype,
)
# Set some defaults; some child classes override this
@@ -77,10 +80,10 @@ def setUp(self):
self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}]
# Now we run any local setup actions by the child classes
- self.local_setUp()
+ self.local_setup()
- def local_setUp(self):
- # Allow child classes to override this if desired
+ def local_setup(self):
+ """Allow child classes to override this if desired."""
return
async def get_messages(self, communicator, messages, should_match):
@@ -95,6 +98,7 @@ async def get_nothings(self, communicator):
assert await communicator.receive_nothing(timeout=0.1, interval=0.01) is False
async def add_ip(self, ip, mask):
+ """Ensure we can add an IP to block."""
async with get_communicators(self.actiontypes, self.should_match) as communicators:
await self.api_create_entry(ip)
@@ -119,6 +123,7 @@ async def ensure_no_more_msgs(self, communicators):
# Django ensures that the create is synchronous, so we have some extra steps to do
@sync_to_async
def api_create_entry(self, route):
+ """Ensure we can create an Entry via the API."""
return self.client.post(
self.url,
{
@@ -131,12 +136,14 @@ def api_create_entry(self, route):
)
async def test_add_v4(self):
+ """Test adding a few v4 routes."""
await self.add_ip("192.0.2.224", 32)
await self.add_ip("192.0.2.225", 32)
await self.add_ip("192.0.2.226", 32)
await self.add_ip("198.51.100.224", 32)
async def test_add_v6(self):
+ """Test adding a few v6 routes."""
await self.add_ip("2001:DB8:FDF0::", 128)
await self.add_ip("2001:DB8:FDF0::D", 128)
await self.add_ip("2001:DB8:FDF0::DB", 128)
@@ -144,9 +151,10 @@ async def test_add_v6(self):
class TranslatorDontCrossTheStreamsTestCase(TestTranslatorBaseCase):
- """Two translators in the same group, two in another group, single IP, ensure we get only the messages we expect."""
+ """Two translators in one group, two in another group, single IP, ensure we get only the messages we expect."""
- def local_setUp(self):
+ def local_setup(self):
+ """Define the actions and what we expect."""
self.actiontypes = ["block", "block", "noop", "noop"]
self.should_match = [True, True, False, False]
@@ -154,14 +162,21 @@ def local_setUp(self):
class TranslatorSequenceTestCase(TestTranslatorBaseCase):
"""Test a sequence of WebSocket messages."""
- def local_setUp(self):
+ def local_setup(self):
+ """Define the messages we want to send."""
wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo")
_ = WebSocketSequenceElement.objects.create(
- websocketmessage=wsm2, verb="A", action_type=self.actiontype, order_num=20
+ websocketmessage=wsm2,
+ verb="A",
+ action_type=self.actiontype,
+ order_num=20,
)
wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar")
_ = WebSocketSequenceElement.objects.create(
- websocketmessage=wsm3, verb="A", action_type=self.actiontype, order_num=2
+ websocketmessage=wsm3,
+ verb="A",
+ action_type=self.actiontype,
+ order_num=2,
)
self.generate_add_msgs = [
@@ -174,7 +189,8 @@ def local_setUp(self):
class TranslatorParametersTestCase(TestTranslatorBaseCase):
"""Additional parameters in the JSONField."""
- def local_setUp(self):
+ def local_setup(self):
+ """Define the message we want to send."""
wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route")
wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."}
wsm.save()
diff --git a/scram/route_manager/urls.py b/scram/route_manager/urls.py
index dde55639..2735cc64 100644
--- a/scram/route_manager/urls.py
+++ b/scram/route_manager/urls.py
@@ -1,3 +1,5 @@
+"""Register URLs known to Django, and the View that will handle each."""
+
from django.urls import path
from . import views
diff --git a/scram/route_manager/views.py b/scram/route_manager/views.py
index 54ff321a..fa8b7116 100644
--- a/scram/route_manager/views.py
+++ b/scram/route_manager/views.py
@@ -1,3 +1,5 @@
+"""Define the Views that will handle the HTTP requests."""
+
import ipaddress
import json
@@ -22,7 +24,10 @@
channel_layer = get_channel_layer()
-def home_page(request, prefilter=Entry.objects.all()):
+def home_page(request, prefilter=None):
+ """Return the home page, autocreating a user if none exists."""
+ if not prefilter:
+ prefilter = Entry.objects.all()
num_entries = settings.RECENT_LIMIT
if request.user.has_perms(("route_manager.view_entry", "route_manager.add_entry")):
readwrite = True
@@ -60,6 +65,7 @@ def home_page(request, prefilter=Entry.objects.all()):
def search_entries(request):
+ """Wrap the home page with a specified CIDR to restrict Entries to."""
# Using ipaddress because we needed to turn off strict mode
# (which netfields uses by default with seemingly no toggle)
# This caused searches with host bits set to 500 which is bad UX see: 68854ee1ad4789a62863083d521bddbc96ab7025
@@ -77,11 +83,14 @@ def search_entries(request):
@require_POST
@permission_required(["route_manager.view_entry", "route_manager.delete_entry"])
def delete_entry(request, pk):
+ """Wrap delete via the API and redirect to the home page."""
delete_entry_api(request, pk)
return redirect("route_manager:home")
class EntryDetailView(PermissionRequiredMixin, DetailView):
+ """Define a view for the API to use."""
+
permission_required = ["route_manager.view_entry"]
model = Entry
template_name = "route_manager/entry_detail.html"
@@ -92,35 +101,35 @@ class EntryDetailView(PermissionRequiredMixin, DetailView):
@permission_required(["route_manager.view_entry", "route_manager.add_entry"])
def add_entry(request):
+ """Send a WebSocket message when adding a new entry."""
with transaction.atomic():
res = add_entry_api(request)
- if res.status_code == 201:
+ if res.status_code == 201: # noqa: PLR2004
messages.add_message(
request,
messages.SUCCESS,
"Entry successfully added",
)
- elif res.status_code == 400:
+ elif res.status_code == 400: # noqa: PLR2004
errors = []
if isinstance(res.data, rest_framework.utils.serializer_helpers.ReturnDict):
for k, v in res.data.items():
- for error in v:
- errors.append(f"'{k}' error: {str(error)}")
+ errors.extend(f"'{k}' error: {error!s}" for error in v)
else:
- for k, v in res.data.items():
- errors.append(f"error: {str(v)}")
+ errors.extend(f"error: {v!s}" for v in res.data.values())
messages.add_message(request, messages.ERROR, "
".join(errors))
- elif res.status_code == 403:
+ elif res.status_code == 403: # noqa: PLR2004
messages.add_message(request, messages.ERROR, "Permission Denied")
else:
messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}")
with transaction.atomic():
home = home_page(request)
- return home
+ return home # noqa RET504
def process_expired(request):
+ """For entries with an expiration, set them to inactive if expired. Return some simple stats."""
current_time = timezone.now()
with transaction.atomic():
entries_start = Entry.objects.filter(is_active=True).count()
@@ -133,17 +142,21 @@ def process_expired(request):
{
"entries_deleted": entries_start - entries_end,
"active_entries": entries_end,
- }
+ },
),
content_type="application/json",
)
class EntryListView(ListView):
+ """Define a view for the API to use."""
+
model = Entry
template_name = "route_manager/entry_list.html"
- def get_context_data(self, **kwargs):
+ @staticmethod
+ def get_context_data(**kwargs):
+ """Group entries by action type."""
context = {"entries": {}}
for at in ActionType.objects.all():
queryset = Entry.objects.filter(actiontype=at).order_by("-pk")
diff --git a/scram/templates/account/account_inactive.html b/scram/templates/account/account_inactive.html
deleted file mode 100644
index 17c21577..00000000
--- a/scram/templates/account/account_inactive.html
+++ /dev/null
@@ -1,12 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-
-{% block head_title %}{% trans "Account Inactive" %}{% endblock %}
-
-{% block inner %}
-{% trans "Account Inactive" %}
-
-{% trans "This account is inactive." %}
-{% endblock %}
-
diff --git a/scram/templates/account/base.html b/scram/templates/account/base.html
deleted file mode 100644
index 97f4ddb8..00000000
--- a/scram/templates/account/base.html
+++ /dev/null
@@ -1,10 +0,0 @@
-{% extends "../base.html" %}
-{% block title %}{% block head_title %}{% endblock head_title %}{% endblock title %}
-
-{% block content %}
-
-
- {% block inner %}{% endblock %}
-
-
-{% endblock %}
diff --git a/scram/templates/account/email.html b/scram/templates/account/email.html
deleted file mode 100644
index 8eef4159..00000000
--- a/scram/templates/account/email.html
+++ /dev/null
@@ -1,80 +0,0 @@
-
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Account" %}{% endblock %}
-
-{% block inner %}
-{% trans "E-mail Addresses" %}
-
-{% if user.emailaddress_set.all %}
-{% trans 'The following e-mail addresses are associated with your account:' %}
-
-
-
-{% else %}
-{% trans 'Warning:'%} {% trans "You currently do not have any e-mail address set up. You should really add an e-mail address so you can receive notifications, reset your password, etc." %}
-
-{% endif %}
-
-
- {% trans "Add E-mail Address" %}
-
-
-
-{% endblock %}
-
-
-{% block inline_javascript %}
-{{ block.super }}
-
-{% endblock %}
-
diff --git a/scram/templates/account/email_confirm.html b/scram/templates/account/email_confirm.html
deleted file mode 100644
index 46c78126..00000000
--- a/scram/templates/account/email_confirm.html
+++ /dev/null
@@ -1,32 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load account %}
-
-{% block head_title %}{% trans "Confirm E-mail Address" %}{% endblock %}
-
-
-{% block inner %}
-{% trans "Confirm E-mail Address" %}
-
-{% if confirmation %}
-
-{% user_display confirmation.email_address.user as user_display %}
-
-{% blocktrans with confirmation.email_address.email as email %}Please confirm that {{ email }} is an e-mail address for user {{ user_display }}.{% endblocktrans %}
-
-
-
-{% else %}
-
-{% url 'account_email' as email_url %}
-
-{% blocktrans %}This e-mail confirmation link expired or is invalid. Please issue a new e-mail confirmation request.{% endblocktrans %}
-
-{% endif %}
-
-{% endblock %}
-
diff --git a/scram/templates/account/login.html b/scram/templates/account/login.html
deleted file mode 100644
index 2cadea6a..00000000
--- a/scram/templates/account/login.html
+++ /dev/null
@@ -1,48 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load account socialaccount %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Sign In" %}{% endblock %}
-
-{% block inner %}
-
-{% trans "Sign In" %}
-
-{% get_providers as socialaccount_providers %}
-
-{% if socialaccount_providers %}
-{% blocktrans with site.name as site_name %}Please sign in with one
-of your existing third party accounts. Or, sign up
-for a {{ site_name }} account and sign in below:{% endblocktrans %}
-
-
-
-
- {% include "socialaccount/snippets/provider_list.html" with process="login" %}
-
-
-
{% trans 'or' %}
-
-
-
-{% include "socialaccount/snippets/login_extra.html" %}
-
-{% else %}
-{% blocktrans %}If you have not created an account yet, then please
-sign up first.{% endblocktrans %}
-{% endif %}
-
-
-
-{% endblock %}
-
diff --git a/scram/templates/account/logout.html b/scram/templates/account/logout.html
deleted file mode 100644
index 8e2e6754..00000000
--- a/scram/templates/account/logout.html
+++ /dev/null
@@ -1,22 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-
-{% block head_title %}{% trans "Sign Out" %}{% endblock %}
-
-{% block inner %}
-{% trans "Sign Out" %}
-
-{% trans 'Are you sure you want to sign out?' %}
-
-
-
-
-{% endblock %}
-
diff --git a/scram/templates/account/password_change.html b/scram/templates/account/password_change.html
deleted file mode 100644
index b72ca068..00000000
--- a/scram/templates/account/password_change.html
+++ /dev/null
@@ -1,17 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Change Password" %}{% endblock %}
-
-{% block inner %}
- {% trans "Change Password" %}
-
-
-{% endblock %}
-
diff --git a/scram/templates/account/password_reset.html b/scram/templates/account/password_reset.html
deleted file mode 100644
index c98f28c1..00000000
--- a/scram/templates/account/password_reset.html
+++ /dev/null
@@ -1,25 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load account %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Password Reset" %}{% endblock %}
-
-{% block inner %}
-
- {% trans "Password Reset" %}
- {% if user.is_authenticated %}
- {% include "/snippets/already_logged_in.html" %}
- {% endif %}
-
- {% trans "Forgotten your password? Enter your e-mail address below, and we'll send you an e-mail allowing you to reset it." %}
-
-
-
- {% blocktrans %}Please contact us if you have any trouble resetting your password.{% endblocktrans %}
-{% endblock %}
diff --git a/scram/templates/account/password_reset_done.html b/scram/templates/account/password_reset_done.html
deleted file mode 100644
index 835156ca..00000000
--- a/scram/templates/account/password_reset_done.html
+++ /dev/null
@@ -1,16 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load account %}
-
-{% block head_title %}{% trans "Password Reset" %}{% endblock %}
-
-{% block inner %}
- {% trans "Password Reset" %}
-
- {% if user.is_authenticated %}
- {% include "/snippets/already_logged_in.html" %}
- {% endif %}
-
- {% blocktrans %}We have sent you an e-mail. Please contact us if you do not receive it within a few minutes.{% endblocktrans %}
-{% endblock %}
diff --git a/scram/templates/account/password_reset_from_key.html b/scram/templates/account/password_reset_from_key.html
deleted file mode 100644
index 2e2cd194..00000000
--- a/scram/templates/account/password_reset_from_key.html
+++ /dev/null
@@ -1,24 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load crispy_forms_tags %}
-{% block head_title %}{% trans "Change Password" %}{% endblock %}
-
-{% block inner %}
- {% if token_fail %}{% trans "Bad Token" %}{% else %}{% trans "Change Password" %}{% endif %}
-
- {% if token_fail %}
- {% url 'account_reset_password' as passwd_reset_url %}
- {% blocktrans %}The password reset link was invalid, possibly because it has already been used. Please request a new password reset.{% endblocktrans %}
- {% else %}
- {% if form %}
-
- {% else %}
- {% trans 'Your password is now changed.' %}
- {% endif %}
- {% endif %}
-{% endblock %}
diff --git a/scram/templates/account/password_reset_from_key_done.html b/scram/templates/account/password_reset_from_key_done.html
deleted file mode 100644
index 89be086f..00000000
--- a/scram/templates/account/password_reset_from_key_done.html
+++ /dev/null
@@ -1,10 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% block head_title %}{% trans "Change Password" %}{% endblock %}
-
-{% block inner %}
- {% trans "Change Password" %}
- {% trans 'Your password is now changed.' %}
-{% endblock %}
-
diff --git a/scram/templates/account/password_set.html b/scram/templates/account/password_set.html
deleted file mode 100644
index 22322235..00000000
--- a/scram/templates/account/password_set.html
+++ /dev/null
@@ -1,17 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Set Password" %}{% endblock %}
-
-{% block inner %}
- {% trans "Set Password" %}
-
-
-{% endblock %}
-
diff --git a/scram/templates/account/signup.html b/scram/templates/account/signup.html
deleted file mode 100644
index 6a2954eb..00000000
--- a/scram/templates/account/signup.html
+++ /dev/null
@@ -1,23 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-{% load crispy_forms_tags %}
-
-{% block head_title %}{% trans "Signup" %}{% endblock %}
-
-{% block inner %}
-{% trans "Sign Up" %}
-
-{% blocktrans %}Already have an account? Then please sign in.{% endblocktrans %}
-
-
-
-{% endblock %}
-
diff --git a/scram/templates/account/signup_closed.html b/scram/templates/account/signup_closed.html
deleted file mode 100644
index 2322f176..00000000
--- a/scram/templates/account/signup_closed.html
+++ /dev/null
@@ -1,12 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-
-{% block head_title %}{% trans "Sign Up Closed" %}{% endblock %}
-
-{% block inner %}
-{% trans "Sign Up Closed" %}
-
-{% trans "We are sorry, but the sign up is currently closed." %}
-{% endblock %}
-
diff --git a/scram/templates/account/verification_sent.html b/scram/templates/account/verification_sent.html
deleted file mode 100644
index ad093fd4..00000000
--- a/scram/templates/account/verification_sent.html
+++ /dev/null
@@ -1,13 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-
-{% block head_title %}{% trans "Verify Your E-mail Address" %}{% endblock %}
-
-{% block inner %}
- {% trans "Verify Your E-mail Address" %}
-
- {% blocktrans %}We have sent an e-mail to you for verification. Follow the link provided to finalize the signup process. Please contact us if you do not receive it within a few minutes.{% endblocktrans %}
-
-{% endblock %}
-
diff --git a/scram/templates/account/verified_email_required.html b/scram/templates/account/verified_email_required.html
deleted file mode 100644
index 09d4fde7..00000000
--- a/scram/templates/account/verified_email_required.html
+++ /dev/null
@@ -1,24 +0,0 @@
-{% extends "account/base.html" %}
-
-{% load i18n %}
-
-{% block head_title %}{% trans "Verify Your E-mail Address" %}{% endblock %}
-
-{% block inner %}
-{% trans "Verify Your E-mail Address" %}
-
-{% url 'account_email' as email_url %}
-
-{% blocktrans %}This part of the site requires us to verify that
-you are who you claim to be. For this purpose, we require that you
-verify ownership of your e-mail address. {% endblocktrans %}
-
-{% blocktrans %}We have sent an e-mail to you for
-verification. Please click on the link inside this e-mail. Please
-contact us if you do not receive it within a few minutes.{% endblocktrans %}
-
-{% blocktrans %}Note: you can still change your e-mail address.{% endblocktrans %}
-
-
-{% endblock %}
-
diff --git a/scram/templates/local_auth/login.html b/scram/templates/local_auth/login.html
new file mode 100644
index 00000000..e2355d27
--- /dev/null
+++ b/scram/templates/local_auth/login.html
@@ -0,0 +1,12 @@
+{% extends 'base.html' %}
+
+{% block title %}Login{% endblock %}
+
+{% block content %}
+ Login
+
+{% endblock %}
diff --git a/scram/users/__init__.py b/scram/users/__init__.py
index e69de29b..8728ced7 100644
--- a/scram/users/__init__.py
+++ b/scram/users/__init__.py
@@ -0,0 +1 @@
+"""The Users app defines users and a permission scheme for them."""
diff --git a/scram/users/admin.py b/scram/users/admin.py
index c69eb541..ae35bc3f 100644
--- a/scram/users/admin.py
+++ b/scram/users/admin.py
@@ -1,3 +1,5 @@
+"""Register Admin models with the Django Admin site."""
+
from django.contrib import admin
from django.contrib.auth import admin as auth_admin
from django.contrib.auth import get_user_model
@@ -10,6 +12,7 @@
@admin.register(User)
class UserAdmin(auth_admin.UserAdmin):
+ """Allow managing Users via the Django admin site."""
form = UserChangeForm
add_form = UserCreationForm
diff --git a/scram/users/api/serializers.py b/scram/users/api/serializers.py
index c95bfe62..f8d22c88 100644
--- a/scram/users/api/serializers.py
+++ b/scram/users/api/serializers.py
@@ -1,3 +1,5 @@
+"""Serializers provide mappings between the API and the underlying model."""
+
from django.contrib.auth import get_user_model
from rest_framework import serializers
@@ -5,7 +7,11 @@
class UserSerializer(serializers.ModelSerializer):
+ """Map to the User model."""
+
class Meta:
+ """Specify the fields exposed by the API."""
+
model = User
fields = ["username", "name", "url"]
diff --git a/scram/users/api/views.py b/scram/users/api/views.py
index 288ea7ab..0e17cb54 100644
--- a/scram/users/api/views.py
+++ b/scram/users/api/views.py
@@ -1,3 +1,5 @@
+"""Views provide mappings between the underlying model and how they're listed in the API."""
+
from django.contrib.auth import get_user_model
from rest_framework import status
from rest_framework.decorators import action
@@ -11,14 +13,19 @@
class UserViewSet(RetrieveModelMixin, ListModelMixin, UpdateModelMixin, GenericViewSet):
+ """Lookup Users by username."""
+
serializer_class = UserSerializer
queryset = User.objects.all()
lookup_field = "username"
def get_queryset(self, *args, **kwargs):
+ """Query on User ID."""
return self.queryset.filter(id=self.request.user.id)
+ @staticmethod
@action(detail=False, methods=["GET"])
- def me(self, request):
+ def me(request):
+ """Return the current user."""
serializer = UserSerializer(request.user, context={"request": request})
return Response(status=status.HTTP_200_OK, data=serializer.data)
diff --git a/scram/users/apps.py b/scram/users/apps.py
index 2ce72578..e6cc9c6b 100644
--- a/scram/users/apps.py
+++ b/scram/users/apps.py
@@ -1,3 +1,5 @@
+"""Register ourselves with Django."""
+
import logging
from django.apps import AppConfig
@@ -7,10 +9,14 @@
class UsersConfig(AppConfig):
+ """Define the name of the module for the Users app."""
+
name = "scram.users"
verbose_name = _("Users")
- def ready(self):
+ @staticmethod
+ def ready():
+ """Check if signals are registered for User events."""
try:
import scram.users.signals # noqa F401
except ImportError:
diff --git a/scram/users/forms.py b/scram/users/forms.py
index e2792fc2..b079f2a4 100644
--- a/scram/users/forms.py
+++ b/scram/users/forms.py
@@ -1,3 +1,5 @@
+"""Define forms for the Admin site."""
+
from django.contrib.auth import forms as admin_forms
from django.contrib.auth import get_user_model
from django.utils.translation import gettext_lazy as _
@@ -6,12 +8,20 @@
class UserChangeForm(admin_forms.UserChangeForm):
+ """Define a form to edit a User."""
+
class Meta(admin_forms.UserChangeForm.Meta):
+ """Map to the User model."""
+
model = User
class UserCreationForm(admin_forms.UserCreationForm):
+ """Define a form to create a User."""
+
class Meta(admin_forms.UserCreationForm.Meta):
+ """Map to the User model and provide custom error messages."""
+
model = User
error_messages = {"username": {"unique": _("This username has already been taken.")}}
diff --git a/scram/users/models.py b/scram/users/models.py
index c2e1914f..81a905ca 100644
--- a/scram/users/models.py
+++ b/scram/users/models.py
@@ -1,3 +1,5 @@
+"""Define models for the User application."""
+
from django.contrib.auth.models import AbstractUser
from django.db.models import CharField
from django.urls import reverse
diff --git a/scram/users/tests/__init__.py b/scram/users/tests/__init__.py
index e69de29b..f1101332 100644
--- a/scram/users/tests/__init__.py
+++ b/scram/users/tests/__init__.py
@@ -0,0 +1 @@
+"""Define tests for the User application."""
diff --git a/scram/users/tests/factories.py b/scram/users/tests/factories.py
index edd306cb..10331878 100644
--- a/scram/users/tests/factories.py
+++ b/scram/users/tests/factories.py
@@ -1,4 +1,7 @@
-from typing import Any, Sequence
+"""Define Factory tests for the Users application."""
+
+from collections.abc import Sequence
+from typing import Any
from django.contrib.auth import get_user_model
from factory import Faker, post_generation
@@ -6,6 +9,7 @@
class UserFactory(DjangoModelFactory):
+ """Test the UserFactory."""
username = Faker("user_name")
email = Faker("email")
@@ -13,6 +17,7 @@ class UserFactory(DjangoModelFactory):
@post_generation
def password(self, create: bool, extracted: Sequence[Any], **kwargs):
+ """Test password assignment."""
password = (
extracted
if extracted
@@ -28,5 +33,7 @@ def password(self, create: bool, extracted: Sequence[Any], **kwargs):
self.set_password(password)
class Meta:
+ """Map to User model."""
+
model = get_user_model()
django_get_or_create = ["username"]
diff --git a/scram/users/tests/test_admin.py b/scram/users/tests/test_admin.py
index 66f20cda..fd49e7ae 100644
--- a/scram/users/tests/test_admin.py
+++ b/scram/users/tests/test_admin.py
@@ -1,3 +1,5 @@
+"""Tests for the Admin site for the Users application."""
+
import pytest
from django.urls import reverse
@@ -7,17 +9,22 @@
class TestUserAdmin:
+ """Test various User operations via the Admin site."""
+
def test_changelist(self, admin_client):
+ """Ensure we can view the changelist."""
url = reverse("admin:users_user_changelist")
response = admin_client.get(url)
assert response.status_code == 200
def test_search(self, admin_client):
+ """Ensure we can view the search."""
url = reverse("admin:users_user_changelist")
response = admin_client.get(url, data={"q": "test"})
assert response.status_code == 200
def test_add(self, admin_client):
+ """Ensure we can add a user."""
url = reverse("admin:users_user_add")
response = admin_client.get(url)
assert response.status_code == 200
@@ -34,6 +41,7 @@ def test_add(self, admin_client):
assert User.objects.filter(username="test").exists()
def test_view_user(self, admin_client):
+ """Ensure we can view a user."""
user = User.objects.get(username="admin")
url = reverse("admin:users_user_change", kwargs={"object_id": user.pk})
response = admin_client.get(url)
diff --git a/scram/users/tests/test_drf_urls.py b/scram/users/tests/test_drf_urls.py
index 6e28c266..8e587486 100644
--- a/scram/users/tests/test_drf_urls.py
+++ b/scram/users/tests/test_drf_urls.py
@@ -1,3 +1,5 @@
+"""Test Django Request Framework use of the Users application."""
+
import pytest
from django.urls import resolve, reverse
@@ -7,15 +9,18 @@
def test_user_detail(user: User):
+ """Ensure we can view details for a single User."""
assert reverse("api:v1:user-detail", kwargs={"username": user.username}) == f"/api/v1/users/{user.username}/"
assert resolve(f"/api/v1/users/{user.username}/").view_name == "api:v1:user-detail"
def test_user_list():
+ """Ensure we can list all Users."""
assert reverse("api:v1:user-list") == "/api/v1/users/"
assert resolve("/api/v1/users/").view_name == "api:v1:user-list"
def test_user_me():
+ """Ensure we can view info for the current user."""
assert reverse("api:v1:user-me") == "/api/v1/users/me/"
assert resolve("/api/v1/users/me/").view_name == "api:v1:user-me"
diff --git a/scram/users/tests/test_drf_views.py b/scram/users/tests/test_drf_views.py
index 6e7b19e5..bbe56598 100644
--- a/scram/users/tests/test_drf_views.py
+++ b/scram/users/tests/test_drf_views.py
@@ -1,3 +1,5 @@
+"""Test Django Request Framework Views in the Users application."""
+
import pytest
from django.test import RequestFactory
@@ -8,7 +10,10 @@
class TestUserViewSet:
+ """Test a couple simple View operations."""
+
def test_get_queryset(self, user: User, rf: RequestFactory):
+ """Ensure we can view an arbitrary URL."""
view = UserViewSet()
request = rf.get("/fake-url/")
request.user = user
@@ -18,6 +23,7 @@ def test_get_queryset(self, user: User, rf: RequestFactory):
assert user in view.get_queryset()
def test_me(self, user: User, rf: RequestFactory):
+ """Ensure we can view info on the current user."""
view = UserViewSet()
request = rf.get("/fake-url/")
request.user = user
diff --git a/scram/users/tests/test_forms.py b/scram/users/tests/test_forms.py
index 8041e621..3fe65da4 100644
--- a/scram/users/tests/test_forms.py
+++ b/scram/users/tests/test_forms.py
@@ -1,8 +1,7 @@
-"""
-Module for all Form Tests.
-"""
+"""Module for all Form Tests."""
+
import pytest
-from django.utils.translation import ugettext_lazy as _
+from django.utils.translation import gettext_lazy as _
from scram.users.forms import UserCreationForm
from scram.users.models import User
@@ -11,18 +10,15 @@
class TestUserCreationForm:
- """
- Test class for all tests related to the UserCreationForm
- """
+ """Test class for all tests related to the UserCreationForm."""
def test_username_validation_error_msg(self, user: User):
- """
- Tests UserCreation Form's unique validator functions correctly by testing:
- 1) A new user with an existing username cannot be added.
- 2) Only 1 error is raised by the UserCreation Form
- 3) The desired error message is raised
- """
+ """Tests UserCreation Form's unique validator functions correctly by testing 3 things.
+ 1) A new user with an existing username cannot be added.
+ 2) Only 1 error is raised by the UserCreation Form
+ 3) The desired error message is raised
+ """
# The user already exists,
# hence cannot be created.
form = UserCreationForm(
@@ -30,7 +26,7 @@ def test_username_validation_error_msg(self, user: User):
"username": user.username,
"password1": user.password,
"password2": user.password,
- }
+ },
)
assert not form.is_valid()
diff --git a/scram/users/tests/test_models.py b/scram/users/tests/test_models.py
index 9920c0ff..a224cd7d 100644
--- a/scram/users/tests/test_models.py
+++ b/scram/users/tests/test_models.py
@@ -1,3 +1,5 @@
+"""Define tests for the Models in the Users application."""
+
import pytest
from scram.users.models import User
@@ -6,4 +8,5 @@
def test_user_get_absolute_url(user: User):
+ """Ensure w ecan convert a User object into a URL with info about that user."""
assert user.get_absolute_url() == f"/users/{user.username}/"
diff --git a/scram/users/tests/test_urls.py b/scram/users/tests/test_urls.py
index dd36a987..aa5a57ad 100644
--- a/scram/users/tests/test_urls.py
+++ b/scram/users/tests/test_urls.py
@@ -1,3 +1,5 @@
+"""Define tests for the URL resolution for the Users application."""
+
import pytest
from django.urls import resolve, reverse
@@ -7,15 +9,18 @@
def test_detail(user: User):
+ """Ensure we can get the URL to view details about a single User."""
assert reverse("users:detail", kwargs={"username": user.username}) == f"/users/{user.username}/"
assert resolve(f"/users/{user.username}/").view_name == "users:detail"
def test_update():
+ """Ensure we can get the URL to update a User."""
assert reverse("users:update") == "/users/~update/"
assert resolve("/users/~update/").view_name == "users:update"
def test_redirect():
+ """Ensure we can get the URL to redirect to a User."""
assert reverse("users:redirect") == "/users/~redirect/"
assert resolve("/users/~redirect/").view_name == "users:redirect"
diff --git a/scram/users/tests/test_views.py b/scram/users/tests/test_views.py
index 708f35e5..2df1929a 100644
--- a/scram/users/tests/test_views.py
+++ b/scram/users/tests/test_views.py
@@ -1,3 +1,5 @@
+"""Define tests for the Views of the Users application."""
+
import pytest
from django.conf import settings
from django.contrib import messages
@@ -16,15 +18,18 @@
class TestUserUpdateView:
- """
- TODO:
+ """Define tests related to the Update View.
+
+ Todo:
extracting view initialization code as class-scoped fixture
would be great if only pytest-django supported non-function-scoped
fixture db access -- this is a work-in-progress for now:
https://github.com/pytest-dev/pytest-django/pull/258
+
"""
def test_get_success_url(self, user: User, rf: RequestFactory):
+ """Ensure we can view an arbitrary URL."""
view = UserUpdateView()
request = rf.get("/fake-url/")
request.user = user
@@ -34,6 +39,7 @@ def test_get_success_url(self, user: User, rf: RequestFactory):
assert view.get_success_url() == f"/users/{user.username}/"
def test_get_object(self, user: User, rf: RequestFactory):
+ """Ensure we can retrieve the User object."""
view = UserUpdateView()
request = rf.get("/fake-url/")
request.user = user
@@ -43,12 +49,13 @@ def test_get_object(self, user: User, rf: RequestFactory):
assert view.get_object() == user
def test_form_valid(self, user: User, rf: RequestFactory):
+ """Ensure the form validation works."""
view = UserUpdateView()
request = rf.get("/fake-url/")
# Add the session/message middleware to the request
- SessionMiddleware().process_request(request)
- MessageMiddleware().process_request(request)
+ SessionMiddleware(get_response=request).process_request(request)
+ MessageMiddleware(get_response=request).process_request(request)
request.user = user
view.request = request
@@ -63,7 +70,10 @@ def test_form_valid(self, user: User, rf: RequestFactory):
class TestUserRedirectView:
+ """Define tests related to the Redirect View."""
+
def test_get_redirect_url(self, user: User, rf: RequestFactory):
+ """Ensure the redirection URL works."""
view = UserRedirectView()
request = rf.get("/fake-url")
request.user = user
@@ -74,7 +84,10 @@ def test_get_redirect_url(self, user: User, rf: RequestFactory):
class TestUserDetailView:
+ """Define tests related to the User Detail View."""
+
def test_authenticated(self, user: User, rf: RequestFactory):
+ """Ensure an authenticated user has access."""
request = rf.get("/fake-url/")
request.user = UserFactory()
@@ -83,6 +96,7 @@ def test_authenticated(self, user: User, rf: RequestFactory):
assert response.status_code == 200
def test_not_authenticated(self, user: User, rf: RequestFactory):
+ """Ensure an unauthenticated user gets redirected to login."""
request = rf.get("/fake-url/")
request.user = AnonymousUser()
diff --git a/scram/users/urls.py b/scram/users/urls.py
index 64f9d9cb..cf8647f3 100644
--- a/scram/users/urls.py
+++ b/scram/users/urls.py
@@ -1,3 +1,5 @@
+"""Register URLs known to Django, and the View that will handle each."""
+
from django.urls import path
from scram.users.views import user_detail_view, user_redirect_view, user_update_view
diff --git a/scram/users/views.py b/scram/users/views.py
index 2b077d42..e05f5cec 100644
--- a/scram/users/views.py
+++ b/scram/users/views.py
@@ -1,3 +1,5 @@
+"""Define the Views that will handle the HTTP requests."""
+
from django.contrib.auth import get_user_model
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.messages.views import SuccessMessageMixin
@@ -9,6 +11,7 @@
class UserDetailView(LoginRequiredMixin, DetailView):
+ """Map this view to the User model, and rely on the DetailView generic view."""
model = User
slug_field = "username"
@@ -19,15 +22,18 @@ class UserDetailView(LoginRequiredMixin, DetailView):
class UserUpdateView(LoginRequiredMixin, SuccessMessageMixin, UpdateView):
+ """Map this view to the User model, and rely on the UpdateView generic view."""
model = User
fields = ["name"]
success_message = _("Information successfully updated")
def get_success_url(self):
+ """Return the User detail view."""
return reverse("users:detail", kwargs={"username": self.request.user.username})
def get_object(self):
+ """Return the User object."""
return self.request.user
@@ -35,10 +41,12 @@ def get_object(self):
class UserRedirectView(LoginRequiredMixin, RedirectView):
+ """Map this view to the User model, and rely on the RedirectView generic view."""
permanent = False
def get_redirect_url(self):
+ """Return the User detail view."""
return reverse("users:detail", kwargs={"username": self.request.user.username})
diff --git a/scram/utils/__init__.py b/scram/utils/__init__.py
index e69de29b..e5a29f6c 100644
--- a/scram/utils/__init__.py
+++ b/scram/utils/__init__.py
@@ -0,0 +1 @@
+"""A module for defining utility functions and code."""
diff --git a/scram/utils/context_processors.py b/scram/utils/context_processors.py
index 3c535141..0bbf8786 100644
--- a/scram/utils/context_processors.py
+++ b/scram/utils/context_processors.py
@@ -1,8 +1,14 @@
+"""Define custom functions that take a request and add to the context before template rendering."""
+
from django.conf import settings
def settings_context(_request):
- """Settings available by default to the templates context."""
+ """Define settings available by default to the templates context.
+
+ Returns:
+ dict: Whether or not we have DEBUG on
+ """
# Note: we intentionally do NOT expose the entire settings
# to prevent accidental leaking of sensitive information
return {"DEBUG": settings.DEBUG}
diff --git a/setup.cfg b/setup.cfg
index 84ec37bb..2c082581 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -24,12 +24,6 @@ django_settings_module = config.settings.test
# Django migrations should not produce any errors:
ignore_errors = True
-[coverage:run]
-include = scram/*
-omit = *migrations*, *tests*
-plugins =
- django_coverage_plugin
-
[behave]
paths = scram/route_manager/tests/acceptance
stderr_capture = no
diff --git a/translator/exceptions.py b/translator/exceptions.py
index f1a54299..b6d2cfd6 100644
--- a/translator/exceptions.py
+++ b/translator/exceptions.py
@@ -1,9 +1,5 @@
-"""
-This module holds all of the exceptions we want to raise in our translators.
-"""
+"""Define all of the exceptions we want to raise in our translators."""
class ASNError(TypeError):
- """
- ASNError provides an error class to use when there is an issue with an Autonomous System Number.
- """
+ """ASNError provides an error class to use when there is an issue with an Autonomous System Number."""
diff --git a/translator/gobgp.py b/translator/gobgp.py
index 23ae2572..f2b8dee9 100644
--- a/translator/gobgp.py
+++ b/translator/gobgp.py
@@ -1,3 +1,5 @@
+"""A translator interface for GoBGP (https://github.com/osrg/gobgp)."""
+
import logging
import attribute_pb2
@@ -13,23 +15,32 @@
DEFAULT_COMMUNITY = 666
DEFAULT_V4_NEXTHOP = "192.0.2.199"
DEFAULT_V6_NEXTHOP = "100::1"
+MAX_SMALL_ASN = 2**16
+MAX_SMALL_COMM = 2**16
+IPV6 = 6
logging.basicConfig(level=logging.DEBUG)
+logger = logging.getLogger(__name__)
+
+class GoBGP:
+ """Represents a GoBGP instance."""
-class GoBGP(object):
def __init__(self, url):
+ """Configure the channel used for communication."""
channel = grpc.insecure_channel(url)
self.stub = gobgp_pb2_grpc.GobgpApiStub(channel)
- def _get_family_AFI(self, ip_version):
- if ip_version == 6:
+ @staticmethod
+ def _get_family_afi(ip_version):
+ if ip_version == IPV6:
return gobgp_pb2.Family.AFI_IP6
- else:
- return gobgp_pb2.Family.AFI_IP
+ return gobgp_pb2.Family.AFI_IP
- def _build_path(self, ip, event_data={}):
+ def _build_path(self, ip, event_data=None): # noqa: PLR0914
# Grab ASN and Community from our event_data, or use the defaults
+ if not event_data:
+ event_data = {}
asn = event_data.get("asn", DEFAULT_ASN)
community = event_data.get("community", DEFAULT_COMMUNITY)
ip_version = ip.ip.version
@@ -42,7 +53,7 @@ def _build_path(self, ip, event_data={}):
origin.Pack(
attribute_pb2.OriginAttribute(
origin=2,
- )
+ ),
)
# IP prefix and its associated length
@@ -51,27 +62,27 @@ def _build_path(self, ip, event_data={}):
attribute_pb2.IPAddressPrefix(
prefix_len=ip.network.prefixlen,
prefix=str(ip.ip),
- )
+ ),
)
# Set the next hop to the correct value depending on IP family
next_hop = Any()
- family_afi = self._get_family_AFI(ip_version)
- if ip_version == 6:
+ family_afi = self._get_family_afi(ip_version)
+ if ip_version == IPV6:
next_hops = event_data.get("next_hop", DEFAULT_V6_NEXTHOP)
next_hop.Pack(
attribute_pb2.MpReachNLRIAttribute(
family=gobgp_pb2.Family(afi=family_afi, safi=gobgp_pb2.Family.SAFI_UNICAST),
next_hops=[next_hops],
nlris=[nlri],
- )
+ ),
)
else:
next_hops = event_data.get("next_hop", DEFAULT_V4_NEXTHOP)
next_hop.Pack(
attribute_pb2.NextHopAttribute(
next_hop=next_hops,
- )
+ ),
)
# Set our AS Path
@@ -89,13 +100,13 @@ def _build_path(self, ip, event_data={}):
communities = Any()
# Standard community
# Since we pack both into the community string we need to make sure they will both fit
- if asn < 65536 and community < 65536:
+ if asn < MAX_SMALL_ASN and community < MAX_SMALL_COMM:
# We bitshift ASN left by 16 so that there is room to add the community on the end of it. This is because
# GoBGP wants the community sent as a single integer.
comm_id = (asn << 16) + community
communities.Pack(attribute_pb2.CommunitiesAttribute(communities=[comm_id]))
else:
- logging.info(f"LargeCommunity Used - ASN:{asn} Community: {community}")
+ logger.info("LargeCommunity Used - ASN: %s. Community: %s", asn, community)
global_admin = asn
local_data1 = community
# set to 0 because there's no use case for it, but we need a local_data2 for gobgp to read any of it
@@ -116,7 +127,8 @@ def _build_path(self, ip, event_data={}):
)
def add_path(self, ip, event_data):
- logging.info(f"Blocking {ip}")
+ """Announce a single route."""
+ logger.info("Blocking %s", ip)
try:
path = self._build_path(ip, event_data)
@@ -125,15 +137,17 @@ def add_path(self, ip, event_data):
_TIMEOUT_SECONDS,
)
except ASNError as e:
- logging.warning(f"ASN assertion failed with error: {e}")
+ logger.warning("ASN assertion failed with error: %s", e)
def del_all_paths(self):
- logging.warning("Withdrawing ALL routes")
+ """Remove all routes from being announced."""
+ logger.warning("Withdrawing ALL routes")
self.stub.DeletePath(gobgp_pb2.DeletePathRequest(table_type=gobgp_pb2.GLOBAL), _TIMEOUT_SECONDS)
def del_path(self, ip, event_data):
- logging.info(f"Unblocking {ip}")
+ """Remove a single route from being announced."""
+ logger.info("Unblocking %s", ip)
try:
path = self._build_path(ip, event_data)
self.stub.DeletePath(
@@ -141,11 +155,16 @@ def del_path(self, ip, event_data):
_TIMEOUT_SECONDS,
)
except ASNError as e:
- logging.warning(f"ASN assertion failed with error: {e}")
+ logger.warning("ASN assertion failed with error: %s", e)
def get_prefixes(self, ip):
+ """Retrieve the routes that match a prefix and are announced.
+
+ Returns:
+ list: The routes that overlap with the prefix and are currently announced.
+ """
prefixes = [gobgp_pb2.TableLookupPrefix(prefix=str(ip.ip))]
- family_afi = self._get_family_AFI(ip.ip.version)
+ family_afi = self._get_family_afi(ip.ip.version)
result = self.stub.ListPath(
gobgp_pb2.ListPathRequest(
table_type=gobgp_pb2.GLOBAL,
@@ -157,4 +176,5 @@ def get_prefixes(self, ip):
return list(result)
def is_blocked(self, ip):
+ """Return True if at least one route matching the prefix is being announced."""
return len(self.get_prefixes(ip)) > 0
diff --git a/translator/shared.py b/translator/shared.py
index 1fe048bd..78555b2d 100644
--- a/translator/shared.py
+++ b/translator/shared.py
@@ -1,13 +1,12 @@
-"""
-This module provides a location for code that we want to share between all translators.
-"""
+"""Provide a location for code that we want to share between all translators."""
from exceptions import ASNError
+MAX_ASN_VAL = 2**32 - 1
+
def asn_is_valid(asn: int) -> bool:
- """
- asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN.
+ """asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN.
Args:
asn (int): The Autonomous System Number that we want to validate
@@ -17,11 +16,14 @@ def asn_is_valid(asn: int) -> bool:
Returns:
bool: _description_
+
"""
if not isinstance(asn, int):
- raise ASNError(f"ASN {asn} is not an Integer, has type {type(asn)}")
- if not 0 < asn < 4294967295:
+ msg = f"ASN {asn} is not an Integer, has type {type(asn)}"
+ raise ASNError(msg)
+ if not 0 < asn < MAX_ASN_VAL:
# This is the max as stated in rfc6996
- raise ASNError(f"ASN {asn} is out of range. Must be between 0 and 4294967295")
+ msg = f"ASN {asn} is out of range. Must be between 0 and 4294967295"
+ raise ASNError(msg)
return True
diff --git a/translator/tests/acceptance/environment.py b/translator/tests/acceptance/environment.py
index 1f06b08d..b5c57b41 100644
--- a/translator/tests/acceptance/environment.py
+++ b/translator/tests/acceptance/environment.py
@@ -1,6 +1,9 @@
+"""Configure the test environment before executing acceptance tests."""
+
from gobgp import GoBGP
def before_all(context):
+ """Create a GoBGP object."""
context.gobgp = GoBGP("gobgp:50051")
context.config.setup_logging()
diff --git a/translator/tests/acceptance/features/clean_slate.feature b/translator/tests/acceptance/features/clean_slate.feature
new file mode 100644
index 00000000..9de4e337
--- /dev/null
+++ b/translator/tests/acceptance/features/clean_slate.feature
@@ -0,0 +1,8 @@
+Feature: Tests run correctly
+ When running tests, there are no leftover blocks
+
+ Scenario: No leftover v4 blocks
+ Then 0.0.0.0/0 is unblocked
+
+ Scenario: No leftover v6 blocks
+ Then ::/0 is unblocked
diff --git a/translator/tests/acceptance/steps/actions.py b/translator/tests/acceptance/steps/actions.py
index ce340de1..fdb2a95f 100644
--- a/translator/tests/acceptance/steps/actions.py
+++ b/translator/tests/acceptance/steps/actions.py
@@ -1,3 +1,5 @@
+"""Define the steps used by Behave."""
+
import ipaddress
import logging
import time
@@ -10,6 +12,7 @@
@when("we add {route} with {asn} and {community} to the block list")
def add_block(context, route, asn, community):
+ """Block a single IP."""
ip = ipaddress.ip_interface(route)
event_data = {"asn": int(asn), "community": int(community)}
context.gobgp.add_path(ip, event_data)
@@ -17,36 +20,41 @@ def add_block(context, route, asn, community):
@then("we delete {route} with {asn} and {community} from the block list")
def del_block(context, route, asn, community):
+ """Remove a single IP."""
ip = ipaddress.ip_interface(route)
event_data = {"asn": int(asn), "community": int(community)}
context.gobgp.del_path(ip, event_data)
def get_block_status(context, ip):
+ """Check if the IP is currently blocked.
+
+ Returns:
+ bool: The return value. True if the IP is currently blocked, False otherwise.
+ """
# Allow our add/delete requests to settle
time.sleep(1)
ip_obj = ipaddress.ip_interface(ip)
- for path in context.gobgp.get_prefixes(ip_obj):
- if ip_obj in ipaddress.ip_network(path.destination.prefix):
- return True
-
- return False
+ return any(ip_obj in ipaddress.ip_network(path.destination.prefix) for path in context.gobgp.get_prefixes(ip_obj))
@capture
@when("{route} and {community} with invalid {asn} is sent")
def asn_validation_fails(context, route, asn, community):
+ """Ensure the ASN was invalid."""
add_block(context, route, asn, community)
assert context.log_capture.find_event("ASN assertion failed")
@then("{ip} is blocked")
def check_block(context, ip):
+ """Ensure that the IP is currently blocked."""
assert get_block_status(context, ip)
@then("{ip} is unblocked")
def check_unblock(context, ip):
+ """Ensure that the IP is currently unblocked."""
assert not get_block_status(context, ip)
diff --git a/translator/translator.py b/translator/translator.py
index 8521d7ab..d2c34018 100644
--- a/translator/translator.py
+++ b/translator/translator.py
@@ -1,5 +1,7 @@
#!/usr/bin/env python3
+"""Define the main event loop for the translator."""
+
import asyncio
import ipaddress
import json
@@ -9,28 +11,41 @@
import websockets
from gobgp import GoBGP
+logger = logging.getLogger(__name__)
+
+KNOWN_MESSAGES = {
+ "translator_add",
+ "translator_remove",
+ "translator_remove_all",
+ "translator_check",
+}
+
# Here we setup a debugger if this is desired. This obviously should not be run in production.
debug_mode = os.environ.get("DEBUG")
if debug_mode:
def install_deps():
- # Because of how we build translator currently, we don't have a great way to selectively install things at
- # build, so we just do it here! Right now this also includes base.txt, which is unecessary, but in the
- # future when we build a little better, it'll already be setup.
- logging.info("Installing dependencies for debuggers")
+ """Install necessary dependencies for debuggers.
+
+ Because of how we build translator currently, we don't have a great way to selectively
+ install things at build, so we just do it here! Right now this also includes base.txt,
+ which is unecessary, but in the future when we build a little better, it'll already be
+ setup.
+ """
+ logger.info("Installing dependencies for debuggers")
- import subprocess
- import sys
+ import subprocess # noqa: S404, PLC0415
+ import sys # noqa: PLC0415
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"])
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) # noqa: S603 TODO: add this to the container build
- logging.info("Done installing dependencies for debuggers")
+ logger.info("Done installing dependencies for debuggers")
- logging.info(f"Translator is set to use a debugger. Provided debug mode: {debug_mode}")
+ logger.info("Translator is set to use a debugger. Provided debug mode: %s", debug_mode)
# We have to setup the debugger appropriately for various IDEs. It'd be nice if they all used the same thing but
# sadly, we live in a fallen world.
if debug_mode == "pycharm-pydevd":
- logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!")
+ logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!")
install_deps()
@@ -38,66 +53,67 @@ def install_deps():
pydevd_pycharm.settrace("host.docker.internal", port=56782, stdoutToServer=True, stderrToServer=True)
- logging.info("Debugger started.")
+ logger.info("Debugger started.")
elif debug_mode == "debugpy":
- logging.info("Entering debug mode for debugpy (VSCode)")
+ logger.info("Entering debug mode for debugpy (VSCode)")
install_deps()
import debugpy
- debugpy.listen(("0.0.0.0", 56781))
+ debugpy.listen(("0.0.0.0", 56781)) # noqa S104 (doesn't like binding to all interfaces)
- logging.info("Debugger listening on port 56781.")
+ logger.info("Debugger listening on port 56781.")
else:
- logging.warning(f"Invalid debug mode given: {debug_mode}. Debugger not started")
+ logger.warning("Invalid debug mode given: %s. Debugger not started", debug_mode)
# Must match the URL in asgi.py, and needs a trailing slash
hostname = os.environ.get("SCRAM_HOSTNAME", "scram_hostname_not_set")
url = os.environ.get("SCRAM_EVENTS_URL", "ws://django:8000/ws/route_manager/translator_block/")
+async def process(message, websocket, g):
+ """Take a single message form the websocket and hand it off to the appropriate function."""
+ json_message = json.loads(message)
+ event_type = json_message.get("type")
+ event_data = json_message.get("message")
+ if event_type not in KNOWN_MESSAGES:
+ logger.error("Unknown event type received: %s", event_type)
+ # TODO: Maybe only allow this in testing?
+ elif event_type == "translator_remove_all":
+ g.del_all_paths()
+ else:
+ try:
+ ip = ipaddress.ip_interface(event_data["route"])
+ except: # noqa E722
+ logger.exception("Error parsing message: %s", message)
+ return
+
+ if event_type == "translator_add":
+ g.add_path(ip, event_data)
+ elif event_type == "translator_remove":
+ g.del_path(ip, event_data)
+ elif event_type == "translator_check":
+ json_message["type"] = "translator_check_resp"
+ json_message["message"]["is_blocked"] = g.is_blocked(ip)
+ json_message["message"]["translator_name"] = hostname
+ await websocket.send(json.dumps(json_message))
+
+
async def main():
+ """Connect to the websocket and start listening for messages."""
g = GoBGP("gobgp:50051")
async for websocket in websockets.connect(url):
try:
async for message in websocket:
- json_message = json.loads(message)
- event_type = json_message.get("type")
- event_data = json_message.get("message")
- if event_type not in [
- "translator_add",
- "translator_remove",
- "translator_remove_all",
- "translator_check",
- ]:
- logging.error(f"Unknown event type received: {event_type!r}")
- # TODO: Maybe only allow this in testing?
- elif event_type == "translator_remove_all":
- g.del_all_paths()
- else:
- try:
- ip = ipaddress.ip_interface(event_data["route"])
- except: # noqa E722
- logging.error(f"Error parsing message: {message!r}")
- continue
-
- if event_type == "translator_add":
- g.add_path(ip, event_data)
- elif event_type == "translator_remove":
- g.del_path(ip, event_data)
- elif event_type == "translator_check":
- json_message["type"] = "translator_check_resp"
- json_message["message"]["is_blocked"] = g.is_blocked(ip)
- json_message["message"]["translator_name"] = hostname
- await websocket.send(json.dumps(json_message))
+ await process(message, websocket, g)
except websockets.ConnectionClosed:
continue
if __name__ == "__main__":
- logging.info("translator started")
+ logger.info("translator started")
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()