diff --git a/.ci-scripts/create_user.py b/.ci-scripts/create_user.py deleted file mode 100755 index 1b263b8b..00000000 --- a/.ci-scripts/create_user.py +++ /dev/null @@ -1,11 +0,0 @@ -import django - -django.setup() - -from scram.users.models import User # noqa:E402 - -u, created = User.objects.get_or_create(username="admin") -u.set_password("password") -u.is_staff = True -u.is_superuser = True -u.save() diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 457c5db9..00000000 --- a/.coveragerc +++ /dev/null @@ -1,3 +0,0 @@ -[run] -branch = True -data_file = coverage.coverage diff --git a/.github/workflows/behave.yml b/.github/workflows/behave.yml new file mode 100644 index 00000000..34fb68f0 --- /dev/null +++ b/.github/workflows/behave.yml @@ -0,0 +1,91 @@ +--- +name: Run behave + +on: + push: + branches: + - "**" + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + behave: + name: Run Behave + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ["3.11", "3.12"] + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Docker + uses: docker/setup-buildx-action@v3 + + - name: Install Docker Compose + run: | + sudo apt-get update + sudo apt-get install -y docker-compose make + + - name: Check Docker state (pre-build) + run: docker ps + + - name: Build Docker images + run: make build + env: + PYTHON_IMAGE_VER: "${{ matrix.python-version }}" + + - name: Migrate Database + run: make migrate + + - name: Run Application + run: make run + + - name: Check Docker state (pre-test) + run: docker ps + + - name: Run pytest + behave with Coverage + env: + POSTGRES_USER: scram + POSTGRES_DB: test_scram + run: make coverage.xml + + - name: Dump docker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 + + - name: Upload Coverage to Coveralls + if: matrix.python-version == '3.12' + uses: coverallsapp/github-action@v2 + + - name: Upload Coverage to GitHub + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + + - name: Display Coverage Metrics + if: matrix.python-version == '3.12' + uses: 5monkeys/cobertura-action@v14 + with: + minimum_coverage: "50" + + - name: Check Docker state (post-test) + if: always() + run: docker ps + + - name: Stop Services + if: always() + run: make stop + + - name: Clean Up + if: always() + run: make clean diff --git a/.github/workflows/behave_next_python.yml b/.github/workflows/behave_next_python.yml new file mode 100644 index 00000000..1c4b1a3c --- /dev/null +++ b/.github/workflows/behave_next_python.yml @@ -0,0 +1,77 @@ +--- +name: Run behave with unsupported Python versions + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + behave_next_python: + name: Run Behave + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.13'] + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Docker + uses: docker/setup-buildx-action@v3 + + - name: Install Docker Compose + run: | + sudo apt-get update + sudo apt-get install -y docker-compose make + + - name: Check Docker state (pre-build) + run: docker ps + + - name: Build Docker images + run: make build + env: + PYTHON_IMAGE_VER: "${{ matrix.python-version }}" + + - name: Migrate Database + run: | + make migrate || echo "::warning:: migrate failed on future Python version ${{ matrix.python-version }}." + + - name: Run Application + run: | + make run || echo "::warning:: run failed on future Python version ${{ matrix.python-version }}." + + - name: Check Docker state (pre-test) + run: docker ps + + - name: Run pytest + behave with Coverage + env: + POSTGRES_USER: scram + POSTGRES_DB: test_scram + run: | + make coverage.xml || echo "::warning:: pytest + behave failed on future Python version ${{ matrix.python-version }}." + + - name: Dump docker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 + + - name: Check Docker state (post-test) + if: always() + run: docker ps + + - name: Stop Services + if: always() + run: make stop + + - name: Clean Up + if: always() + run: make clean diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1e875dc4..223f58d5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,9 +1,9 @@ name: Build sphinx docs on: - # Runs on pushes targeting the default branch push: - branches: ["main"] + branches: + - 'main' # Allows you to run this workflow manually from the Actions tab workflow_dispatch: @@ -45,8 +45,10 @@ jobs: uses: actions/upload-pages-artifact@v3 with: # Upload entire repository - path: 'docs/_build/html' + path: 'site/' - name: Deploy to GitHub Pages + if: github.ref == 'refs/heads/main' id: deployment uses: actions/deploy-pages@v4 + diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..5d4004d2 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,82 @@ +--- +name: Run pytest + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + pytest: + name: Run Pytest + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.11', '3.12'] + + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: scram + POSTGRES_PASSWORD: '' + POSTGRES_DB: test_scram_${{ matrix.python-version }} + POSTGRES_HOST_AUTH_METHOD: trust + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U scram" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + redis: + image: redis:5.0 + ports: + - 6379:6379 + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/local.txt + # https://github.com/pytest-dev/pytest-github-actions-annotate-failures/pull/68 isn't yet in a release + pip install git+https://github.com/pytest-dev/pytest-github-actions-annotate-failures.git@6e66cd895fe05cd09be8bad58f5d79110a20385f + + - name: Apply migrations + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + run: | + python manage.py makemigrations --noinput + python manage.py migrate --noinput + + - name: Check for duplicate migrations + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + run: | + if python manage.py makemigrations --dry-run | grep "No changes detected"; then + echo "No duplicate migrations detected." + else + echo "::warning:: Potential duplicate migrations detected. Please review." + fi + + - name: Run Pytest + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + REDIS_HOST: "localhost" + run: pytest diff --git a/.github/workflows/pytest_next_python.yml b/.github/workflows/pytest_next_python.yml new file mode 100644 index 00000000..dbc9d4e1 --- /dev/null +++ b/.github/workflows/pytest_next_python.yml @@ -0,0 +1,88 @@ +--- +name: Run pytest with unsupported Python versions + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + pytest_next_python: + name: Run Pytest + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.13'] + + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: scram + POSTGRES_PASSWORD: '' + POSTGRES_DB: test_scram + POSTGRES_HOST_AUTH_METHOD: trust + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U scram" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + redis: + image: redis:5.0 + ports: + - 6379:6379 + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/local.txt + # https://github.com/pytest-dev/pytest-github-actions-annotate-failures/pull/68 isn't yet in a release + pip install git+https://github.com/pytest-dev/pytest-github-actions-annotate-failures.git@6e66cd895fe05cd09be8bad58f5d79110a20385f + + - name: Apply unapplied migrations + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram" + run: | + python manage.py makemigrations --noinput || true + UNAPPLIED_MIGRATIONS=$(python manage.py showmigrations --plan | grep '\[ \]' | awk '{print $2}') + if [ -n "$UNAPPLIED_MIGRATIONS" ]; then + for migration in $UNAPPLIED_MIGRATIONS; do + python manage.py migrate $migration --fake-initial --noinput + done + else + echo "No unapplied migrations." + fi + + - name: Check for duplicate migrations + run: | + if python manage.py makemigrations --dry-run | grep "No changes detected"; then + echo "No duplicate migrations detected." + else + echo "Warning: Potential duplicate migrations detected. Please review." + fi + + - name: Run Pytest + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram" + REDIS_HOST: "localhost" + run: | + pytest || echo "::warning:: Failed on future Python version ${{ matrix.python-version }}." diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 00000000..31647732 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,29 @@ +name: Run ruff + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: pip install ruff + + - name: Fail if we have any style errors + run: ruff check --output-format=github + + - name: Fail if the code is not formatted correctly (like Black) + run: ruff format --diff diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml deleted file mode 100644 index 435a4812..00000000 --- a/.gitlab-ci.yml +++ /dev/null @@ -1,68 +0,0 @@ -include: -- template: Security/Dependency-Scanning.gitlab-ci.yml -- template: Security/Secret-Detection.gitlab-ci.yml -- template: Code-Quality.gitlab-ci.yml -- template: Security/SAST.gitlab-ci.yml - -stages: -- lint -- test - -variables: - POSTGRES_USER: scram - POSTGRES_PASSWORD: '' - POSTGRES_DB: test_scram - POSTGRES_HOST_AUTH_METHOD: trust - -flake8: - stage: lint - image: python:3.8-alpine - before_script: - - pip install -q flake8 - script: - - flake8 - -pytest: - stage: test - image: docker:24.0.6-dind - services: - - docker:dind - variables: - POSTGRES_ENABLED: 1 - before_script: - - apk add make - - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY - - export COMPOSE_PROJECT_NAME=$CI_PIPELINE_ID - - ".ci-scripts/pull_images.sh" - - make build - - ".ci-scripts/push_images.sh" - - make migrate - - make run - script: - - export COMPOSE_PROJECT_NAME=$CI_PIPELINE_ID - - make coverage.xml - artifacts: - reports: - coverage_report: - coverage_format: cobertura - path: coverage.xml - after_script: - - export COMPOSE_PROJECT_NAME=$CI_PIPELINE_ID - - make stop - - make clean - - -gemnasium-dependency_scanning: - variables: - PIP_REQUIREMENTS_FILE: requirements/base.txt - -code_quality_html: - extends: code_quality - variables: - REPORT_FORMAT: html - artifacts: - paths: - - gl-code-quality-report.html - -sast: - stage: test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48d2a7b2..08d8da08 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: 'docs|node_modules|migrations|.git|.tox' +exclude: 'docs|migrations|.git|.tox' default_stages: [commit] fail_fast: true @@ -9,17 +9,11 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - - repo: https://github.com/psf/black - rev: 23.9.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.0 hooks: - - id: black - - - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 - hooks: - - id: isort - - - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/Makefile b/Makefile index 74556490..347eb58c 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ +# It'd be nice to keep these in sync with the defaults of the Dockerfiles +PYTHON_IMAGE_VER ?= 3.12 +POSTGRES_IMAGE_VER ?= 12.3 + .DEFAULT_GOAL := help ## toggle-prod: configure make to use the production stack @@ -37,7 +41,8 @@ behave-translator: compose.override.yml ## build: rebuilds all your containers or a single one if CONTAINER is specified .Phony: build build: compose.override.yml - @docker compose up -d --no-deps --build $(CONTAINER) + @docker compose build --build-arg PYTHON_IMAGE_VER=$(PYTHON_IMAGE_VER) --build-arg POSTGRES_IMAGE_VER=$(POSTGRES_IMAGE_VER) $(CONTAINER) + @docker compose up -d --no-deps $(CONTAINER) @docker compose restart $(CONTAINER) ## coverage.xml: generate coverage from test runs @@ -85,6 +90,11 @@ down: compose.override.yml exec: compose.override.yml @docker compose exec $(CONTAINER) $(COMMAND) +## gobgp-neighbor: shows the gobgp neighbor information (append neighbor IP for specific information) +.Phony: gobgp-neighbor +gobgp-neighbor: compose.override.yml + @docker compose exec gobgp gobgp neighbor $(NEIGHBOR) + # This automatically builds the help target based on commands prepended with a double hashbang ## help: print this help output .Phony: help @@ -133,3 +143,13 @@ tail-log: compose.override.yml .Phony: type-check type-check: compose.override.yml @docker compose run --rm django mypy scram + +## docs-build: build the documentation +.Phony: docs-build +docs-build: + @docker compose run --rm docs mkdocs build + +## docs-serve: build and run a server with the documentation +.Phony: docs-serve +docs-serve: + @docker compose run --rm docs mkdocs serve -a 0.0.0.0:8888 diff --git a/compose.override.local.yml b/compose.override.local.yml index c2075761..9332114a 100644 --- a/compose.override.local.yml +++ b/compose.override.local.yml @@ -44,12 +44,10 @@ services: networks: default: {} volumes: - - $CI_PROJECT_DIR/docs:/docs:z - - $CI_PROJECT_DIR/config:/app/config:z - - $CI_PROJECT_DIR/scram:/app/scram:z + - $CI_PROJECT_DIR:/app:z ports: - - "7000" - command: /start-docs + - "${DOCS_PORT:-8888}" + command: "mkdocs serve -a 0.0.0.0:${DOCS_PORT:-8888}" redis: ports: diff --git a/compose.override.production.yml b/compose.override.production.yml index 58f47d72..de54d36a 100644 --- a/compose.override.production.yml +++ b/compose.override.production.yml @@ -26,8 +26,9 @@ services: - production_postgres_data_backups:/backups:z env_file: - ./.envs/.production/.postgres + - /etc/vault.d/secrets/kv_root_security.env deploy: - replicas: 0 + replicas: ${POSTGRES_ENABLED:-0} nginx: image: nginx:1.19 diff --git a/compose/local/django/Dockerfile b/compose/local/django/Dockerfile index ca5e2cc7..aaba2ef9 100644 --- a/compose/local/django/Dockerfile +++ b/compose/local/django/Dockerfile @@ -1,5 +1,8 @@ -FROM python:3.8-slim-buster +ARG PYTHON_IMAGE_VER=3.12 +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm + +ENV PIP_ROOT_USER_ACTION ignore ENV PYTHONUNBUFFERED 1 ENV PYTHONDONTWRITEBYTECODE 1 diff --git a/compose/local/django/start b/compose/local/django/start index 2d910375..95f91dff 100644 --- a/compose/local/django/start +++ b/compose/local/django/start @@ -4,5 +4,6 @@ set -o errexit set -o pipefail set -o nounset +mkdir -p /app/staticfiles python manage.py migrate uvicorn config.asgi:application --host 0.0.0.0 --reload diff --git a/compose/local/docs/Dockerfile b/compose/local/docs/Dockerfile index 9ab25f44..49965c75 100644 --- a/compose/local/docs/Dockerfile +++ b/compose/local/docs/Dockerfile @@ -1,5 +1,8 @@ -FROM python:3.8-slim-buster +ARG PYTHON_IMAGE_VER=3.12 +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm + +ENV PIP_ROOT_USER_ACTION ignore ENV PYTHONUNBUFFERED 1 ENV PYTHONDONTWRITEBYTECODE 1 @@ -19,13 +22,11 @@ RUN apt-get update \ && apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false \ && rm -rf /var/lib/apt/lists/* -# Requirements are installed here to ensure they will be cached. -COPY ./requirements /requirements -# All imports needed for autodoc. -RUN pip install -r /requirements/local.txt -r /requirements/production.txt -COPY ./compose/local/docs/start /start-docs -RUN sed -i 's/\r$//g' /start-docs -RUN chmod +x /start-docs +# Only re-run the pip install if these files have changed +COPY requirements/base.txt requirements/local.txt requirements/production.txt /app/requirements/ +RUN pip install -r /app/requirements/local.txt -r /app/requirements/production.txt + +COPY . /app/ -WORKDIR /docs +WORKDIR /app diff --git a/compose/local/docs/start b/compose/local/docs/start deleted file mode 100644 index c562c13a..00000000 --- a/compose/local/docs/start +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -o errexit -set -o pipefail -set -o nounset - -make apidocs -make livehtml diff --git a/compose/production/django/Dockerfile b/compose/production/django/Dockerfile index e4a786a2..f9e1407e 100644 --- a/compose/production/django/Dockerfile +++ b/compose/production/django/Dockerfile @@ -1,5 +1,8 @@ +ARG PYTHON_IMAGE_VER=3.12 + + +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm -FROM python:3.8-slim-buster ENV PYTHONUNBUFFERED 1 diff --git a/compose/production/postgres/Dockerfile b/compose/production/postgres/Dockerfile index c4160f1e..88f86a97 100644 --- a/compose/production/postgres/Dockerfile +++ b/compose/production/postgres/Dockerfile @@ -1,4 +1,6 @@ -FROM postgres:12.3 +ARG POSTGRES_IMAGE_VER=12.3 + +FROM postgres:${POSTGRES_IMAGE_VER} COPY ./compose/production/postgres/maintenance /usr/local/bin/maintenance RUN chmod +x /usr/local/bin/maintenance/* diff --git a/config/__init__.py b/config/__init__.py index e69de29b..69052b7a 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -0,0 +1 @@ +"""Holds the setings and entrypoints.""" diff --git a/config/api_router.py b/config/api_router.py index a5b77f57..a57a12ba 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -1,3 +1,5 @@ +"""Map the API routes to the views.""" + from rest_framework.routers import DefaultRouter from scram.route_manager.api.views import ActionTypeViewSet, ClientViewSet, EntryViewSet, IgnoreEntryViewSet diff --git a/config/asgi.py b/config/asgi.py index 1b3cd6e9..5e11d1fa 100644 --- a/config/asgi.py +++ b/config/asgi.py @@ -1,5 +1,4 @@ -""" -ASGI config for SCRAM project. +"""ASGI config for SCRAM project. It exposes the ASGI callable as a module-level variable named ``application``. @@ -7,6 +6,7 @@ https://docs.djangoproject.com/en/dev/howto/deployment/asgi/ """ + import logging import os import sys @@ -18,28 +18,30 @@ # TODO: from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application +logger = logging.getLogger(__name__) + # Here we setup a debugger if this is desired. This obviously should not be run in production. -debug_mode = os.environ.get("DEBUG") -if debug_mode: - logging.info(f"Django is set to use a debugger. Provided debug mode: {debug_mode}") - if debug_mode == "pycharm-pydevd": - logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") +debug = os.environ.get("DEBUG") +if debug: + logger.info("Django is set to use a debugger. Provided debug mode: %s", debug) + if debug == "pycharm-pydevd": + logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") import pydevd_pycharm pydevd_pycharm.settrace("host.docker.internal", port=56783, stdoutToServer=True, stderrToServer=True) - logging.info("Debugger started.") - elif debug_mode == "debugpy": - logging.info("Entering debug mode for debugpy (VSCode)") + logger.info("Debugger started.") + elif debug == "debugpy": + logger.info("Entering debug mode for debugpy (VSCode)") import debugpy - debugpy.listen(("0.0.0.0", 56780)) + debugpy.listen(("0.0.0.0", 56780)) # noqa S104 (doesn't like binding to all interfaces) - logging.info("Debugger listening on port 56780.") + logger.info("Debugger listening on port 56780.") else: - logging.warning(f"Invalid debug mode given: {debug_mode}. Debugger not started") + logger.warning("Invalid debug mode given: %s. Debugger not started", debug) # This allows easy placement of apps within the interior # scram directory. @@ -62,5 +64,5 @@ { "http": django_application, "websocket": ws_application, - } + }, ) diff --git a/config/consumers.py b/config/consumers.py index 11cca2ef..cd163bb7 100644 --- a/config/consumers.py +++ b/config/consumers.py @@ -1,14 +1,22 @@ +"""Define logic for the WebSocket consumers.""" + import logging +from functools import partial from asgiref.sync import sync_to_async from channels.generic.websocket import AsyncJsonWebsocketConsumer from scram.route_manager.models import Entry, WebSocketSequenceElement +logger = logging.getLogger(__name__) + class TranslatorConsumer(AsyncJsonWebsocketConsumer): + """Handle messages from the Translator(s).""" + async def connect(self): - logging.info("Translator connected") + """Handle the initial connection with adding to the right group.""" + logger.info("Translator connected") self.actiontype = self.scope["url_route"]["kwargs"]["actiontype"] self.translator_group = f"translator_{self.actiontype}" @@ -17,10 +25,10 @@ async def connect(self): # Filter WebSocketSequenceElements by actiontype elements = await sync_to_async(list)( - WebSocketSequenceElement.objects.filter(action_type__name=self.actiontype).order_by("order_num") + WebSocketSequenceElement.objects.filter(action_type__name=self.actiontype).order_by("order_num"), ) if not elements: - logging.warning(f"No elements found for actiontype={self.actiontype}.") + logger.warning("No elements found for actiontype=%s.", self.actiontype) return # Avoid lazy evaluation @@ -28,16 +36,17 @@ async def connect(self): for route in routes: for element in elements: - msg = await sync_to_async(lambda: element.websocketmessage)() + msg = await sync_to_async(partial(element.websocketmessage)) msg.msg_data[msg.msg_data_route_field] = str(route) await self.send_json({"type": msg.msg_type, "message": msg.msg_data}) async def disconnect(self, close_code): - logging.info(f"Disconnect received: {close_code}") + """Discard any remaining messages on disconnect.""" + logger.info("Disconnect received: %s", close_code) await self.channel_layer.group_discard(self.translator_group, self.channel_name) async def receive_json(self, content): - """Received a WebSocket message""" + """Handle a WebSocket message.""" if content["type"] == "translator_check_resp": # We received a check response from a translator, forward to web UI. channel = content.pop("channel") @@ -58,14 +67,17 @@ async def _send_event(self, event): class WebUIConsumer(AsyncJsonWebsocketConsumer): + """Handle messages from the Web UI.""" + async def connect(self): + """Handle the initial connection with adding to the right group.""" self.actiontype = self.scope["url_route"]["kwargs"]["actiontype"] self.translator_group = f"translator_{self.actiontype}" await self.accept() - # Receive message from WebSocket async def receive_json(self, content): + """Receive message from WebSocket.""" if content["type"] == "wui_check_req": # Web UI asks us to check; forward to translator(s) await self.channel_layer.group_send( diff --git a/config/routing.py b/config/routing.py index bae0298a..10a63105 100644 --- a/config/routing.py +++ b/config/routing.py @@ -1,3 +1,5 @@ +"""Define URLs for the WebSocket consumers.""" + from django.urls import re_path from . import consumers diff --git a/config/settings/__init__.py b/config/settings/__init__.py index e69de29b..dd91545c 100644 --- a/config/settings/__init__.py +++ b/config/settings/__init__.py @@ -0,0 +1 @@ +"""Define Django settings. Everyone gets base, and then we can override in local or production.""" diff --git a/config/settings/base.py b/config/settings/base.py index b9e1ca01..be90d71e 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -1,11 +1,13 @@ -""" -Base settings to build other settings files upon. -""" +"""Base settings to build other settings files upon.""" + +import logging import os from pathlib import Path import environ +logger = logging.getLogger(__name__) + ROOT_DIR = Path(__file__).resolve(strict=True).parent.parent.parent # scram/ APPS_DIR = ROOT_DIR / "scram" @@ -97,12 +99,6 @@ ] # https://docs.djangoproject.com/en/dev/ref/settings/#auth-user-model AUTH_USER_MODEL = "users.User" -# https://docs.djangoproject.com/en/dev/ref/settings/#login-redirect-url -LOGIN_REDIRECT_URL = "route_manager:home" -# https://docs.djangoproject.com/en/dev/ref/settings/#login-url -LOGIN_URL = "admin:login" -# https://docs.djangoproject.com/en/dev/ref/settings/#logout-url -LOGOUT_URL = "admin:logout" # PASSWORDS # ------------------------------------------------------------------------------ @@ -192,7 +188,7 @@ "scram.route_manager.context_processors.login_logout", ], }, - } + }, ] # https://docs.djangoproject.com/en/dev/ref/settings/#form-renderer @@ -242,14 +238,14 @@ "version": 1, "disable_existing_loggers": False, "formatters": { - "verbose": {"format": "%(levelname)s %(asctime)s %(module)s " "%(process)d %(thread)d %(message)s"} + "verbose": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"}, }, "handlers": { "console": { "level": "DEBUG", "class": "logging.StreamHandler", "formatter": "verbose", - } + }, }, "root": {"level": "INFO", "handlers": ["console"]}, } @@ -260,7 +256,7 @@ "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", "CONFIG": { - "hosts": [("redis", 6379)], + "hosts": [(os.environ.get("REDIS_HOST", "redis"), 6379)], }, }, } @@ -268,6 +264,8 @@ # django-rest-framework # ------------------------------------------------------------------------------- # django-rest-framework - https://www.django-rest-framework.org/api-guide/settings/ +# Swagger related tooling +INSTALLED_APPS += ["drf_spectacular"] REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework.authentication.SessionAuthentication", @@ -275,12 +273,70 @@ ), "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "TEST_REQUEST_DEFAULT_FORMAT": "json", + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", } # django-cors-headers - https://github.com/adamchainz/django-cors-headers#setup CORS_URLS_REGEX = r"^/api/.*$" # Your stuff... # ------------------------------------------------------------------------------ +# Are you using local passwords or oidc? +AUTH_METHOD = os.environ.get("SCRAM_AUTH_METHOD", "local").lower() + +# https://docs.djangoproject.com/en/dev/ref/settings/#login-redirect-url +LOGIN_REDIRECT_URL = "route_manager:home" + +# Need to point somewhere otherwise /oidc/logout/ redirects to /oidc/logout/None which 404s +# https://github.com/mozilla/mozilla-django-oidc/issues/118 +# Using `/` because named urls don't work for this package +# https://github.com/mozilla/mozilla-django-oidc/issues/434 +LOGOUT_REDIRECT_URL = "route_manager:home" + +OIDC_OP_JWKS_ENDPOINT = os.environ.get( + "OIDC_OP_JWKS_ENDPOINT", + "https://example.com/auth/realms/example/protocol/openid-connect/certs", +) +OIDC_OP_AUTHORIZATION_ENDPOINT = os.environ.get( + "OIDC_OP_AUTHORIZATION_ENDPOINT", + "https://example.com/auth/realms/example/protocol/openid-connect/auth", +) +OIDC_OP_TOKEN_ENDPOINT = os.environ.get( + "OIDC_OP_TOKEN_ENDPOINT", + "https://example.com/auth/realms/example/protocol/openid-connect/token", +) +OIDC_OP_USER_ENDPOINT = os.environ.get( + "OIDC_OP_USER_ENDPOINT", + "https://example.com/auth/realms/example/protocol/openid-connect/userinfo", +) +OIDC_RP_SIGN_ALGO = "RS256" + +logger.info("Using AUTH METHOD=%s", AUTH_METHOD) +if AUTH_METHOD == "oidc": + # Extend middleware to add OIDC middleware + MIDDLEWARE += ["mozilla_django_oidc.middleware.SessionRefresh"] + + # Extend middleware to add OIDC auth backend + AUTHENTICATION_BACKENDS += ["scram.route_manager.authentication_backends.ESnetAuthBackend"] + + # https://docs.djangoproject.com/en/dev/ref/settings/#login-url + LOGIN_URL = "oidc_authentication_init" + + # https://docs.djangoproject.com/en/dev/ref/settings/#logout-url + LOGOUT_URL = "oidc_logout" + + OIDC_RP_CLIENT_ID = os.environ.get("OIDC_RP_CLIENT_ID") + OIDC_RP_CLIENT_SECRET = os.environ.get("OIDC_RP_CLIENT_SECRET") + +elif AUTH_METHOD == "local": + # https://docs.djangoproject.com/en/dev/ref/settings/#login-url + LOGIN_URL = "local_auth:login" + + # https://docs.djangoproject.com/en/dev/ref/settings/#logout-url + LOGOUT_URL = "local_auth:logout" +else: + msg = f"Invalid authentication method: {AUTH_METHOD}. Please choose 'local' or 'oidc'" + raise ValueError(msg) + # Should we create an admin user for you AUTOCREATE_ADMIN = True @@ -293,9 +349,6 @@ SIMPLE_HISTORY_HISTORY_CHANGE_REASON_USE_TEXT_FIELD = True SIMPLE_HISTORY_ENABLED = True -# Are you using local passwords or oidc? -AUTH_METHOD = os.environ.get("SCRAM_AUTH_METHOD", "local") - # Users in these groups have full privileges, including Django is_superuser SCRAM_ADMIN_GROUPS = ["svc_scram_admin"] @@ -311,8 +364,6 @@ # This is the set of all the groups SCRAM_GROUPS = SCRAM_ADMIN_GROUPS + SCRAM_READWRITE_GROUPS + SCRAM_READONLY_GROUPS + SCRAM_DENIED_GROUPS -# This is the full set of groups - # How many entries to show PER Actiontype on the home page RECENT_LIMIT = 10 # What is the largest cidr range we'll accept entries for diff --git a/config/settings/local.py b/config/settings/local.py index 14ec647b..56029930 100644 --- a/config/settings/local.py +++ b/config/settings/local.py @@ -1,5 +1,5 @@ from .base import * # noqa -from .base import env +from .base import AUTH_METHOD, REST_FRAMEWORK, env # GENERAL # ------------------------------------------------------------------------------ @@ -11,7 +11,7 @@ default="BmZnn8FeNFdaeCod8ky6eBNpTiwO45NzlFyA6kk1xo0g4Mc263gAyscHFCMCeJAi", ) # https://docs.djangoproject.com/en/dev/ref/settings/#allowed-hosts -ALLOWED_HOSTS = ["localhost", "0.0.0.0", "127.0.0.1", "django"] +ALLOWED_HOSTS = ["*"] # CACHES # ------------------------------------------------------------------------------ @@ -20,7 +20,7 @@ "default": { "BACKEND": "django.core.cache.backends.locmem.LocMemCache", "LOCATION": "", - } + }, } # EMAIL @@ -33,11 +33,15 @@ # http://whitenoise.evans.io/en/latest/django.html#using-whitenoise-in-development INSTALLED_APPS = ["whitenoise.runserver_nostatic"] + INSTALLED_APPS # noqa F405 +# django-coverage-plugin +# ------------------------------------------------------------------------------ +# https://github.com/nedbat/django_coverage_plugin?tab=readme-ov-file#django-template-coveragepy-plugin +TEMPLATES[0]["OPTIONS"]["debug"] = True # noqa F405 # django-debug-toolbar # ------------------------------------------------------------------------------ # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#prerequisites -INSTALLED_APPS += ["debug_toolbar"] # noqa F405 +INSTALLED_APPS += ["debug_toolbar"] # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#middleware MIDDLEWARE += ["debug_toolbar.middleware.DebugToolbarMiddleware"] # noqa F405 # https://django-debug-toolbar.readthedocs.io/en/latest/configuration.html#debug-toolbar-config @@ -47,7 +51,7 @@ } # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#internal-ips INTERNAL_IPS = ["127.0.0.1", "10.0.2.2"] -if env("USE_DOCKER") == "yes": +if env("USE_DOCKER", default="no") == "yes": import socket hostname, _, ips = socket.gethostbyname_ex(socket.gethostname()) @@ -56,14 +60,23 @@ # django-extensions # ------------------------------------------------------------------------------ # https://django-extensions.readthedocs.io/en/latest/installation_instructions.html#configuration -INSTALLED_APPS += ["django_extensions"] # noqa F405 +INSTALLED_APPS += ["django_extensions"] # Your stuff... # ------------------------------------------------------------------------------ - -REST_FRAMEWORK = { - "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAdminUser",), -} +REST_FRAMEWORK["DEFAULT_PERMISSION_CLASSES"] = ("rest_framework.permissions.IsAdminUser",) # Behave Django testing framework -INSTALLED_APPS += ["behave_django"] # noqa F405 +INSTALLED_APPS += ["behave_django"] + +# AUTHENTICATION +# ------------------------------------------------------------------------------ +# We shouldn't be using OIDC in local dev mode as of now, but might be worth pursuing later +if AUTH_METHOD == "oidc": + msg = "oidc is not yet implemented" + raise NotImplementedError(msg) + +# https://docs.djangoproject.com/en/dev/ref/settings/#login-url +LOGIN_URL = "admin:login" +# https://docs.djangoproject.com/en/dev/ref/settings/#logout-url +LOGOUT_URL = "admin:logout" diff --git a/config/settings/production.py b/config/settings/production.py index 6d4484cc..4af0de77 100644 --- a/config/settings/production.py +++ b/config/settings/production.py @@ -1,7 +1,5 @@ -import os - from .base import * # noqa -from .base import AUTHENTICATION_BACKENDS, MIDDLEWARE, env +from .base import env # GENERAL # ------------------------------------------------------------------------------ @@ -15,9 +13,10 @@ DATABASES["default"] = env.db("DATABASE_URL") # noqa F405 DATABASES["default"]["ATOMIC_REQUESTS"] = True # noqa F405 DATABASES["default"]["CONN_MAX_AGE"] = env.int("CONN_MAX_AGE", default=60) # noqa F405 -if env("POSTGRES_SSL"): +if env.bool("POSTGRES_SSL", default=True): DATABASES["default"]["OPTIONS"] = {"sslmode": "require"} # noqa F405 - +else: + DATABASES["default"]["OPTIONS"] = {"sslmode": "disable"} # noqa F405 # CACHES # ------------------------------------------------------------------------------ CACHES = { @@ -30,7 +29,7 @@ # https://github.com/jazzband/django-redis#memcached-exceptions-behavior "IGNORE_EXCEPTIONS": True, }, - } + }, } # SECURITY @@ -70,7 +69,7 @@ "django.template.loaders.filesystem.Loader", "django.template.loaders.app_directories.Loader", ], - ) + ), ] # EMAIL @@ -110,7 +109,7 @@ "disable_existing_loggers": False, "filters": {"require_debug_false": {"()": "django.utils.log.RequireDebugFalse"}}, "formatters": { - "verbose": {"format": "%(levelname)s %(asctime)s %(module)s " "%(process)d %(thread)d %(message)s"} + "verbose": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"}, }, "handlers": { "mail_admins": { @@ -141,44 +140,3 @@ # Your stuff... # ------------------------------------------------------------------------------ -# Extend middleware to add OIDC middleware -MIDDLEWARE += ["mozilla_django_oidc.middleware.SessionRefresh"] # noqa F405 - -# Extend middleware to add OIDC auth backend -AUTHENTICATION_BACKENDS += ["scram.route_manager.authentication_backends.ESnetAuthBackend"] # noqa F405 - -# https://docs.djangoproject.com/en/dev/ref/settings/#login-url -LOGIN_URL = "oidc_authentication_init" - -# https://docs.djangoproject.com/en/dev/ref/settings/#login-redirect-url -LOGIN_REDIRECT_URL = "/" - -# https://docs.djangoproject.com/en/dev/ref/settings/#logout-url -LOGOUT_URL = "oidc_logout" - -# Need to point somewhere otherwise /oidc/logout/ redirects to /oidc/logout/None which 404s -# https://github.com/mozilla/mozilla-django-oidc/issues/118 -# Using `/` because named urls don't work for this package -# https://github.com/mozilla/mozilla-django-oidc/issues/434 -LOGOUT_REDIRECT_URL = "/" - -OIDC_OP_JWKS_ENDPOINT = os.environ.get( - "OIDC_OP_JWKS_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/certs", -) -OIDC_OP_AUTHORIZATION_ENDPOINT = os.environ.get( - "OIDC_OP_AUTHORIZATION_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/auth", -) -OIDC_OP_TOKEN_ENDPOINT = os.environ.get( - "OIDC_OP_TOKEN_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/token", -) -OIDC_OP_USER_ENDPOINT = os.environ.get( - "OIDC_OP_USER_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/userinfo", -) -OIDC_RP_SIGN_ALGO = "RS256" - -OIDC_RP_CLIENT_ID = os.environ.get("OIDC_RP_CLIENT_ID") -OIDC_RP_CLIENT_SECRET = os.environ.get("OIDC_RP_CLIENT_SECRET") diff --git a/config/settings/test.py b/config/settings/test.py index 1eb5ed4a..606cee75 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -1,8 +1,4 @@ -""" -With these settings, tests run faster. -""" - -import os +"""With these settings, tests run faster.""" from .base import * # noqa from .base import env @@ -31,8 +27,9 @@ "django.template.loaders.filesystem.Loader", "django.template.loaders.app_directories.Loader", ], - ) + ), ] +TEMPLATES[0]["OPTIONS"]["debug"] = True # noqa F405 # EMAIL # ------------------------------------------------------------------------------ @@ -41,26 +38,11 @@ # Your stuff... # ------------------------------------------------------------------------------ -# Extend middleware to add OIDC middleware -MIDDLEWARE += ["mozilla_django_oidc.middleware.SessionRefresh"] # noqa F405 - -OIDC_OP_JWKS_ENDPOINT = os.environ.get( - "OIDC_OP_JWKS_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/certs", -) -OIDC_OP_AUTHORIZATION_ENDPOINT = os.environ.get( - "OIDC_OP_AUTHORIZATION_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/auth", -) -OIDC_OP_TOKEN_ENDPOINT = os.environ.get( - "OIDC_OP_TOKEN_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/token", -) -OIDC_OP_USER_ENDPOINT = os.environ.get( - "OIDC_OP_USER_ENDPOINT", - "https://example.com/auth/realms/example/protocol/openid-connect/userinfo", -) +# These variables are required by the ESnetAuthBackend called in our OidcTest case +OIDC_OP_JWKS_ENDPOINT = "https://example.com/auth/realms/example/protocol/openid-connect/certs" +OIDC_OP_AUTHORIZATION_ENDPOINT = "https://example.com/auth/realms/example/protocol/openid-connect/auth" +OIDC_OP_TOKEN_ENDPOINT = "https://example.com/auth/realms/example/protocol/openid-connect/token" +OIDC_OP_USER_ENDPOINT = "https://example.com/auth/realms/example/protocol/openid-connect/userinfo" OIDC_RP_SIGN_ALGO = "RS256" - -OIDC_RP_CLIENT_ID = os.environ.get("OIDC_RP_CLIENT_ID", "client_id") -OIDC_RP_CLIENT_SECRET = os.environ.get("OIDC_RP_CLIENT_SECRET", "client_secret") +OIDC_RP_CLIENT_ID = "" +OIDC_RP_CLIENT_SECRET = "" diff --git a/config/urls.py b/config/urls.py index a52cd002..41a45d12 100644 --- a/config/urls.py +++ b/config/urls.py @@ -1,9 +1,12 @@ +"""Define the non-WebSocket URLs for Django.""" + from django.conf import settings from django.conf.urls.static import static from django.contrib import admin from django.contrib.staticfiles.urls import staticfiles_urlpatterns from django.urls import include, path from django.views import defaults as default_views +from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from rest_framework.authtoken.views import obtain_auth_token from .api_router import app_name @@ -16,7 +19,8 @@ # User management path("users/", include("scram.users.urls", namespace="users")), # Your stuff: custom urls includes go here -] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) + *static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT), +] if settings.DEBUG: # Static file serving when using Gunicorn + Uvicorn for local web socket development urlpatterns += staticfiles_urlpatterns() @@ -28,7 +32,8 @@ import mozilla_django_oidc # noqa: F401 urlpatterns += [path("oidc/", include("mozilla_django_oidc.urls"))] - +elif settings.AUTH_METHOD == "local": + urlpatterns += [path("auth/", include("scram.local_auth.urls", namespace="local_auth"))] # API URLS api_version_urls = ( [ @@ -43,6 +48,13 @@ path("auth-token/", obtain_auth_token), ] +# Swagger OpenAPI URLs +urlpatterns += [ + path("schema/", SpectacularAPIView.as_view(), name="schema"), + path("schema/swagger-ui/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"), + path("schema/redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"), +] + if settings.DEBUG: # This allows the error pages to be debugged during development, just visit # these url in browser to see how these error pages look like. @@ -67,4 +79,4 @@ if "debug_toolbar" in settings.INSTALLED_APPS: import debug_toolbar - urlpatterns = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns + urlpatterns = [path("__debug__/", include(debug_toolbar.urls)), *urlpatterns] diff --git a/config/websocket.py b/config/websocket.py index 81adfbc6..b48dfb8b 100644 --- a/config/websocket.py +++ b/config/websocket.py @@ -1,4 +1,8 @@ +"""TODO: Find out if this is used.""" + + async def websocket_application(scope, receive, send): + """Handle WebSocket messages. I guess.""" while True: event = await receive() diff --git a/config/wsgi.py b/config/wsgi.py index 98830366..73ed363b 100644 --- a/config/wsgi.py +++ b/config/wsgi.py @@ -1,5 +1,4 @@ -""" -WSGI config for SCRAM project. +"""WSGI config for SCRAM project. This module contains the WSGI application used by Django's development server and any production WSGI deployments. It should expose a module-level variable @@ -13,6 +12,7 @@ framework. """ + import os import sys from pathlib import Path @@ -26,13 +26,10 @@ # We defer to a DJANGO_SETTINGS_MODULE already in the environment. This breaks # if running multiple sites in the same mod_wsgi process. To fix this, use # mod_wsgi daemon mode with each site in its own daemon process, or use -# os.environ["DJANGO_SETTINGS_MODULE"] = "config.settings.production" +# os.environ["DJANGO_SETTINGS_MODULE"] = "config.settings.production" # noqa ERA001 os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.production") # This application object is used by any WSGI server configured to use this # file. This includes Django's development server, if the WSGI_APPLICATION # setting points here. application = get_wsgi_application() -# Apply WSGI middleware here. -# from helloworld.wsgi import HelloWorldApplication -# application = HelloWorldApplication(application) diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index f7bac529..00000000 --- a/docs/Makefile +++ /dev/null @@ -1,29 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -c . -SOURCEDIR = . -BUILDDIR = ./_build -APP = /app - -.PHONY: help livehtml apidocs Makefile - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -# Build, watch and serve docs with live reload -livehtml: - sphinx-autobuild -b html --host 0.0.0.0 --port 7000 --watch $(APP) -c . $(SOURCEDIR) $(BUILDDIR)/html - -# Outputs rst files from django application code -apidocs: - sphinx-apidoc -o $(SOURCEDIR)/api /app - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -b $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/README.rst b/docs/README.md similarity index 80% rename from README.rst rename to docs/README.md index 9be49463..53f356c5 100644 --- a/README.rst +++ b/docs/README.md @@ -1,28 +1,19 @@ -SCRAM -===== +# SCRAM Security Catch and Release Automation Manager -.. image:: https://img.shields.io/badge/built%20with-Cookiecutter%20Django-ff69b4.svg?logo=cookiecutter - :target: https://github.com/pydanny/cookiecutter-django/ - :alt: Built with Cookiecutter Django -.. image:: https://img.shields.io/badge/code%20style-black-000000.svg - :target: https://github.com/ambv/black - :alt: Black code style +[]() +[]() +[]() - -:License: BSD - -==== - -Overview -==== +## Overview SCRAM is a web based service to assist in automation of security data. There is a web interface as well as a REST API available. + The idea is to create actiontypes which allow you to take actions on the IPs/cidr networks you provide. -Components -==== +## Components + SCRAM utilizes ``docker compose`` to run the following stack in production: - nginx (as a webserver and static asset server) @@ -35,8 +26,7 @@ SCRAM utilizes ``docker compose`` to run the following stack in production: A predefined actiontype of "block" exists which utilizes bgp nullrouting to effectivley block any traffic you want to apply. You can add any other actiontypes via the admin page of the web interface dynamically, but keep in mind translator support would need to be added as well. -Installation -==== +## Installation To get a basic implementation up and running locally: @@ -48,6 +38,9 @@ To get a basic implementation up and running locally: - Create ``$scram_home/.envs/.production/.postgres`` a template exists in the docs/templates directory - Make sure to set the right credentials - By default this template assumes you have a service defined in docker compose file called postgres. If you use another postgres server, make sure to update that setting as well +- Create a ``.env`` file with the necessary environment variables: + - [comment]: # This chooses if you want to use oidc or local accounts. This can be local or oidc only. Default: `local` + - scram_auth_method: "local" - ``make build`` - ``make toggle-prod`` - This will turn off debug mode in django and start using nginx to reverse proxy for the app @@ -56,8 +49,7 @@ To get a basic implementation up and running locally: - ``make django-open`` -*** Copyright Notice *** -==== +## Copyright Notice Security Catch and Release Automation Manager (SCRAM) Copyright (c) 2022, The Regents of the University of California, through Lawrence Berkeley diff --git a/docs/__init__.py b/docs/__init__.py index 8772c827..d34efb05 100644 --- a/docs/__init__.py +++ b/docs/__init__.py @@ -1 +1 @@ -# Included so that Django's startproject comment runs against the docs directory +"""Included so that Django's startproject comment runs against the docs directory.""" diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index d4d4bad7..00000000 --- a/docs/conf.py +++ /dev/null @@ -1,63 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. - -import os -import sys - -import django - -if os.getenv("READTHEDOCS", default=False) == "True": - sys.path.insert(0, os.path.abspath("..")) - os.environ["DJANGO_READ_DOT_ENV_FILE"] = "True" - os.environ["USE_DOCKER"] = "no" -else: - sys.path.insert(0, os.path.abspath("/app")) -os.environ["DATABASE_URL"] = "sqlite:///readthedocs.db" -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") -django.setup() - -# -- Project information ----------------------------------------------------- - -project = "SCRAM" -copyright = """2021, Sam Oehlert""" -author = "Sam Oehlert" - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.napoleon", -] - -# Add any paths that contain templates here, relative to this directory. -# templates_path = ["_templates"] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "alabaster" - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ["_static"] diff --git a/docs/cookiecutter.md b/docs/cookiecutter.md deleted file mode 100644 index 17445af4..00000000 --- a/docs/cookiecutter.md +++ /dev/null @@ -1,71 +0,0 @@ -This documentation was initially provided in the README.rst from the cookiecutter used to generate the project - -Settings --------- - -Moved to settings_. - -.. _settings: http://cookiecutter-django.readthedocs.io/en/latest/settings.html - -Basic Commands --------------- - -Setting Up Your Users -^^^^^^^^^^^^^^^^^^^^^ - -* To create a **normal user account**, just go to Sign Up and fill out the form. Once you submit it, you'll see a "Verify Your E-mail Address" page. Go to your console to see a simulated email verification message. Copy the link into your browser. Now the user's email should be verified and ready to go. - -* To create an **superuser account**, use this command:: - - $ python manage.py createsuperuser - -For convenience, you can keep your normal user logged in on Chrome and your superuser logged in on Firefox (or similar), so that you can see how the site behaves for both kinds of users. - -Type checks -^^^^^^^^^^^ - -Running type checks with mypy: - -:: - - $ mypy scram - -Test coverage -^^^^^^^^^^^^^ - -To run the tests, check your test coverage, and generate an HTML coverage report:: - - $ coverage run -m pytest - $ coverage html - $ open htmlcov/index.html - -Running tests with py.test -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -:: - - $ pytest - -Live reloading and Sass CSS compilation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Moved to `Live reloading and SASS compilation`_. - -.. _`Live reloading and SASS compilation`: http://cookiecutter-django.readthedocs.io/en/latest/live-reloading-and-sass-compilation.html - - - -Deployment ----------- - -The following details how to deploy this application. - - - -Docker -^^^^^^ - -See detailed `cookiecutter-django Docker documentation`_. - -.. _`cookiecutter-django Docker documentation`: http://cookiecutter-django.readthedocs.io/en/latest/deployment-with-docker.html - diff --git a/docs/environment_variables.md b/docs/environment_variables.md new file mode 100644 index 00000000..c44629a7 --- /dev/null +++ b/docs/environment_variables.md @@ -0,0 +1,76 @@ +## Environment Variables to Set for Deployment +[comment]: # Which branch of SCRAM to use (you probably want to set it to a release tag) +scram_code_branch: +#### Systems +[comment]: # Email of the main admin +scram_manager_email: +[comment]: # Set to true for production mode; set to false to set up the compose.override.local.yml stack +scram_prod: true +[comment]: # Set to true if you want ansible to install a scram user +scram_install_user: true +[comment]: # What group to put `scram` user in +scram_group: 'scram' +[comment]: # What username to use for `scram` user +scram_user: '' +[comment]: # WHat uid to use for `scram` user +scram_uid: '' +[comment]: # What directory to use for base of the repo +scram_home: '/usr/local/scram' +[comment]: # IP or DNS record for your postgres host +scram_postgres_host: +[comment]: # What postgres user to use +scram_postgres_user: '' + +#### Authentication +[comment]: # This chooses if you want to use oidc or local accounts. This can be local or oidc only. Default: `local` +scram_auth_method: "local" +[comment]: # This client id (username) for your oidc connection. Only need to set this if you are trying to do oidc. +scram_oidc_client_id: + +#### Networking +[comment]: # What is the peering interface docker uses for gobgp to talk to the router +scram_peering_iface: 'ens192' +[comment]: # The v6 network of your peering connection +scram_v4_subnet: '10.0.0.0/24' +[comment]: # The v4 IP of the peering connection for the router side +scram_v4_gateway: '10.0.0.1' +[comment]: # The v4 IP of the peering connection for gobgp side +scram_v4_address: '10.0.0.2' +[comment]: # The v6 network of your peering connection +scram_v6_subnet: '2001:db8::/64' +[comment]: # The v6 IP of the peering connection for the router side +scram_v6_gateway: '2001:db8::2' +[comment]: # The v6 IP of the peering connection for the gobgp side +scram_v6_address: '2001:db8::3' +[comment]: # The AS you want to use for gobgp +scram_as: +[comment]: # A string representing your gobgp instance. Often seen as the local IP of the gobgp instance +scram_router_id: +[comment]: # +scram_peer_as: +[comment]: # The AS you want to use for gobgp side (can this be the same as `scram_as`?) +scram_local_as: +[comment]: # The fqdn of the server hosting this - to be used for nginx +scram_nginx_host: +[comment]: # List of allowed hosts per the django setting "ALLOWED_HOSTS". This should be a list of strings in shell +[comment]: # `django` is required for the websockets to work +[comment]: # Our Ansible assumes `django` + `scram_nginx_host` +scram_django_allowed_hosts: "django" +[comment]: # The fqdn of the server hosting this - to be used for nginx +scram_server_alias: +[comment]: # Do you want to set an md5 for authentication of bgp +scram_bgp_md5_enabled: false +[comment]: # The neighbor config of your gobgp config +scram_neighbors: +[comment]: # The v6 address of your neighbor + - neighbor_address: 2001:db8::2 +[comment]: # This is a v6 address so don't use v4 + ipv4: false +[comment]: # This is a v6 address so use v6 + ipv6: true +[comment]: # The v4 address of your neighbor + - neighbor_address: 10.0.0.200 +[comment]: # This is a v4 address so use v4 + ipv4: true +[comment]: # This is a v4 address so don't use v6 + ipv6: false diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index a22d92b7..00000000 --- a/docs/make.bat +++ /dev/null @@ -1,46 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -c . -) -set SOURCEDIR=_source -set BUILDDIR=_build -set APP=..\scram - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.Install sphinx-autobuild for live serving. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -b %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:livehtml -sphinx-autobuild -b html --open-browser -p 7000 --watch %APP% -c . %SOURCEDIR% %BUILDDIR%/html -GOTO :EOF - -:apidocs -sphinx-apidoc -o %SOURCEDIR%/api %APP% -GOTO :EOF - -:help -%SPHINXBUILD% -b help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/reference/django.md b/docs/reference/django.md new file mode 100644 index 00000000..be734554 --- /dev/null +++ b/docs/reference/django.md @@ -0,0 +1,3 @@ +# Route Manager + +::: scram.route_manager diff --git a/docs/reference/index.md b/docs/reference/index.md new file mode 100644 index 00000000..fc771185 --- /dev/null +++ b/docs/reference/index.md @@ -0,0 +1,7 @@ +# Module Reference + +The SCRAM ecosystem consists of two parts: + +A Django app, [route_manager](/reference/django) + +A translator service, [translator](/reference/translator) diff --git a/docs/reference/translator.md b/docs/reference/translator.md new file mode 100644 index 00000000..897d84e2 --- /dev/null +++ b/docs/reference/translator.md @@ -0,0 +1,3 @@ +# Translator + +::: translator diff --git a/manage.py b/manage.py index 37181fe1..cbc80d8d 100755 --- a/manage.py +++ b/manage.py @@ -1,27 +1,25 @@ #!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" + import os import sys from pathlib import Path -if __name__ == "__main__": - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") try: - from django.core.management import execute_from_command_line - except ImportError: - # The above import may fail for some other reason. Ensure that the - # issue is really that Django is missing to avoid masking other - # exceptions on Python 2. - try: - import django # noqa - except ImportError: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) - - raise + from django.core.management import execute_from_command_line # noqa: PLC0415 + except ImportError as exc: + msg = ( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) + raise ImportError( + msg, + ) from exc # This allows easy placement of apps within the interior # scram directory. @@ -29,3 +27,7 @@ sys.path.append(str(current_path / "scram")) execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/merge_production_dotenvs_in_dotenv.py b/merge_production_dotenvs_in_dotenv.py deleted file mode 100644 index b66558c3..00000000 --- a/merge_production_dotenvs_in_dotenv.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from pathlib import Path -from typing import Sequence - -import pytest - -ROOT_DIR_PATH = Path(__file__).parent.resolve() -PRODUCTION_DOTENVS_DIR_PATH = ROOT_DIR_PATH / ".envs" / ".production" -PRODUCTION_DOTENV_FILE_PATHS = [ - PRODUCTION_DOTENVS_DIR_PATH / ".django", - PRODUCTION_DOTENVS_DIR_PATH / ".postgres", -] -DOTENV_FILE_PATH = ROOT_DIR_PATH / ".env" - - -def merge(output_file_path: str, merged_file_paths: Sequence[str], append_linesep: bool = True) -> None: - with open(output_file_path, "w") as output_file: - for merged_file_path in merged_file_paths: - with open(merged_file_path, "r") as merged_file: - merged_file_content = merged_file.read() - output_file.write(merged_file_content) - if append_linesep: - output_file.write(os.linesep) - - -def main(): - merge(DOTENV_FILE_PATH, PRODUCTION_DOTENV_FILE_PATHS) - - -@pytest.mark.parametrize("merged_file_count", range(3)) -@pytest.mark.parametrize("append_linesep", [True, False]) -def test_merge(tmpdir_factory, merged_file_count: int, append_linesep: bool): - tmp_dir_path = Path(str(tmpdir_factory.getbasetemp())) - - output_file_path = tmp_dir_path / ".env" - - expected_output_file_content = "" - merged_file_paths = [] - for i in range(merged_file_count): - merged_file_ord = i + 1 - - merged_filename = ".service{}".format(merged_file_ord) - merged_file_path = tmp_dir_path / merged_filename - - merged_file_content = merged_filename * merged_file_ord - - with open(merged_file_path, "w+") as file: - file.write(merged_file_content) - - expected_output_file_content += merged_file_content - if append_linesep: - expected_output_file_content += os.linesep - - merged_file_paths.append(merged_file_path) - - merge(output_file_path, merged_file_paths, append_linesep) - - with open(output_file_path, "r") as output_file: - actual_output_file_content = output_file.read() - - assert actual_output_file_content == expected_output_file_content - - -if __name__ == "__main__": - main() diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..704db80c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,16 @@ +site_name: Security Catch and Release Automation Manager + +theme: + name: "material" + +plugins: + - mkdocstrings: + handlers: + python: + import: + - url: https://docs.djangoproject.com/en/4.2/_objects/ + base_url: https://docs.djangoproject.com/en/4.2/ + options: + show_submodules: true + - search + - section-index diff --git a/pyproject.toml b/pyproject.toml index 54476b25..5c332104 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,27 +9,107 @@ python_files = [ # ==== Coverage ==== [tool.coverage.run] -omit = ["*/migrations/*", "*/tests/*"] +include = ["scram/*", "config/*", "translator/*"] +omit = ["**/migrations/*", "scram/contrib/*", "*/tests/*"] plugins = ["django_coverage_plugin"] +branch = true +data_file = "coverage.coverage" + +[tool.coverage.report] +exclude_also = [ + "if debug:", + "if self.debug:", + "if settings.DEBUG", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + ] + +# ===== ruff ==== +[tool.ruff] +exclude = [ + "migrations", +] - -# ==== black ==== -[tool.black] line-length = 119 -target-version = ['py311'] +target-version = 'py312' +preview = true + +[tool.ruff.lint] +select = [ + "A", # builtins + "ASYNC", # async + "B", # bugbear + "BLE", # blind-except + "C4", # comprehensions + "C90", # complexity + "COM", # commas + "D", # pydocstyle + "DJ", # django + "DOC", # pydoclint + "DTZ", # datetimez + "E", # pycodestyle + "EM", # errmsg + "ERA", # eradicate + "F", # pyflakes + "FBT", # boolean-trap + "FLY", # flynt + "G", # logging-format + "I", # isort + "ICN", # import-conventions + "ISC", # implicit-str-concat + "LOG", # logging + "N", # pep8-naming + "PERF", # perflint + "PIE", # pie + "PL", # pylint + "PTH", # use-pathlib + "Q", # quotes + "RET", # return + "RSE", # raise + "RUF", # Ruff + "S", # bandit + "SIM", # simplify + "SLF", # self + "SLOT", # slots + "T20", # print + "TRY", # tryceratops + "UP", # pyupgrade +] +ignore = [ + "COM812", # handled by the formatter + "DOC501", # add possible exceptions to the docstring (TODO) + "ISC001", # handled by the formatter + "RUF012", # need more widespread typing + "SIM102", # Use a single `if` statement instead of nested `if` statements + "SIM108", # Use ternary operator instead of `if`-`else`-block +] +[tool.ruff.lint.mccabe] +max-complexity = 7 # our current code adheres to this without too much effort -# ==== isort ==== -[tool.isort] -profile = "black" -line_length = 119 -known_first_party = [ - "scram", - "config", +[tool.ruff.lint.per-file-ignores] +"**/{tests}/*" = [ + "DOC201", # documenting return values + "DOC402", # documenting yield values + "PLR6301", # could be a static method + "S101", # use of assert + "S106", # hardcoded password +] +"test.py" = [ + "S105", # hardcoded password as argument +] +"scram/users/**" = [ + "DOC201", # documenting return values + "FBT001", # minimal issue; don't need to mess with in the User app + "PLR2004", # magic values when checking HTTP status codes +] +"**/views.py" = [ + "DOC201", # documenting return values; it's fairly obvious in a View ] -skip = ["venv/"] -skip_glob = ["**/migrations/*.py"] +[tool.ruff.lint.pydocstyle] +convention = "google" # ==== mypy ==== [tool.mypy] @@ -51,53 +131,3 @@ ignore_errors = true [tool.django-stubs] django_settings_module = "config.settings.test" - - -# ==== PyLint ==== -[tool.pylint.MASTER] -load-plugins = [ - "pylint_django", -] -django-settings-module = "config.settings.local" - -[tool.pylint.FORMAT] -max-line-length = 119 - -[tool.pylint."MESSAGES CONTROL"] -disable = [ - "missing-docstring", - "invalid-name", -] - -[tool.pylint.DESIGN] -max-parents = 13 - -[tool.pylint.TYPECHECK] -generated-members = [ - "REQUEST", - "acl_users", - "aq_parent", - "[a-zA-Z]+_set{1,2}", - "save", - "delete", -] - - -# ==== djLint ==== -[tool.djlint] -blank_line_after_tag = "load,extends" -close_void_tags = true -format_css = true -format_js = true -# TODO: remove T002 when fixed https://github.com/Riverside-Healthcare/djLint/issues/687 -ignore = "H006,H021,H023,H025,H029,H030,H031,T002,T003" -include = "H017,H035" -indent = 2 -max_line_length = 119 -profile = "django" - -[tool.djlint.css] -indent_size = 2 - -[tool.djlint.js] -indent_size = 2 diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c2b3a233..00000000 --- a/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -addopts = --ds=config.settings.test --reuse-db -python_files = tests.py test_*.py diff --git a/requirements/base.txt b/requirements/base.txt index cd489161..b43b582a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,38 +1,31 @@ -pytz==2022.1 # https://github.com/stub42/pytz -django-netfields==1.2.2 # https://pypi.org/project/django-netfields/ -python-slugify==6.1.2 # https://github.com/un33k/python-slugify -Pillow==9.1.1 # https://github.com/python-pillow/Pillow argon2-cffi==21.3.0 # https://github.com/hynek/argon2_cffi -whitenoise==6.2.0 # https://github.com/evansd/whitenoise -celery==5.2.7 # pyup: < 6.0 # https://github.com/celery/celery -django-celery-beat==2.3.0 # https://github.com/celery/django-celery-beat +django-celery-beat # https://github.com/celery/django-celery-beat +django-netfields # https://pypi.org/project/django-netfields/ flower==1.0.0 # https://github.com/mher/flower +python-slugify==6.1.2 # https://github.com/un33k/python-slugify uvicorn[standard]==0.17.6 # https://github.com/encode/uvicorn -asgiref~=3.5.0 # https://github.com/django/asgiref +whitenoise==6.2.0 # https://github.com/evansd/whitenoise # Django # ------------------------------------------------------------------------------ -channels==3.0.4 channels_redis==3.4.0 -django-redis==5.3.0 -django==3.2.13 # pyup: < 4.0 # https://www.djangoproject.com/ -django-grip==3.4.0 -django-eventstream==4.5.1 # https://github.com/fanout/django-eventstream +crispy-bootstrap5==0.6 # https://github.com/django-crispy-forms/crispy-bootstrap5 +django~=4.2 # https://www.djangoproject.com/ + +django-allauth==0.51.0 # https://github.com/pennersr/django-allauth django-environ==0.9.0 # https://github.com/joke2k/django-environ +django-eventstream==4.5.1 # https://github.com/fanout/django-eventstream django-model-utils==4.2.0 # https://github.com/jazzband/django-model-utils -django-allauth==0.51.0 # https://github.com/pennersr/django-allauth -django-crispy-forms==1.14.0 # https://github.com/django-crispy-forms/django-crispy-forms -crispy-bootstrap5==0.6 # https://github.com/django-crispy-forms/crispy-bootstrap5 +django-redis==5.3.0 +django-simple-history~=3.1.1 # Django REST Framework -djangorestframework==3.13.1 # https://github.com/encode/django-rest-framework +djangorestframework~=3.15 # https://github.com/encode/django-rest-framework django-cors-headers==3.13.0 # https://github.com/adamchainz/django-cors-headers -# DRF-spectacular for api documentation -drf-spectacular==0.22.1 # https://github.com/tfranzel/drf-spectacular +drf-spectacular # https://github.com/tfranzel/drf-spectacular # OIDC # ------------------------------------------------------------------------------ mozilla_django_oidc==2.0.0 # https://github.com/mozilla/mozilla-django-oidc websockets~=10.3 -django-simple-history~=3.1.1 diff --git a/requirements/local.txt b/requirements/local.txt index edc15a30..134a2213 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -2,28 +2,32 @@ Werkzeug[watchdog]==2.0.3 # https://github.com/pallets/werkzeug ipdb==0.13.9 # https://github.com/gotcha/ipdb -psycopg2-binary==2.9.3 # https://github.com/psycopg/psycopg2 +psycopg2-binary==2.9.10 # https://github.com/psycopg/psycopg2 watchgod==0.8.2 # https://github.com/samuelcolvin/watchgod # Testing # ------------------------------------------------------------------------------ -mypy==0.950 # https://github.com/python/mypy django-stubs==1.11.0 # https://github.com/typeddjango/django-stubs -pytest==7.1.2 # https://github.com/pytest-dev/pytest -pytest-sugar==0.9.4 # https://github.com/Frozenball/pytest-sugar +pytest-sugar==0.9.6 # https://github.com/Frozenball/pytest-sugar behave-django==1.4.0 # https://github.com/behave/behave-django -django-debug-toolbar~=3.2 # https://github.com/jazzband/django-debug-toolbar # Documentation # ------------------------------------------------------------------------------ -sphinx==5.0.1 # https://github.com/sphinx-doc/sphinx -sphinx-autobuild==2021.3.14 # https://github.com/GaretJax/sphinx-autobuild +mkdocs==1.6.1 +mkdocs-autorefs==1.2.0 +mkdocs-gen-files==0.5.0 +mkdocs-get-deps==0.2.0 +mkdocs-literate-nav==0.6.1 +mkdocs-material==9.5.44 +mkdocs-material-extensions==1.3.1 +mkdocs-section-index==0.3.9 +mkdocstrings==0.27.0 +mkdocstrings-python==1.12.2 # Code quality # ------------------------------------------------------------------------------ -flake8==4.0.1 # https://github.com/PyCQA/flake8 +flake8-docstrings==1.7.0 # https://github.com/pycqa/flake8-docstrings flake8-isort==4.1.1 # https://github.com/gforcada/flake8-isort -coverage==6.4.1 # https://github.com/nedbat/coveragepy black==22.3.0 # https://github.com/psf/black pylint-django==2.5.3 # https://github.com/PyCQA/pylint-django pylint-celery==0.3 # https://github.com/PyCQA/pylint-celery @@ -35,10 +39,8 @@ factory-boy==3.2.1 # https://github.com/FactoryBoy/factory_boy django-debug-toolbar==3.4.0 # https://github.com/jazzband/django-debug-toolbar django-extensions==3.1.5 # https://github.com/django-extensions/django-extensions -django-coverage-plugin==2.0.3 # https://github.com/nedbat/django_coverage_plugin +django-coverage-plugin==3.1.0 # https://github.com/nedbat/django_coverage_plugin pytest-django==4.5.2 # https://github.com/pytest-dev/pytest-django -pytest~=7.1.2 -behave~=1.2.6 # https://behave.readthedocs.io/en/stable/ # Debugging # ------------------------------------------------------------------------------ diff --git a/scram/__init__.py b/scram/__init__.py index eed836da..8f5a0220 100644 --- a/scram/__init__.py +++ b/scram/__init__.py @@ -1,2 +1,4 @@ -__version__ = "0.1.0" -__version_info__ = tuple([int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")]) +"""The Django project for Security Catch and Release Automation Manager (SCRAM).""" + +__version__ = "1.1.1" +__version_info__ = tuple(int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")) diff --git a/scram/conftest.py b/scram/conftest.py index 52d755b4..29a5bc61 100644 --- a/scram/conftest.py +++ b/scram/conftest.py @@ -1,3 +1,5 @@ +"""Some simple tests for the User app.""" + import pytest from scram.users.models import User @@ -6,9 +8,11 @@ @pytest.fixture(autouse=True) def media_storage(settings, tmpdir): + """Configure the test to use the temp directory.""" settings.MEDIA_ROOT = tmpdir.strpath @pytest.fixture def user() -> User: + """Return the UserFactory.""" return UserFactory() diff --git a/scram/contrib/__init__.py b/scram/contrib/__init__.py index 1c7ecc89..adb76b31 100644 --- a/scram/contrib/__init__.py +++ b/scram/contrib/__init__.py @@ -1,5 +1,4 @@ -""" -To understand why this file is here, please read: +"""To understand why this file is here, please read the CookieCutter documentation. http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django """ diff --git a/scram/contrib/sites/__init__.py b/scram/contrib/sites/__init__.py index 1c7ecc89..adb76b31 100644 --- a/scram/contrib/sites/__init__.py +++ b/scram/contrib/sites/__init__.py @@ -1,5 +1,4 @@ -""" -To understand why this file is here, please read: +"""To understand why this file is here, please read the CookieCutter documentation. http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django """ diff --git a/scram/local_auth/__init__.py b/scram/local_auth/__init__.py new file mode 100644 index 00000000..113a7bfa --- /dev/null +++ b/scram/local_auth/__init__.py @@ -0,0 +1 @@ +"""Local_auth is the app that holds urls we want to use with local Django auth.""" diff --git a/scram/local_auth/urls.py b/scram/local_auth/urls.py new file mode 100644 index 00000000..56c7b727 --- /dev/null +++ b/scram/local_auth/urls.py @@ -0,0 +1,15 @@ +"""Register URLs for local auth known to Django, and the View that will handle each.""" + +from django.contrib.auth.views import LoginView, LogoutView +from django.urls import path + +app_name = "local_auth" + +urlpatterns = [ + path( + "login/", + LoginView.as_view(template_name="local_auth/login.html", success_url="route_manager:home"), + name="login", + ), + path("logout/", LogoutView.as_view(), name="logout"), +] diff --git a/scram/route_manager/__init__.py b/scram/route_manager/__init__.py index e69de29b..db37df27 100644 --- a/scram/route_manager/__init__.py +++ b/scram/route_manager/__init__.py @@ -0,0 +1 @@ +"""Route Manager is the core app, and it handles adding and removing routes that will be blocked.""" diff --git a/scram/route_manager/admin.py b/scram/route_manager/admin.py index 96a6e324..172dd6c7 100644 --- a/scram/route_manager/admin.py +++ b/scram/route_manager/admin.py @@ -1,3 +1,5 @@ +"""Register models in the Admin site.""" + from django.contrib import admin from simple_history.admin import SimpleHistoryAdmin @@ -6,6 +8,8 @@ @admin.register(ActionType) class ActionTypeAdmin(SimpleHistoryAdmin): + """Configure the ActionType and how it shows up in the Admin site.""" + list_filter = ("available",) list_display = ("name", "available") diff --git a/scram/route_manager/api/__init__.py b/scram/route_manager/api/__init__.py index e69de29b..26fc9ceb 100644 --- a/scram/route_manager/api/__init__.py +++ b/scram/route_manager/api/__init__.py @@ -0,0 +1 @@ +"""The API, which leverages Django Request Framework.""" diff --git a/scram/route_manager/api/exceptions.py b/scram/route_manager/api/exceptions.py index 5483055b..0445dbe4 100644 --- a/scram/route_manager/api/exceptions.py +++ b/scram/route_manager/api/exceptions.py @@ -1,22 +1,30 @@ +"""Custom exceptions for the API.""" + from django.conf import settings from rest_framework.exceptions import APIException class PrefixTooLarge(APIException): + """The CIDR prefix that was specified is larger than the prefix allowed in the settings.""" + v4_min_prefix = getattr(settings, "V4_MINPREFIX", 0) v6_min_prefix = getattr(settings, "V6_MINPREFIX", 0) status_code = 400 - default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: 501 + default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: E501 default_code = "prefix_too_large" class IgnoredRoute(APIException): + """An operation attempted to add a route that overlaps with a route on the ignore list.""" + status_code = 400 default_detail = "This CIDR is on the ignore list. You are not allowed to add it here." default_code = "ignored_route" class ActiontypeNotAllowed(APIException): + """An operation attempted to perform an action on behalf of a client that is unauthorized to perform that type.""" + status_code = 403 default_detail = "This client is not allowed to use this actiontype" default_code = "actiontype_not_allowed" diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py index dec9df15..f1bc1907 100644 --- a/scram/route_manager/api/serializers.py +++ b/scram/route_manager/api/serializers.py @@ -1,5 +1,8 @@ +"""Serializers provide mappings between the API and the underlying model.""" + import logging +from drf_spectacular.utils import extend_schema_field from netfields import rest_framework from rest_framework import serializers from rest_framework.fields import CurrentUserDefault @@ -10,16 +13,29 @@ logger = logging.getLogger(__name__) +@extend_schema_field(field={"type": "string", "format": "cidr"}) +class CustomCidrAddressField(rest_framework.CidrAddressField): + """Define a wrapper field so swagger can properly handle the inherited field.""" + + class ActionTypeSerializer(serializers.ModelSerializer): + """Map the serializer to the model via Meta.""" + class Meta: + """Maps to the ActionType model, and specifies the fields exposed by the API.""" + model = ActionType fields = ["pk", "name", "available"] class RouteSerializer(serializers.ModelSerializer): - route = rest_framework.CidrAddressField() + """Exposes route as a CIDR field.""" + + route = CustomCidrAddressField() class Meta: + """Maps to the Route model, and specifies the fields exposed by the API.""" + model = Route fields = [ "route", @@ -27,16 +43,24 @@ class Meta: class ClientSerializer(serializers.ModelSerializer): + """Map the serializer to the model via Meta.""" + class Meta: + """Maps to the Client model, and specifies the fields exposed by the API.""" + model = Client fields = ["hostname", "uuid"] class EntrySerializer(serializers.HyperlinkedModelSerializer): + """Due to the use of ForeignKeys, this follows some relationships to make sense via the API.""" + url = serializers.HyperlinkedIdentityField( - view_name="api:v1:entry-detail", lookup_url_kwarg="pk", lookup_field="route" + view_name="api:v1:entry-detail", + lookup_url_kwarg="pk", + lookup_field="route", ) - route = rest_framework.CidrAddressField() + route = CustomCidrAddressField() actiontype = serializers.CharField(default="block") if CurrentUserDefault(): # This is set if we are calling this serializer from WUI @@ -46,28 +70,44 @@ class EntrySerializer(serializers.HyperlinkedModelSerializer): comment = serializers.CharField() class Meta: + """Maps to the Entry model, and specifies the fields exposed by the API.""" + model = Entry fields = ["route", "actiontype", "url", "comment", "who"] - def get_comment(self, obj): + @staticmethod + def get_comment(obj): + """Provide a nicer name for change reason. + + Returns: + string: The change reason that modified the Entry. + """ return obj.get_change_reason() - def create(self, validated_data): + @staticmethod + def create(validated_data): + """Implement custom logic and validates creating a new route.""" # noqa: DOC201 valid_route = validated_data.pop("route") actiontype = validated_data.pop("actiontype") comment = validated_data.pop("comment") - route_instance, created = Route.objects.get_or_create(route=valid_route) + route_instance, _ = Route.objects.get_or_create(route=valid_route) actiontype_instance = ActionType.objects.get(name=actiontype) - entry_instance, created = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance) + entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance) - logger.debug(f"{comment}") + logger.debug("Created entry with comment: %s", comment) update_change_reason(entry_instance, comment) return entry_instance class IgnoreEntrySerializer(serializers.ModelSerializer): + """Map the route to the right field type.""" + + route = CustomCidrAddressField() + class Meta: + """Maps to the IgnoreEntry model, and specifies the fields exposed by the API.""" + model = IgnoreEntry fields = ["route", "comment"] diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py index 2972e6dd..70d30b34 100644 --- a/scram/route_manager/api/views.py +++ b/scram/route_manager/api/views.py @@ -1,3 +1,5 @@ +"""Views provide mappings between the underlying model and how they're listed in the API.""" + import ipaddress import logging @@ -8,6 +10,7 @@ from django.db.models import Q from django.http import Http404 from django.utils.dateparse import parse_datetime +from drf_spectacular.utils import extend_schema from rest_framework import status, viewsets from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response @@ -17,23 +20,42 @@ from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer channel_layer = get_channel_layer() +logger = logging.getLogger(__name__) +@extend_schema( + description="API endpoint for actiontypes", + responses={200: ActionTypeSerializer}, +) class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet): + """Lookup ActionTypes by name when authenticated, and bind to the serializer.""" + queryset = ActionType.objects.all() permission_classes = (IsAuthenticated,) serializer_class = ActionTypeSerializer lookup_field = "name" +@extend_schema( + description="API endpoint for ignore entries", + responses={200: IgnoreEntrySerializer}, +) class IgnoreEntryViewSet(viewsets.ModelViewSet): + """Lookup IgnoreEntries by route when authenticated, and bind to the serializer.""" + queryset = IgnoreEntry.objects.all() permission_classes = (IsAuthenticated,) serializer_class = IgnoreEntrySerializer lookup_field = "route" +@extend_schema( + description="API endpoint for clients", + responses={200: ClientSerializer}, +) class ClientViewSet(viewsets.ModelViewSet): + """Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer.""" + queryset = Client.objects.all() # We want to allow a client to be registered from anywhere permission_classes = (AllowAny,) @@ -42,21 +64,56 @@ class ClientViewSet(viewsets.ModelViewSet): http_method_names = ["post"] +@extend_schema( + description="API endpoint for entries", + responses={200: EntrySerializer}, +) class EntryViewSet(viewsets.ModelViewSet): + """Lookup Entry when authenticated, and bind to the serializer.""" + queryset = Entry.objects.filter(is_active=True) permission_classes = (IsAuthenticated,) serializer_class = EntrySerializer lookup_value_regex = ".*" http_method_names = ["get", "post", "head", "delete"] - # Ovveride the permissions classes for POST method since we want to accept Entry creates from any client - # Note: We make authorization decisions on whether to actually create the object in the perform_create method later def get_permissions(self): + """Override the permissions classes for POST method since we want to accept Entry creates from any client. + + Note: We make authorization decisions on whether to actually create the object in the perform_create method + later. + """ if self.request.method == "POST": return [AllowAny()] return super().get_permissions() + def check_client_authorization(self, actiontype): + """Ensure that a given client is authorized to use a given actiontype.""" + uuid = self.request.data.get("uuid") + if uuid: + authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list( + "authorized_actiontypes__name", + flat=True, + ) + authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized") + if not authorized_client or actiontype not in authorized_actiontypes: + logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes) + logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype) + raise ActiontypeNotAllowed + elif not self.request.user.has_perm("route_manager.can_add_entry"): + raise PermissionDenied + + @staticmethod + def check_ignore_list(route): + """Ensure that we're not trying to block something from the ignore list.""" + overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route) + if overlapping_ignore.count(): + ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()] + logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries) + raise IgnoredRoute + def perform_create(self, serializer): + """Create a new Entry, causing that route to receive the actiontype (i.e. block).""" actiontype = serializer.validated_data["actiontype"] route = serializer.validated_data["route"] if self.request.user.username: @@ -69,63 +126,45 @@ def perform_create(self, serializer): tmp_exp = self.request.data.get("expiration", "") try: - expiration = parse_datetime(tmp_exp) # noqa: F841 + expiration = parse_datetime(tmp_exp) except ValueError: - logging.info(f"Could not parse expiration DateTime string {tmp_exp!r}.") + logger.warning("Could not parse expiration DateTime string: %s", tmp_exp) # Make sure we put in an acceptable sized prefix min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0) if route.prefixlen < min_prefix: - raise PrefixTooLarge() - - # Make sure this client is authorized to add this entry with this actiontype - if self.request.data.get("uuid"): - client_uuid = self.request.data["uuid"] - authorized_actiontypes = Client.objects.filter(uuid=client_uuid).values_list( - "authorized_actiontypes__name", flat=True + raise PrefixTooLarge + + self.check_client_authorization(actiontype) + self.check_ignore_list(route) + + elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num") + if not elements: + logger.warning("No elements found for actiontype: %s", actiontype) + + for element in elements: + msg = element.websocketmessage + msg.msg_data[msg.msg_data_route_field] = str(route) + # Must match a channel name defined in asgi.py + async_to_sync(channel_layer.group_send)( + f"translator_{actiontype}", + {"type": msg.msg_type, "message": msg.msg_data}, ) - authorized_client = Client.objects.filter(uuid=client_uuid).values("is_authorized") - if not authorized_client or actiontype not in authorized_actiontypes: - logging.debug(f"Client {client_uuid} actiontypes: {authorized_actiontypes}") - logging.info(f"{client_uuid} is not allowed to add an entry to the {actiontype} list") - raise ActiontypeNotAllowed() - elif not self.request.user.has_perm("route_manager.can_add_entry"): - raise PermissionDenied() - # Don't process if we have the entry in the ignorelist - overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route) - if overlapping_ignore.count(): - ignore_entries = [] - for ignore_entry in overlapping_ignore.values(): - ignore_entries.append(str(ignore_entry["route"])) - logging.info(f"Cannot proceed adding {route}. The ignore list contains {ignore_entries}") - raise IgnoredRoute - else: - elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num") - if not elements: - logging.warning(f"No elements found for actiontype={actiontype}.") - - for element in elements: - msg = element.websocketmessage - msg.msg_data[msg.msg_data_route_field] = str(route) - # Must match a channel name defined in asgi.py - async_to_sync(channel_layer.group_send)( - f"translator_{actiontype}", {"type": msg.msg_type, "message": msg.msg_data} - ) - - serializer.save() - - entry = Entry.objects.get(route__route=route, actiontype__name=actiontype) - if expiration: - entry.expiration = expiration - entry.who = who - entry.is_active = True - entry.comment = comment - logging.info(f"{entry}") - entry.save() + serializer.save() + + entry = Entry.objects.get(route__route=route, actiontype__name=actiontype) + if expiration: + entry.expiration = expiration + entry.who = who + entry.is_active = True + entry.comment = comment + logger.info("Created entry: %s", entry) + entry.save() @staticmethod def find_entries(arg, active_filter=None): + """Query entries either by pk or overlapping route.""" if not arg: return Entry.objects.none() @@ -133,13 +172,13 @@ def find_entries(arg, active_filter=None): try: pk = int(arg) query = Q(pk=pk) - except ValueError: + except ValueError as exc: # Maybe a CIDR? We want the ValueError at this point, if not. cidr = ipaddress.ip_network(arg, strict=False) min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0) if cidr.prefixlen < min_prefix: - raise PrefixTooLarge() + raise PrefixTooLarge from exc query = Q(route__route__net_overlaps=cidr) @@ -149,6 +188,7 @@ def find_entries(arg, active_filter=None): return Entry.objects.filter(query) def retrieve(self, request, pk=None, **kwargs): + """Retrieve a single route.""" entries = self.find_entries(pk, active_filter=True) # TODO: What happens if we get multiple? Is that ok? I think yes, and return them all? if entries.count() != 1: @@ -157,6 +197,7 @@ def retrieve(self, request, pk=None, **kwargs): return Response(serializer.data) def destroy(self, request, pk=None, *args, **kwargs): + """Only delete active (e.g. announced) entries.""" for entry in self.find_entries(pk, active_filter=True): entry.delete() diff --git a/scram/route_manager/apps.py b/scram/route_manager/apps.py index 319a8d17..bf13bf9e 100644 --- a/scram/route_manager/apps.py +++ b/scram/route_manager/apps.py @@ -1,5 +1,9 @@ +"""Register ourselves with Django.""" + from django.apps import AppConfig class RouteManagerConfig(AppConfig): + """Define the name of the module that's the main app.""" + name = "scram.route_manager" diff --git a/scram/route_manager/authentication_backends.py b/scram/route_manager/authentication_backends.py index 1c549376..1831d788 100644 --- a/scram/route_manager/authentication_backends.py +++ b/scram/route_manager/authentication_backends.py @@ -1,47 +1,49 @@ +"""Define one or more custom auth backends.""" + from django.conf import settings from django.contrib.auth.models import Group from mozilla_django_oidc.auth import OIDCAuthenticationBackend -class ESnetAuthBackend(OIDCAuthenticationBackend): - def update_groups(self, user, claims): - """Sets the users group(s) to whatever is in the claims.""" +def groups_overlap(a, b): + """Helper function to see if a and b have any overlap. - claimed_groups = claims.get("groups", []) - - effective_groups = [] - is_admin = False + Returns: + bool: True if there's any overlap between a and b. + """ + return not set(a).isdisjoint(b) - ro_group = Group.objects.get(name="readonly") - rw_group = Group.objects.get(name="readwrite") - for g in claimed_groups: - # If any of the user's groups are in DENIED_GROUPS, deny them and stop processing immediately - if g in settings.SCRAM_DENIED_GROUPS: - effective_groups = [] - is_admin = False - break - - if g in settings.SCRAM_ADMIN_GROUPS: - is_admin = True +class ESnetAuthBackend(OIDCAuthenticationBackend): + """Extend the OIDC backend with a custom permission model.""" - if g in settings.SCRAM_READONLY_GROUPS: - if ro_group not in effective_groups: - effective_groups.append(ro_group) + @staticmethod + def update_groups(user, claims): + """Set the user's group(s) to whatever is in the claims.""" + effective_groups = [] + claimed_groups = claims.get("groups", []) - if g in settings.SCRAM_READWRITE_GROUPS: - if rw_group not in effective_groups: - effective_groups.append(rw_group) + if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS): + is_admin = False + # Don't even look at anything else if they're denied + else: + is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS) + if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS): + effective_groups.append(Group.objects.get(name="readwrite")) + if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS): + effective_groups.append(Group.objects.get(name="readonly")) user.groups.set(effective_groups) user.is_staff = user.is_superuser = is_admin user.save() def create_user(self, claims): - user = super(ESnetAuthBackend, self).create_user(claims) + """Wrap the superclass's user creation.""" # noqa: DOC201 + user = super().create_user(claims) return self.update_user(user, claims) def update_user(self, user, claims): + """Determine the user name from the claims and update said user's groups.""" # noqa: DOC201 user.name = claims.get("given_name", "") + " " + claims.get("family_name", "") user.username = claims.get("preferred_username", "") if claims.get("groups", False): diff --git a/scram/route_manager/context_processors.py b/scram/route_manager/context_processors.py index 3781c3d2..0daf06c9 100644 --- a/scram/route_manager/context_processors.py +++ b/scram/route_manager/context_processors.py @@ -1,8 +1,15 @@ +"""Define custom functions that take a request and add to the context before template rendering.""" + from django.conf import settings from django.urls import reverse def login_logout(request): + """Pass through the relevant URLs from the settings. + + Returns: + dict: login and logout URLs + """ login_url = reverse(settings.LOGIN_URL) logout_url = reverse(settings.LOGOUT_URL) return {"login": login_url, "logout": logout_url} diff --git a/scram/route_manager/models.py b/scram/route_manager/models.py index ef063ead..aa72f7ba 100644 --- a/scram/route_manager/models.py +++ b/scram/route_manager/models.py @@ -1,3 +1,5 @@ +"""Define the models used in the route_manager app.""" + import logging import uuid as uuid_lib @@ -8,35 +10,41 @@ from netfields import CidrAddressField from simple_history.models import HistoricalRecords +logger = logging.getLogger(__name__) + class Route(models.Model): - """Model describing a route""" + """Define a route as a CIDR route and a UUID.""" route = CidrAddressField(unique=True) uuid = models.UUIDField(db_index=True, default=uuid_lib.uuid4, editable=False) - def get_absolute_url(self): - return reverse("") - def __str__(self): + """Don't display the UUID, only the route.""" # noqa: DOC201 return str(self.route) + @staticmethod + def get_absolute_url(): + """Ensure we use UUID on the API side instead.""" # noqa: DOC201 + return reverse("") + class ActionType(models.Model): - """Defines an action that can be done with a given route. e.g. Block, shunt, redirect, etc.""" + """Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc.""" name = models.CharField(help_text="One-word description of the action", max_length=30) available = models.BooleanField(help_text="Is this a valid choice for new entries?", default=True) history = HistoricalRecords() def __str__(self): + """Display clearly whether the action is currently available.""" # noqa: DOC201 if not self.available: return f"{self.name} (Inactive)" return self.name class WebSocketMessage(models.Model): - """Defines a single message sent to downstream translators via WebSocket.""" + """Define a single message sent to downstream translators via WebSocket.""" msg_type = models.CharField("The type of the message", max_length=50) msg_data = models.JSONField("The JSON payload. See also msg_data_route_field.", default=dict) @@ -47,16 +55,17 @@ class WebSocketMessage(models.Model): ) def __str__(self): + """Display clearly what the fields are used for.""" # noqa: DOC201 return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}" class WebSocketSequenceElement(models.Model): - """In a sequence of messages, defines a single element.""" + """In a sequence of messages, define a single element.""" websocketmessage = models.ForeignKey("WebSocketMessage", on_delete=models.CASCADE) order_num = models.SmallIntegerField( "Sequences are sent from the smallest order_num to the highest. " - + "Messages with the same order_num could be sent in any order", + "Messages with the same order_num could be sent in any order", default=0, ) @@ -70,9 +79,10 @@ class WebSocketSequenceElement(models.Model): action_type = models.ForeignKey("ActionType", on_delete=models.CASCADE) def __str__(self): + """Summarize the fields into something short and readable.""" # noqa: DOC201 return ( f"{self.websocketmessage} as order={self.order_num} for " - + f"{self.verb} actions on actiontype={self.action_type}" + f"{self.verb} actions on actiontype={self.action_type}" ) @@ -81,7 +91,7 @@ class Entry(models.Model): route = models.ForeignKey("Route", on_delete=models.PROTECT) actiontype = models.ForeignKey("ActionType", on_delete=models.PROTECT) - comment = models.TextField(blank=True, null=True) + comment = models.TextField(blank=True, default="") is_active = models.BooleanField(default=True) # TODO: fix name if this works history = HistoricalRecords() @@ -91,60 +101,71 @@ class Entry(models.Model): expiration_reason = models.CharField( help_text="Optional reason for the expiration", max_length=200, - null=True, blank=True, + default="", ) - def delete(self, *args, **kwargs): - if not self.is_active: - # We've already expired this route, don't send another message - return - else: - # We don't actually delete records; we set them to inactive and then tell the translator to remove them - logging.info(f"Deactivating {self.route}") - self.is_active = False - self.save() - - # Unblock it - async_to_sync(channel_layer.group_send)( - f"translator_{self.actiontype}", - { - "type": "translator_remove", - "message": {"route": str(self.route)}, - }, - ) - class Meta: + """Ensure that multiple routes can be added as long as they have different action types.""" + unique_together = ["route", "actiontype"] verbose_name_plural = "Entries" def __str__(self): + """Summarize the most important fields to something easily readable.""" # noqa: DOC201 desc = f"{self.route} ({self.actiontype})" if not self.is_active: desc += " (inactive)" return desc + def delete(self, *args, **kwargs): + """Set inactive instead of deleting, as we want to ensure a history of entries.""" + if not self.is_active: + # We've already expired this route, don't send another message + return + # We don't actually delete records; we set them to inactive and then tell the translator to remove them + logger.info("Deactivating %s", self.route) + self.is_active = False + self.save() + + # Unblock it + async_to_sync(channel_layer.group_send)( + f"translator_{self.actiontype}", + { + "type": "translator_remove", + "message": {"route": str(self.route)}, + }, + ) + def get_change_reason(self): + """Traverse some complex relationships to determine the most recent change reason. + + Returns: + str: The most recent change reason + """ hist_mgr = getattr(self, self._meta.simple_history_manager_attribute) return hist_mgr.order_by("-history_date").first().history_change_reason class IgnoreEntry(models.Model): - """For cidrs you NEVER want to block ie don't shoot yourself in the foot list""" + """Define CIDRs you NEVER want to block (i.e. the "don't shoot yourself in the foot" list).""" route = CidrAddressField(unique=True) comment = models.CharField(max_length=100) history = HistoricalRecords() class Meta: + """Ensure the plural is grammatically correct.""" + verbose_name_plural = "Ignored Entries" def __str__(self): + """Only display the route.""" # noqa: DOC201 return str(self.route) class Client(models.Model): - """Any client that would like to hit the API to add entries (e.g. Zeek)""" + """Any client that would like to hit the API to add entries (e.g. Zeek).""" hostname = models.CharField(max_length=50, unique=True) uuid = models.UUIDField() @@ -153,6 +174,7 @@ class Client(models.Model): authorized_actiontypes = models.ManyToManyField(ActionType) def __str__(self): + """Only display the hostname.""" # noqa: DOC201 return str(self.hostname) diff --git a/scram/route_manager/tests/__init__.py b/scram/route_manager/tests/__init__.py index e69de29b..fb031a32 100644 --- a/scram/route_manager/tests/__init__.py +++ b/scram/route_manager/tests/__init__.py @@ -0,0 +1 @@ +"""Define tests executed by pytest.""" diff --git a/scram/route_manager/tests/acceptance/environment.py b/scram/route_manager/tests/acceptance/environment.py index 5bae6990..0d31198b 100644 --- a/scram/route_manager/tests/acceptance/environment.py +++ b/scram/route_manager/tests/acceptance/environment.py @@ -1,9 +1,12 @@ +"""Setup the environment for tests.""" + from rest_framework.test import APIClient from scram.users.tests.factories import UserFactory def django_ready(context): + """Create a user that tests can use for authenticated/unauthenticated testing.""" # Create a user UserFactory(username="user", password="password") diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py index 1d502e96..f688bd64 100644 --- a/scram/route_manager/tests/acceptance/steps/common.py +++ b/scram/route_manager/tests/acceptance/steps/common.py @@ -1,32 +1,37 @@ +"""Define steps used exclusively by the Behave tests.""" + import datetime import time -import django.conf as conf from asgiref.sync import async_to_sync from behave import given, step, then, when from channels.layers import get_channel_layer +from django import conf from django.urls import reverse from scram.route_manager.models import ActionType, Client, WebSocketMessage, WebSocketSequenceElement @given("a {name} actiontype is defined") -def step_impl(context, name): +def create_actiontype(context, name): + """Create an actiontype of that name.""" context.channel_layer = get_channel_layer() async_to_sync(context.channel_layer.group_send)( - f"translator_{name}", {"type": "translator_remove_all", "message": {}} + f"translator_{name}", + {"type": "translator_remove_all", "message": {}}, ) - at, created = ActionType.objects.get_or_create(name=name) - wsm, created = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") + at, _ = ActionType.objects.get_or_create(name=name) + wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") wsm.save() - wsse, created = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at) + wsse, _ = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at) wsse.save() @given("a client with {name} authorization") -def step_impl(context, name): - at, created = ActionType.objects.get_or_create(name=name) +def create_authed_client(context, name): + """Create a client and authorize it for that action type.""" + at, _ = ActionType.objects.get_or_create(name=name) authorized_client = Client.objects.create( hostname="authorized_client.es.net", uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3", @@ -36,7 +41,8 @@ def step_impl(context, name): @given("a client without {name} authorization") -def step_impl(context, name): +def create_unauthed_client(context, name): + """Create a client that has no authorized action types.""" unauthorized_client = Client.objects.create( hostname="unauthorized_client.es.net", uuid="91e134a5-77cf-4560-9797-6bbdbffde9f8", @@ -45,23 +51,27 @@ def step_impl(context, name): @when("we're logged in") -def step_impl(context): +def login(context): + """Login.""" context.test.client.login(username="user", password="password") @when("the CIDR prefix limits are {v4_minprefix:d} and {v6_minprefix:d}") -def step_impl(context, v4_minprefix, v6_minprefix): +def set_cidr_limit(context, v4_minprefix, v6_minprefix): + """Override our settings with the provided values.""" conf.settings.V4_MINPREFIX = v4_minprefix conf.settings.V6_MINPREFIX = v6_minprefix @then("we get a {status_code:d} status code") -def step_impl(context, status_code): +def check_status_code(context, status_code): + """Ensure the status code response matches the expected value.""" context.test.assertEqual(context.response.status_code, status_code) @when("we add the entry {value:S}") -def step_impl(context, value): +def add_entry(context, value): + """Block the provided route.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), { @@ -77,7 +87,8 @@ def step_impl(context, value): @when("we add the entry {value:S} with comment {comment}") -def step_impl(context, value, comment): +def add_entry_with_comment(context, value, comment): + """Block the provided route and add a comment.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), { @@ -92,7 +103,8 @@ def step_impl(context, value, comment): @when("we add the entry {value:S} with expiration {exp:S}") -def step_impl(context, value, exp): +def add_entry_with_absolute_expiration(context, value, exp): + """Block the provided route and add an absolute expiration datetime.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), { @@ -108,9 +120,10 @@ def step_impl(context, value, exp): @when("we add the entry {value:S} with expiration in {secs:d} seconds") -def step_impl(context, value, secs): +def add_entry_with_relative_expiration(context, value, secs): + """Block the provided route and add a relative expiration.""" td = datetime.timedelta(seconds=secs) - expiration = datetime.datetime.now() + td + expiration = datetime.datetime.now(tz=datetime.UTC) + td context.response = context.test.client.post( reverse("api:v1:entry-list"), @@ -127,34 +140,41 @@ def step_impl(context, value, secs): @step("we wait {secs:d} seconds") -def step_impl(context, secs): +def wait(context, secs): + """Wait to allow messages to propagate.""" time.sleep(secs) @then("we remove expired entries") -def step_impl(context): +def remove_expired(context): + """Call the function that removes expired entries.""" context.response = context.test.client.get(reverse("route_manager:process-expired")) @when("we add the ignore entry {value:S}") -def step_impl(context, value): +def add_ignore_entry(context, value): + """Add an IgnoreEntry with the specified route.""" context.response = context.test.client.post( - reverse("api:v1:ignoreentry-list"), {"route": value, "comment": "test api"} + reverse("api:v1:ignoreentry-list"), + {"route": value, "comment": "test api"}, ) @when("we remove the {model} {value}") -def step_impl(context, model, value): +def remove_an_object(context, model, value): + """Remove any model object with the matching value.""" context.response = context.test.client.delete(reverse(f"api:v1:{model.lower()}-detail", args=[value])) @when("we list the {model}s") -def step_impl(context, model): +def list_objects(context, model): + """List all objects of an arbitrary model.""" context.response = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) @when("we update the {model} {value_from} to {value_to}") -def step_impl(context, model, value_from, value_to): +def update_object(context, model, value_from, value_to): + """Modify any model object with the matching value to the new value instead.""" context.response = context.test.client.patch( reverse(f"api:v1:{model.lower()}-detail", args=[value_from]), {model.lower(): value_to}, @@ -162,7 +182,8 @@ def step_impl(context, model, value_from, value_to): @then("the number of {model}s is {num:d}") -def step_impl(context, model, num): +def count_objects(context, model, num): + """Count the number of objects of an arbitrary model.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) context.test.assertEqual(len(objs.json()), num) @@ -171,7 +192,8 @@ def step_impl(context, model, num): @then("{value} is one of our list of {model}s") -def step_impl(context, value, model): +def check_object(context, value, model): + """Ensure that the arbitrary model has an object with the specified value.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) found = False @@ -186,7 +208,8 @@ def step_impl(context, value, model): @when("we register a client named {hostname} with the uuid of {uuid}") -def step_impl(context, hostname, uuid): +def add_client(context, hostname, uuid): + """Create a client with a specific UUID.""" context.response = context.test.client.post( reverse("api:v1:client-list"), { diff --git a/scram/route_manager/tests/acceptance/steps/ip.py b/scram/route_manager/tests/acceptance/steps/ip.py index a346e954..34536a86 100644 --- a/scram/route_manager/tests/acceptance/steps/ip.py +++ b/scram/route_manager/tests/acceptance/steps/ip.py @@ -1,12 +1,14 @@ +"""Define steps used for IP-related logic by the Behave tests.""" + import ipaddress from behave import then, when from django.urls import reverse -# This does a CIDR match @then("{route} is contained in our list of {model}s") -def step_impl(context, route, model): +def check_route(context, route, model): + """Perform a CIDR match on the matching object.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) ip_target = ipaddress.ip_address(route) @@ -21,7 +23,8 @@ def step_impl(context, route, model): @when("we query for {ip}") -def step_impl(context, ip): +def check_ip(context, ip): + """Find an Entry for the specified IP.""" try: context.response = context.test.client.get(reverse("api:v1:entry-detail", args=[ip])) context.queryException = None @@ -31,12 +34,14 @@ def step_impl(context, ip): @then("we get a ValueError") -def step_impl(context): +def check_error(context): + """Ensure we received a ValueError exception.""" assert isinstance(context.queryException, ValueError) @then("the change entry for {value:S} is {comment}") -def step_impl(context, value, comment): +def check_comment(context, value, comment): + """Verify the comment for the Entry.""" try: objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value])) context.test.assertEqual(objs.json()[0]["comment"], comment) diff --git a/scram/route_manager/tests/acceptance/steps/translator.py b/scram/route_manager/tests/acceptance/steps/translator.py index 85621f21..ff478a9a 100644 --- a/scram/route_manager/tests/acceptance/steps/translator.py +++ b/scram/route_manager/tests/acceptance/steps/translator.py @@ -1,3 +1,5 @@ +"""Define steps used for translator integration testing by the Behave tests.""" + from behave import then from behave.api.async_step import async_run_until_complete from channels.testing import WebsocketCommunicator @@ -5,9 +7,10 @@ from config.asgi import ws_application -async def query_translator(route, actiontype, is_announced=True): +async def query_translator(route, actiontype, is_announced): + """Ensure the specified route is currently either blocked or unblocked.""" communicator = WebsocketCommunicator(ws_application, f"/ws/route_manager/webui_{actiontype}/") - connected, subprotocol = await communicator.connect() + connected, _ = await communicator.connect() assert connected await communicator.send_json_to({"type": "wui_check_req", "message": {"route": route}}) @@ -20,11 +23,13 @@ async def query_translator(route, actiontype, is_announced=True): @then("{route} is announced by {actiontype} translators") @async_run_until_complete -async def step_impl(context, route, actiontype): - await query_translator(route, actiontype) +async def check_blocked(context, route, actiontype): + """Ensure the specified route is currently blocked.""" + await query_translator(route, actiontype, is_announced=True) @then("{route} is not announced by {actiontype} translators") @async_run_until_complete -async def step_impl(context, route, actiontype): +async def check_unblocked(context, route, actiontype): + """Ensure the specified route is currently unblocked.""" await query_translator(route, actiontype, is_announced=False) diff --git a/scram/route_manager/tests/functional_tests.py b/scram/route_manager/tests/functional_tests.py index c4312ecf..a87ce21e 100644 --- a/scram/route_manager/tests/functional_tests.py +++ b/scram/route_manager/tests/functional_tests.py @@ -1,9 +1,10 @@ +"""Use the Django web client to perform end-to-end, WebUI-based testing.""" + import unittest class HomePageTest(unittest.TestCase): - - pass + """Ensure the home page works.""" if __name__ == "__main__": diff --git a/scram/route_manager/tests/test_api.py b/scram/route_manager/tests/test_api.py index 8152b877..1a81e907 100644 --- a/scram/route_manager/tests/test_api.py +++ b/scram/route_manager/tests/test_api.py @@ -1,3 +1,5 @@ +"""Use pytest to unit test the API.""" + from django.contrib.auth import get_user_model from django.urls import reverse from rest_framework import status @@ -7,7 +9,10 @@ class TestAddRemoveIP(APITestCase): + """Ensure that we can block IPs, and that duplicate blocks don't generate an error.""" + def setUp(self): + """Set up the environment for our tests.""" self.url = reverse("api:v1:entry-list") self.superuser = get_user_model().objects.create_superuser("admin", "admin@es.net", "admintestpassword") self.client.login(username="admin", password="admintestpassword") @@ -19,6 +24,7 @@ def setUp(self): self.authorized_client.authorized_actiontypes.set([1]) def test_block_ipv4(self): + """Block a v4 IP.""" response = self.client.post( self.url, { @@ -31,6 +37,7 @@ def test_block_ipv4(self): self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_block_duplicate_ipv4(self): + """Block an existing v4 IP and ensure we don't get an error.""" self.client.post( self.url, { @@ -52,6 +59,7 @@ def test_block_duplicate_ipv4(self): self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_block_ipv6(self): + """Block a v6 IP.""" response = self.client.post( self.url, { @@ -64,6 +72,7 @@ def test_block_ipv6(self): self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_block_duplicate_ipv6(self): + """Block an existing v6 IP and ensure we don't get an error.""" self.client.post( self.url, { @@ -86,11 +95,15 @@ def test_block_duplicate_ipv6(self): class TestUnauthenticatedAccess(APITestCase): + """Ensure that an unathenticated client can't do anything.""" + def setUp(self): + """Define some helper variables.""" self.entry_url = reverse("api:v1:entry-list") self.ignore_url = reverse("api:v1:ignoreentry-list") def test_unauthenticated_users_have_no_create_access(self): + """Ensure an unauthenticated client can't add an Entry.""" response = self.client.post( self.entry_url, { @@ -104,9 +117,11 @@ def test_unauthenticated_users_have_no_create_access(self): self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_unauthenticated_users_have_no_ignore_create_access(self): + """Ensure an unauthenticated client can't add an IgnoreEntry.""" response = self.client.post(self.ignore_url, {"route": "192.0.2.4"}, format="json") self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_unauthenticated_users_have_no_list_access(self): + """Ensure an unauthenticated client can't list Entries.""" response = self.client.get(self.entry_url, format="json") self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/scram/route_manager/tests/test_authorization.py b/scram/route_manager/tests/test_authorization.py index 6f4269d1..9a0d4bff 100644 --- a/scram/route_manager/tests/test_authorization.py +++ b/scram/route_manager/tests/test_authorization.py @@ -1,3 +1,5 @@ +"""Define tests for authorization and permissions.""" + from django.conf import settings from django.contrib.auth.models import Group from django.test import Client, TestCase @@ -10,7 +12,10 @@ class AuthzTest(TestCase): + """Define tests using the built-in authentication.""" + def setUp(self): + """Define several users for our tests.""" self.client = Client() self.unauthorized_user = User.objects.create(username="unauthorized") @@ -49,6 +54,7 @@ def setUp(self): ) def create_entry(self): + """Ensure the admin user can create an Entry.""" self.client.force_login(self.admin_user) self.client.post( reverse("route_manager:add"), @@ -63,8 +69,7 @@ def create_entry(self): return Entry.objects.latest("id").id def test_unauthorized_add_entry(self): - """Unauthorized users should not be able to add an entry""" - + """Unauthorized users should not be able to add an Entry.""" for user in self.write_blocked_users: if user: self.client.force_login(user) @@ -79,6 +84,7 @@ def test_unauthorized_add_entry(self): self.assertEqual(response.status_code, 302) def test_authorized_add_entry(self): + """Test authorized users with various permissions to ensure they can add an Entry.""" for user in self.write_allowed_users: self.client.force_login(user) response = self.client.post( @@ -92,6 +98,7 @@ def test_authorized_add_entry(self): self.assertEqual(response.status_code, 200) def test_unauthorized_detail_view(self): + """Ensure that unauthorized users can't view the blocked IPs.""" pk = self.create_entry() for user in self.detail_blocked_users: @@ -101,6 +108,7 @@ def test_unauthorized_detail_view(self): self.assertIn(response.status_code, [302, 403], msg=f"username={user}") def test_authorized_detail_view(self): + """Test authorized users with various permissions to ensure they can view block details.""" pk = self.create_entry() for user in self.detail_allowed_users: @@ -110,7 +118,6 @@ def test_authorized_detail_view(self): def test_unauthorized_after_group_removal(self): """The user has r/w access, then when we remove them from the r/w group, they no longer do.""" - test_user = User.objects.create(username="tmp_readwrite") test_user.groups.set([self.readwrite_group]) test_user.save() @@ -125,8 +132,11 @@ def test_unauthorized_after_group_removal(self): self.assertEqual(response.status_code, 302) -class OidcTest(TestCase): +class ESnetAuthBackendTest(TestCase): + """Define tests using OIDC authentication with our ESnetAuthBackend.""" + def setUp(self): + """Create a sample OIDC user.""" self.client = Client() self.claims = { "given_name": "Edward", @@ -136,8 +146,7 @@ def setUp(self): } def test_unauthorized(self): - """A user with no groups should have no access""" - + """A user with no groups should have no access.""" claims = dict(self.claims) user = ESnetAuthBackend().create_user(claims) @@ -146,8 +155,7 @@ def test_unauthorized(self): self.assertEqual(list(user.user_permissions.all()), []) def test_readonly(self): - """Test r/o groups""" - + """Test r/o groups.""" claims = dict(self.claims) claims["groups"] = [settings.SCRAM_READONLY_GROUPS[0]] user = ESnetAuthBackend().create_user(claims) @@ -158,8 +166,7 @@ def test_readonly(self): self.assertFalse(user.has_perm("route_manager.add_entry")) def test_readwrite(self): - """Test r/w groups""" - + """Test r/w groups.""" claims = dict(self.claims) claims["groups"] = [settings.SCRAM_READWRITE_GROUPS[0]] user = ESnetAuthBackend().create_user(claims) @@ -171,8 +178,7 @@ def test_readwrite(self): self.assertTrue(user.has_perm("route_manager.add_entry")) def test_admin(self): - """Test admin_groups""" - + """Test admin_groups.""" claims = dict(self.claims) claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]] user = ESnetAuthBackend().create_user(claims) @@ -183,8 +189,7 @@ def test_admin(self): self.assertTrue(user.has_perm("route_manager.add_entry")) def test_authorized_removal(self): - """Have an authorized user, then downgrade them and make sure they're unauthorized""" - + """Have an authorized user, then downgrade them and make sure they're unauthorized.""" claims = dict(self.claims) claims["groups"] = [settings.SCRAM_ADMIN_GROUPS[0]] user = ESnetAuthBackend().create_user(claims) @@ -230,9 +235,8 @@ def test_authorized_removal(self): def test_disabled(self): """Pass all the groups, user should be disabled as it takes precedence.""" - claims = dict(self.claims) - claims["groups"] = [settings.SCRAM_GROUPS] + claims["groups"] = settings.SCRAM_GROUPS user = ESnetAuthBackend().create_user(claims) self.assertFalse(user.is_staff) diff --git a/scram/route_manager/tests/test_autocreate_admin.py b/scram/route_manager/tests/test_autocreate_admin.py new file mode 100644 index 00000000..88270f4e --- /dev/null +++ b/scram/route_manager/tests/test_autocreate_admin.py @@ -0,0 +1,50 @@ +"""Test the auto-creation of an admin user.""" + +import pytest +from django.contrib.messages import get_messages +from django.test import Client +from django.urls import reverse + +from scram.users.models import User + +LEVEL_SUCCESS = 25 +LEVEL_INFO = 20 + + +@pytest.mark.django_db +def test_autocreate_admin(settings): + """Test that an admin user is auto-created when AUTOCREATE_ADMIN is True.""" + settings.AUTOCREATE_ADMIN = True + client = Client() + response = client.get(reverse("route_manager:home")) + assert response.status_code == 200 # noqa: PLR2004 + assert User.objects.count() == 1 + user = User.objects.get(username="admin") + assert user.is_superuser + assert user.email == "admin@example.com" + messages = list(get_messages(response.wsgi_request)) + assert len(messages) == 2 # noqa: PLR2004 + assert messages[0].level == LEVEL_SUCCESS + assert messages[1].level == LEVEL_INFO + + +@pytest.mark.django_db +def test_autocreate_admin_disabled(settings): + """Test that an admin user is not auto-created when AUTOCREATE_ADMIN is False.""" + settings.AUTOCREATE_ADMIN = False + client = Client() + response = client.get(reverse("route_manager:home")) + assert response.status_code == 200 # noqa: PLR2004 + assert User.objects.count() == 0 + + +@pytest.mark.django_db +def test_autocreate_admin_existing_user(settings): + """Test that an admin user is not auto-created when an existing user is present.""" + settings.AUTOCREATE_ADMIN = True + User.objects.create_user("testuser", "test@example.com", "password") + client = Client() + response = client.get(reverse("route_manager:home")) + assert response.status_code == 200 # noqa: PLR2004 + assert User.objects.count() == 1 + assert not User.objects.filter(username="admin").exists() diff --git a/scram/route_manager/tests/test_history.py b/scram/route_manager/tests/test_history.py index f00f16bb..d2848e83 100644 --- a/scram/route_manager/tests/test_history.py +++ b/scram/route_manager/tests/test_history.py @@ -1,3 +1,5 @@ +"""Define tests for the history feature.""" + from django.test import TestCase from simple_history.utils import get_change_reason_from_object, update_change_reason @@ -5,20 +7,27 @@ class TestActiontypeHistory(TestCase): + """Test the history on an action type.""" + def setUp(self): + """Set up the test environment.""" self.atype = ActionType.objects.create(name="Block") def test_comments(self): + """Ensure we can go back and set a reason.""" self.atype.name = "Nullroute" - self.atype._change_reason = "Use more descriptive name" + self.atype._change_reason = "Use more descriptive name" # noqa SLF001 self.atype.save() self.assertIsNotNone(get_change_reason_from_object(self.atype)) class TestEntryHistory(TestCase): + """Test the history on an Entry.""" + routes = ["192.0.2.16/32", "198.51.100.16/28"] def setUp(self): + """Set up the test environment.""" self.atype = ActionType.objects.create(name="Block") for r in self.routes: route = Route.objects.create(route=r) @@ -28,6 +37,7 @@ def setUp(self): self.assertEqual(entry.get_change_reason(), create_reason) def test_comments(self): + """Ensure we can update the reason.""" for r in self.routes: route_old = Route.objects.get(route=r) e = Entry.objects.get(route=route_old) @@ -37,7 +47,7 @@ def test_comments(self): e.route = Route.objects.create(route=route_new) change_reason = "I meant 32, not 16." - e._change_reason = change_reason + e._change_reason = change_reason # noqa SLF001 e.save() self.assertEqual(len(e.history.all()), 2) diff --git a/scram/route_manager/tests/test_swagger.py b/scram/route_manager/tests/test_swagger.py new file mode 100644 index 00000000..ce955db9 --- /dev/null +++ b/scram/route_manager/tests/test_swagger.py @@ -0,0 +1,30 @@ +"""Test the swagger API endpoints.""" + +import pytest +from django.urls import reverse + + +@pytest.mark.django_db +def test_swagger_api(client): + """Test that the Swagger API endpoint returns a successful response.""" + url = reverse("swagger-ui") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + + +@pytest.mark.django_db +def test_redoc_api(client): + """Test that the Redoc API endpoint returns a successful response.""" + url = reverse("redoc") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + + +@pytest.mark.django_db +def test_schema_api(client): + """Test that the Schema API endpoint returns a successful response.""" + url = reverse("schema") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + expected_strings = [b"/entries/", b"/actiontypes/", b"/ignore_entries/", b"/users/"] + assert all(string in response.content for string in expected_strings) diff --git a/scram/route_manager/tests/test_views.py b/scram/route_manager/tests/test_views.py index 415cdae0..1433135a 100644 --- a/scram/route_manager/tests/test_views.py +++ b/scram/route_manager/tests/test_views.py @@ -1,10 +1,58 @@ +"""Define simple tests for the template-based Views.""" + +from django.conf import settings from django.test import TestCase -from django.urls import resolve +from django.urls import resolve, reverse from scram.route_manager.views import home_page class HomePageTest(TestCase): + """Test how the home page renders.""" + def test_root_url_resolves_to_home_page_view(self): + """Ensure we can find the home page.""" found = resolve("/") self.assertEqual(found.func, home_page) + + +class HomePageFirstVisitTest(TestCase): + """Test how the home page renders the first time we view it.""" + + def setUp(self): + """Get the home page.""" + self.response = self.client.get(reverse("route_manager:home")) + + def test_first_homepage_view_has_userinfo(self): + """The first time we view the home page, a user was created for us.""" + self.assertContains(self.response, b"An admin user was created for you.") + + def test_first_homepage_view_is_logged_in(self): + """The first time we view the home page, we're logged in.""" + self.assertContains(self.response, b'type="submit">Logout') + + +class HomePageLogoutTest(TestCase): + """Verify that once logged out, we can't view anything.""" + + def test_homepage_logout_links_missing(self): + """After logout, we can't see anything.""" + response = self.client.get(reverse("route_manager:home")) + response = self.client.post(reverse(settings.LOGOUT_URL), follow=True) + self.assertEqual(response.status_code, 200) + response = self.client.get(reverse("route_manager:home")) + + self.assertNotContains(response, b"An admin user was created for you.") + self.assertNotContains(response, b'type="submit">Logout') + self.assertNotContains(response, b">Admin") + + +class NotFoundTest(TestCase): + """Verify that our custom 404 page is being served up.""" + + def test_404(self): + """Grab a bad URL.""" + response = self.client.get("/foobarbaz") + self.assertContains( + response, b'
The page you are looking for was not found.
', status_code=404 + ) diff --git a/scram/route_manager/tests/test_websockets.py b/scram/route_manager/tests/test_websockets.py index 367359b8..ca9cd528 100644 --- a/scram/route_manager/tests/test_websockets.py +++ b/scram/route_manager/tests/test_websockets.py @@ -1,3 +1,5 @@ +"""Define unit tests for the websockets-based communication.""" + import json from asyncio import gather from contextlib import asynccontextmanager @@ -15,7 +17,7 @@ @asynccontextmanager async def get_communicators(actiontypes, should_match, *args, **kwds): - """Creates a set of communicators, and then handles tear-down. + """Create a set of communicators, and then handle tear-down. Given two lists of the same length, a set of actiontypes, and set of boolean values, creates that many communicators, one for each actiontype-bool pair. @@ -24,15 +26,13 @@ async def get_communicators(actiontypes, should_match, *args, **kwds): Returns a list of (communicator, should_match bool) pairs. """ - assert len(actiontypes) == len(should_match) - router = URLRouter(websocket_urlpatterns) communicators = [ WebsocketCommunicator(router, f"/ws/route_manager/translator_{actiontype}/") for actiontype in actiontypes ] - response = zip(communicators, should_match) + response = zip(communicators, should_match, strict=True) - for communicator, should_match in response: + for communicator, _ in response: connected, _ = await communicator.connect() assert connected @@ -40,7 +40,7 @@ async def get_communicators(actiontypes, should_match, *args, **kwds): yield response finally: - for communicator, should_match in response: + for communicator, _ in response: await communicator.disconnect() @@ -48,6 +48,7 @@ class TestTranslatorBaseCase(TestCase): """Base case that other test cases build on top of. Three translators in one group, test one v4 and one v6.""" def setUp(self): + """Set up our test environment.""" # TODO: This is copied from test_api; should de-dupe this. self.url = reverse("api:v1:entry-list") self.superuser = get_user_model().objects.create_superuser("admin", "admin@example.net", "admintestpassword") @@ -68,7 +69,9 @@ def setUp(self): wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") _, _ = WebSocketSequenceElement.objects.get_or_create( - websocketmessage=wsm, verb="A", action_type=self.actiontype + websocketmessage=wsm, + verb="A", + action_type=self.actiontype, ) # Set some defaults; some child classes override this @@ -77,10 +80,10 @@ def setUp(self): self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}] # Now we run any local setup actions by the child classes - self.local_setUp() + self.local_setup() - def local_setUp(self): - # Allow child classes to override this if desired + def local_setup(self): + """Allow child classes to override this if desired.""" return async def get_messages(self, communicator, messages, should_match): @@ -95,6 +98,7 @@ async def get_nothings(self, communicator): assert await communicator.receive_nothing(timeout=0.1, interval=0.01) is False async def add_ip(self, ip, mask): + """Ensure we can add an IP to block.""" async with get_communicators(self.actiontypes, self.should_match) as communicators: await self.api_create_entry(ip) @@ -119,6 +123,7 @@ async def ensure_no_more_msgs(self, communicators): # Django ensures that the create is synchronous, so we have some extra steps to do @sync_to_async def api_create_entry(self, route): + """Ensure we can create an Entry via the API.""" return self.client.post( self.url, { @@ -131,12 +136,14 @@ def api_create_entry(self, route): ) async def test_add_v4(self): + """Test adding a few v4 routes.""" await self.add_ip("192.0.2.224", 32) await self.add_ip("192.0.2.225", 32) await self.add_ip("192.0.2.226", 32) await self.add_ip("198.51.100.224", 32) async def test_add_v6(self): + """Test adding a few v6 routes.""" await self.add_ip("2001:DB8:FDF0::", 128) await self.add_ip("2001:DB8:FDF0::D", 128) await self.add_ip("2001:DB8:FDF0::DB", 128) @@ -144,9 +151,10 @@ async def test_add_v6(self): class TranslatorDontCrossTheStreamsTestCase(TestTranslatorBaseCase): - """Two translators in the same group, two in another group, single IP, ensure we get only the messages we expect.""" + """Two translators in one group, two in another group, single IP, ensure we get only the messages we expect.""" - def local_setUp(self): + def local_setup(self): + """Define the actions and what we expect.""" self.actiontypes = ["block", "block", "noop", "noop"] self.should_match = [True, True, False, False] @@ -154,14 +162,21 @@ def local_setUp(self): class TranslatorSequenceTestCase(TestTranslatorBaseCase): """Test a sequence of WebSocket messages.""" - def local_setUp(self): + def local_setup(self): + """Define the messages we want to send.""" wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo") _ = WebSocketSequenceElement.objects.create( - websocketmessage=wsm2, verb="A", action_type=self.actiontype, order_num=20 + websocketmessage=wsm2, + verb="A", + action_type=self.actiontype, + order_num=20, ) wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar") _ = WebSocketSequenceElement.objects.create( - websocketmessage=wsm3, verb="A", action_type=self.actiontype, order_num=2 + websocketmessage=wsm3, + verb="A", + action_type=self.actiontype, + order_num=2, ) self.generate_add_msgs = [ @@ -174,7 +189,8 @@ def local_setUp(self): class TranslatorParametersTestCase(TestTranslatorBaseCase): """Additional parameters in the JSONField.""" - def local_setUp(self): + def local_setup(self): + """Define the message we want to send.""" wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route") wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."} wsm.save() diff --git a/scram/route_manager/urls.py b/scram/route_manager/urls.py index dde55639..2735cc64 100644 --- a/scram/route_manager/urls.py +++ b/scram/route_manager/urls.py @@ -1,3 +1,5 @@ +"""Register URLs known to Django, and the View that will handle each.""" + from django.urls import path from . import views diff --git a/scram/route_manager/views.py b/scram/route_manager/views.py index 54ff321a..fa8b7116 100644 --- a/scram/route_manager/views.py +++ b/scram/route_manager/views.py @@ -1,3 +1,5 @@ +"""Define the Views that will handle the HTTP requests.""" + import ipaddress import json @@ -22,7 +24,10 @@ channel_layer = get_channel_layer() -def home_page(request, prefilter=Entry.objects.all()): +def home_page(request, prefilter=None): + """Return the home page, autocreating a user if none exists.""" + if not prefilter: + prefilter = Entry.objects.all() num_entries = settings.RECENT_LIMIT if request.user.has_perms(("route_manager.view_entry", "route_manager.add_entry")): readwrite = True @@ -60,6 +65,7 @@ def home_page(request, prefilter=Entry.objects.all()): def search_entries(request): + """Wrap the home page with a specified CIDR to restrict Entries to.""" # Using ipaddress because we needed to turn off strict mode # (which netfields uses by default with seemingly no toggle) # This caused searches with host bits set to 500 which is bad UX see: 68854ee1ad4789a62863083d521bddbc96ab7025 @@ -77,11 +83,14 @@ def search_entries(request): @require_POST @permission_required(["route_manager.view_entry", "route_manager.delete_entry"]) def delete_entry(request, pk): + """Wrap delete via the API and redirect to the home page.""" delete_entry_api(request, pk) return redirect("route_manager:home") class EntryDetailView(PermissionRequiredMixin, DetailView): + """Define a view for the API to use.""" + permission_required = ["route_manager.view_entry"] model = Entry template_name = "route_manager/entry_detail.html" @@ -92,35 +101,35 @@ class EntryDetailView(PermissionRequiredMixin, DetailView): @permission_required(["route_manager.view_entry", "route_manager.add_entry"]) def add_entry(request): + """Send a WebSocket message when adding a new entry.""" with transaction.atomic(): res = add_entry_api(request) - if res.status_code == 201: + if res.status_code == 201: # noqa: PLR2004 messages.add_message( request, messages.SUCCESS, "Entry successfully added", ) - elif res.status_code == 400: + elif res.status_code == 400: # noqa: PLR2004 errors = [] if isinstance(res.data, rest_framework.utils.serializer_helpers.ReturnDict): for k, v in res.data.items(): - for error in v: - errors.append(f"'{k}' error: {str(error)}") + errors.extend(f"'{k}' error: {error!s}" for error in v) else: - for k, v in res.data.items(): - errors.append(f"error: {str(v)}") + errors.extend(f"error: {v!s}" for v in res.data.values()) messages.add_message(request, messages.ERROR, "
".join(errors)) - elif res.status_code == 403: + elif res.status_code == 403: # noqa: PLR2004 messages.add_message(request, messages.ERROR, "Permission Denied") else: messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}") with transaction.atomic(): home = home_page(request) - return home + return home # noqa RET504 def process_expired(request): + """For entries with an expiration, set them to inactive if expired. Return some simple stats.""" current_time = timezone.now() with transaction.atomic(): entries_start = Entry.objects.filter(is_active=True).count() @@ -133,17 +142,21 @@ def process_expired(request): { "entries_deleted": entries_start - entries_end, "active_entries": entries_end, - } + }, ), content_type="application/json", ) class EntryListView(ListView): + """Define a view for the API to use.""" + model = Entry template_name = "route_manager/entry_list.html" - def get_context_data(self, **kwargs): + @staticmethod + def get_context_data(**kwargs): + """Group entries by action type.""" context = {"entries": {}} for at in ActionType.objects.all(): queryset = Entry.objects.filter(actiontype=at).order_by("-pk") diff --git a/scram/templates/account/account_inactive.html b/scram/templates/account/account_inactive.html deleted file mode 100644 index 17c21577..00000000 --- a/scram/templates/account/account_inactive.html +++ /dev/null @@ -1,12 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} - -{% block head_title %}{% trans "Account Inactive" %}{% endblock %} - -{% block inner %} -

{% trans "Account Inactive" %}

- -

{% trans "This account is inactive." %}

-{% endblock %} - diff --git a/scram/templates/account/base.html b/scram/templates/account/base.html deleted file mode 100644 index 97f4ddb8..00000000 --- a/scram/templates/account/base.html +++ /dev/null @@ -1,10 +0,0 @@ -{% extends "../base.html" %} -{% block title %}{% block head_title %}{% endblock head_title %}{% endblock title %} - -{% block content %} -
-
- {% block inner %}{% endblock %} -
-
-{% endblock %} diff --git a/scram/templates/account/email.html b/scram/templates/account/email.html deleted file mode 100644 index 8eef4159..00000000 --- a/scram/templates/account/email.html +++ /dev/null @@ -1,80 +0,0 @@ - -{% extends "account/base.html" %} - -{% load i18n %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Account" %}{% endblock %} - -{% block inner %} -

{% trans "E-mail Addresses" %}

- -{% if user.emailaddress_set.all %} -

{% trans 'The following e-mail addresses are associated with your account:' %}

- - - -{% else %} -

{% trans 'Warning:'%} {% trans "You currently do not have any e-mail address set up. You should really add an e-mail address so you can receive notifications, reset your password, etc." %}

- -{% endif %} - - -

{% trans "Add E-mail Address" %}

- -
- {% csrf_token %} - {{ form|crispy }} - -
- -{% endblock %} - - -{% block inline_javascript %} -{{ block.super }} - -{% endblock %} - diff --git a/scram/templates/account/email_confirm.html b/scram/templates/account/email_confirm.html deleted file mode 100644 index 46c78126..00000000 --- a/scram/templates/account/email_confirm.html +++ /dev/null @@ -1,32 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load account %} - -{% block head_title %}{% trans "Confirm E-mail Address" %}{% endblock %} - - -{% block inner %} -

{% trans "Confirm E-mail Address" %}

- -{% if confirmation %} - -{% user_display confirmation.email_address.user as user_display %} - -

{% blocktrans with confirmation.email_address.email as email %}Please confirm that {{ email }} is an e-mail address for user {{ user_display }}.{% endblocktrans %}

- -
-{% csrf_token %} - -
- -{% else %} - -{% url 'account_email' as email_url %} - -

{% blocktrans %}This e-mail confirmation link expired or is invalid. Please issue a new e-mail confirmation request.{% endblocktrans %}

- -{% endif %} - -{% endblock %} - diff --git a/scram/templates/account/login.html b/scram/templates/account/login.html deleted file mode 100644 index 2cadea6a..00000000 --- a/scram/templates/account/login.html +++ /dev/null @@ -1,48 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load account socialaccount %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Sign In" %}{% endblock %} - -{% block inner %} - -

{% trans "Sign In" %}

- -{% get_providers as socialaccount_providers %} - -{% if socialaccount_providers %} -

{% blocktrans with site.name as site_name %}Please sign in with one -of your existing third party accounts. Or, sign up -for a {{ site_name }} account and sign in below:{% endblocktrans %}

- -
- -
    - {% include "socialaccount/snippets/provider_list.html" with process="login" %} -
- - - -
- -{% include "socialaccount/snippets/login_extra.html" %} - -{% else %} -

{% blocktrans %}If you have not created an account yet, then please -sign up first.{% endblocktrans %}

-{% endif %} - - - -{% endblock %} - diff --git a/scram/templates/account/logout.html b/scram/templates/account/logout.html deleted file mode 100644 index 8e2e6754..00000000 --- a/scram/templates/account/logout.html +++ /dev/null @@ -1,22 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} - -{% block head_title %}{% trans "Sign Out" %}{% endblock %} - -{% block inner %} -

{% trans "Sign Out" %}

- -

{% trans 'Are you sure you want to sign out?' %}

- -
- {% csrf_token %} - {% if redirect_field_value %} - - {% endif %} - -
- - -{% endblock %} - diff --git a/scram/templates/account/password_change.html b/scram/templates/account/password_change.html deleted file mode 100644 index b72ca068..00000000 --- a/scram/templates/account/password_change.html +++ /dev/null @@ -1,17 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Change Password" %}{% endblock %} - -{% block inner %} -

{% trans "Change Password" %}

- -
- {% csrf_token %} - {{ form|crispy }} - -
-{% endblock %} - diff --git a/scram/templates/account/password_reset.html b/scram/templates/account/password_reset.html deleted file mode 100644 index c98f28c1..00000000 --- a/scram/templates/account/password_reset.html +++ /dev/null @@ -1,25 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load account %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Password Reset" %}{% endblock %} - -{% block inner %} - -

{% trans "Password Reset" %}

- {% if user.is_authenticated %} - {% include "/snippets/already_logged_in.html" %} - {% endif %} - -

{% trans "Forgotten your password? Enter your e-mail address below, and we'll send you an e-mail allowing you to reset it." %}

- -
- {% csrf_token %} - {{ form|crispy }} - -
- -

{% blocktrans %}Please contact us if you have any trouble resetting your password.{% endblocktrans %}

-{% endblock %} diff --git a/scram/templates/account/password_reset_done.html b/scram/templates/account/password_reset_done.html deleted file mode 100644 index 835156ca..00000000 --- a/scram/templates/account/password_reset_done.html +++ /dev/null @@ -1,16 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load account %} - -{% block head_title %}{% trans "Password Reset" %}{% endblock %} - -{% block inner %} -

{% trans "Password Reset" %}

- - {% if user.is_authenticated %} - {% include "/snippets/already_logged_in.html" %} - {% endif %} - -

{% blocktrans %}We have sent you an e-mail. Please contact us if you do not receive it within a few minutes.{% endblocktrans %}

-{% endblock %} diff --git a/scram/templates/account/password_reset_from_key.html b/scram/templates/account/password_reset_from_key.html deleted file mode 100644 index 2e2cd194..00000000 --- a/scram/templates/account/password_reset_from_key.html +++ /dev/null @@ -1,24 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load crispy_forms_tags %} -{% block head_title %}{% trans "Change Password" %}{% endblock %} - -{% block inner %} -

{% if token_fail %}{% trans "Bad Token" %}{% else %}{% trans "Change Password" %}{% endif %}

- - {% if token_fail %} - {% url 'account_reset_password' as passwd_reset_url %} -

{% blocktrans %}The password reset link was invalid, possibly because it has already been used. Please request a new password reset.{% endblocktrans %}

- {% else %} - {% if form %} -
- {% csrf_token %} - {{ form|crispy }} - -
- {% else %} -

{% trans 'Your password is now changed.' %}

- {% endif %} - {% endif %} -{% endblock %} diff --git a/scram/templates/account/password_reset_from_key_done.html b/scram/templates/account/password_reset_from_key_done.html deleted file mode 100644 index 89be086f..00000000 --- a/scram/templates/account/password_reset_from_key_done.html +++ /dev/null @@ -1,10 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% block head_title %}{% trans "Change Password" %}{% endblock %} - -{% block inner %} -

{% trans "Change Password" %}

-

{% trans 'Your password is now changed.' %}

-{% endblock %} - diff --git a/scram/templates/account/password_set.html b/scram/templates/account/password_set.html deleted file mode 100644 index 22322235..00000000 --- a/scram/templates/account/password_set.html +++ /dev/null @@ -1,17 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Set Password" %}{% endblock %} - -{% block inner %} -

{% trans "Set Password" %}

- -
- {% csrf_token %} - {{ form|crispy }} - -
-{% endblock %} - diff --git a/scram/templates/account/signup.html b/scram/templates/account/signup.html deleted file mode 100644 index 6a2954eb..00000000 --- a/scram/templates/account/signup.html +++ /dev/null @@ -1,23 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} -{% load crispy_forms_tags %} - -{% block head_title %}{% trans "Signup" %}{% endblock %} - -{% block inner %} -

{% trans "Sign Up" %}

- -

{% blocktrans %}Already have an account? Then please sign in.{% endblocktrans %}

- - - -{% endblock %} - diff --git a/scram/templates/account/signup_closed.html b/scram/templates/account/signup_closed.html deleted file mode 100644 index 2322f176..00000000 --- a/scram/templates/account/signup_closed.html +++ /dev/null @@ -1,12 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} - -{% block head_title %}{% trans "Sign Up Closed" %}{% endblock %} - -{% block inner %} -

{% trans "Sign Up Closed" %}

- -

{% trans "We are sorry, but the sign up is currently closed." %}

-{% endblock %} - diff --git a/scram/templates/account/verification_sent.html b/scram/templates/account/verification_sent.html deleted file mode 100644 index ad093fd4..00000000 --- a/scram/templates/account/verification_sent.html +++ /dev/null @@ -1,13 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} - -{% block head_title %}{% trans "Verify Your E-mail Address" %}{% endblock %} - -{% block inner %} -

{% trans "Verify Your E-mail Address" %}

- -

{% blocktrans %}We have sent an e-mail to you for verification. Follow the link provided to finalize the signup process. Please contact us if you do not receive it within a few minutes.{% endblocktrans %}

- -{% endblock %} - diff --git a/scram/templates/account/verified_email_required.html b/scram/templates/account/verified_email_required.html deleted file mode 100644 index 09d4fde7..00000000 --- a/scram/templates/account/verified_email_required.html +++ /dev/null @@ -1,24 +0,0 @@ -{% extends "account/base.html" %} - -{% load i18n %} - -{% block head_title %}{% trans "Verify Your E-mail Address" %}{% endblock %} - -{% block inner %} -

{% trans "Verify Your E-mail Address" %}

- -{% url 'account_email' as email_url %} - -

{% blocktrans %}This part of the site requires us to verify that -you are who you claim to be. For this purpose, we require that you -verify ownership of your e-mail address. {% endblocktrans %}

- -

{% blocktrans %}We have sent an e-mail to you for -verification. Please click on the link inside this e-mail. Please -contact us if you do not receive it within a few minutes.{% endblocktrans %}

- -

{% blocktrans %}Note: you can still change your e-mail address.{% endblocktrans %}

- - -{% endblock %} - diff --git a/scram/templates/local_auth/login.html b/scram/templates/local_auth/login.html new file mode 100644 index 00000000..e2355d27 --- /dev/null +++ b/scram/templates/local_auth/login.html @@ -0,0 +1,12 @@ +{% extends 'base.html' %} + +{% block title %}Login{% endblock %} + +{% block content %} +

Login

+
+ {% csrf_token %} + {{ form.as_p }} + +
+{% endblock %} diff --git a/scram/users/__init__.py b/scram/users/__init__.py index e69de29b..8728ced7 100644 --- a/scram/users/__init__.py +++ b/scram/users/__init__.py @@ -0,0 +1 @@ +"""The Users app defines users and a permission scheme for them.""" diff --git a/scram/users/admin.py b/scram/users/admin.py index c69eb541..ae35bc3f 100644 --- a/scram/users/admin.py +++ b/scram/users/admin.py @@ -1,3 +1,5 @@ +"""Register Admin models with the Django Admin site.""" + from django.contrib import admin from django.contrib.auth import admin as auth_admin from django.contrib.auth import get_user_model @@ -10,6 +12,7 @@ @admin.register(User) class UserAdmin(auth_admin.UserAdmin): + """Allow managing Users via the Django admin site.""" form = UserChangeForm add_form = UserCreationForm diff --git a/scram/users/api/serializers.py b/scram/users/api/serializers.py index c95bfe62..f8d22c88 100644 --- a/scram/users/api/serializers.py +++ b/scram/users/api/serializers.py @@ -1,3 +1,5 @@ +"""Serializers provide mappings between the API and the underlying model.""" + from django.contrib.auth import get_user_model from rest_framework import serializers @@ -5,7 +7,11 @@ class UserSerializer(serializers.ModelSerializer): + """Map to the User model.""" + class Meta: + """Specify the fields exposed by the API.""" + model = User fields = ["username", "name", "url"] diff --git a/scram/users/api/views.py b/scram/users/api/views.py index 288ea7ab..0e17cb54 100644 --- a/scram/users/api/views.py +++ b/scram/users/api/views.py @@ -1,3 +1,5 @@ +"""Views provide mappings between the underlying model and how they're listed in the API.""" + from django.contrib.auth import get_user_model from rest_framework import status from rest_framework.decorators import action @@ -11,14 +13,19 @@ class UserViewSet(RetrieveModelMixin, ListModelMixin, UpdateModelMixin, GenericViewSet): + """Lookup Users by username.""" + serializer_class = UserSerializer queryset = User.objects.all() lookup_field = "username" def get_queryset(self, *args, **kwargs): + """Query on User ID.""" return self.queryset.filter(id=self.request.user.id) + @staticmethod @action(detail=False, methods=["GET"]) - def me(self, request): + def me(request): + """Return the current user.""" serializer = UserSerializer(request.user, context={"request": request}) return Response(status=status.HTTP_200_OK, data=serializer.data) diff --git a/scram/users/apps.py b/scram/users/apps.py index 2ce72578..e6cc9c6b 100644 --- a/scram/users/apps.py +++ b/scram/users/apps.py @@ -1,3 +1,5 @@ +"""Register ourselves with Django.""" + import logging from django.apps import AppConfig @@ -7,10 +9,14 @@ class UsersConfig(AppConfig): + """Define the name of the module for the Users app.""" + name = "scram.users" verbose_name = _("Users") - def ready(self): + @staticmethod + def ready(): + """Check if signals are registered for User events.""" try: import scram.users.signals # noqa F401 except ImportError: diff --git a/scram/users/forms.py b/scram/users/forms.py index e2792fc2..b079f2a4 100644 --- a/scram/users/forms.py +++ b/scram/users/forms.py @@ -1,3 +1,5 @@ +"""Define forms for the Admin site.""" + from django.contrib.auth import forms as admin_forms from django.contrib.auth import get_user_model from django.utils.translation import gettext_lazy as _ @@ -6,12 +8,20 @@ class UserChangeForm(admin_forms.UserChangeForm): + """Define a form to edit a User.""" + class Meta(admin_forms.UserChangeForm.Meta): + """Map to the User model.""" + model = User class UserCreationForm(admin_forms.UserCreationForm): + """Define a form to create a User.""" + class Meta(admin_forms.UserCreationForm.Meta): + """Map to the User model and provide custom error messages.""" + model = User error_messages = {"username": {"unique": _("This username has already been taken.")}} diff --git a/scram/users/models.py b/scram/users/models.py index c2e1914f..81a905ca 100644 --- a/scram/users/models.py +++ b/scram/users/models.py @@ -1,3 +1,5 @@ +"""Define models for the User application.""" + from django.contrib.auth.models import AbstractUser from django.db.models import CharField from django.urls import reverse diff --git a/scram/users/tests/__init__.py b/scram/users/tests/__init__.py index e69de29b..f1101332 100644 --- a/scram/users/tests/__init__.py +++ b/scram/users/tests/__init__.py @@ -0,0 +1 @@ +"""Define tests for the User application.""" diff --git a/scram/users/tests/factories.py b/scram/users/tests/factories.py index edd306cb..10331878 100644 --- a/scram/users/tests/factories.py +++ b/scram/users/tests/factories.py @@ -1,4 +1,7 @@ -from typing import Any, Sequence +"""Define Factory tests for the Users application.""" + +from collections.abc import Sequence +from typing import Any from django.contrib.auth import get_user_model from factory import Faker, post_generation @@ -6,6 +9,7 @@ class UserFactory(DjangoModelFactory): + """Test the UserFactory.""" username = Faker("user_name") email = Faker("email") @@ -13,6 +17,7 @@ class UserFactory(DjangoModelFactory): @post_generation def password(self, create: bool, extracted: Sequence[Any], **kwargs): + """Test password assignment.""" password = ( extracted if extracted @@ -28,5 +33,7 @@ def password(self, create: bool, extracted: Sequence[Any], **kwargs): self.set_password(password) class Meta: + """Map to User model.""" + model = get_user_model() django_get_or_create = ["username"] diff --git a/scram/users/tests/test_admin.py b/scram/users/tests/test_admin.py index 66f20cda..fd49e7ae 100644 --- a/scram/users/tests/test_admin.py +++ b/scram/users/tests/test_admin.py @@ -1,3 +1,5 @@ +"""Tests for the Admin site for the Users application.""" + import pytest from django.urls import reverse @@ -7,17 +9,22 @@ class TestUserAdmin: + """Test various User operations via the Admin site.""" + def test_changelist(self, admin_client): + """Ensure we can view the changelist.""" url = reverse("admin:users_user_changelist") response = admin_client.get(url) assert response.status_code == 200 def test_search(self, admin_client): + """Ensure we can view the search.""" url = reverse("admin:users_user_changelist") response = admin_client.get(url, data={"q": "test"}) assert response.status_code == 200 def test_add(self, admin_client): + """Ensure we can add a user.""" url = reverse("admin:users_user_add") response = admin_client.get(url) assert response.status_code == 200 @@ -34,6 +41,7 @@ def test_add(self, admin_client): assert User.objects.filter(username="test").exists() def test_view_user(self, admin_client): + """Ensure we can view a user.""" user = User.objects.get(username="admin") url = reverse("admin:users_user_change", kwargs={"object_id": user.pk}) response = admin_client.get(url) diff --git a/scram/users/tests/test_drf_urls.py b/scram/users/tests/test_drf_urls.py index 6e28c266..8e587486 100644 --- a/scram/users/tests/test_drf_urls.py +++ b/scram/users/tests/test_drf_urls.py @@ -1,3 +1,5 @@ +"""Test Django Request Framework use of the Users application.""" + import pytest from django.urls import resolve, reverse @@ -7,15 +9,18 @@ def test_user_detail(user: User): + """Ensure we can view details for a single User.""" assert reverse("api:v1:user-detail", kwargs={"username": user.username}) == f"/api/v1/users/{user.username}/" assert resolve(f"/api/v1/users/{user.username}/").view_name == "api:v1:user-detail" def test_user_list(): + """Ensure we can list all Users.""" assert reverse("api:v1:user-list") == "/api/v1/users/" assert resolve("/api/v1/users/").view_name == "api:v1:user-list" def test_user_me(): + """Ensure we can view info for the current user.""" assert reverse("api:v1:user-me") == "/api/v1/users/me/" assert resolve("/api/v1/users/me/").view_name == "api:v1:user-me" diff --git a/scram/users/tests/test_drf_views.py b/scram/users/tests/test_drf_views.py index 6e7b19e5..bbe56598 100644 --- a/scram/users/tests/test_drf_views.py +++ b/scram/users/tests/test_drf_views.py @@ -1,3 +1,5 @@ +"""Test Django Request Framework Views in the Users application.""" + import pytest from django.test import RequestFactory @@ -8,7 +10,10 @@ class TestUserViewSet: + """Test a couple simple View operations.""" + def test_get_queryset(self, user: User, rf: RequestFactory): + """Ensure we can view an arbitrary URL.""" view = UserViewSet() request = rf.get("/fake-url/") request.user = user @@ -18,6 +23,7 @@ def test_get_queryset(self, user: User, rf: RequestFactory): assert user in view.get_queryset() def test_me(self, user: User, rf: RequestFactory): + """Ensure we can view info on the current user.""" view = UserViewSet() request = rf.get("/fake-url/") request.user = user diff --git a/scram/users/tests/test_forms.py b/scram/users/tests/test_forms.py index 8041e621..3fe65da4 100644 --- a/scram/users/tests/test_forms.py +++ b/scram/users/tests/test_forms.py @@ -1,8 +1,7 @@ -""" -Module for all Form Tests. -""" +"""Module for all Form Tests.""" + import pytest -from django.utils.translation import ugettext_lazy as _ +from django.utils.translation import gettext_lazy as _ from scram.users.forms import UserCreationForm from scram.users.models import User @@ -11,18 +10,15 @@ class TestUserCreationForm: - """ - Test class for all tests related to the UserCreationForm - """ + """Test class for all tests related to the UserCreationForm.""" def test_username_validation_error_msg(self, user: User): - """ - Tests UserCreation Form's unique validator functions correctly by testing: - 1) A new user with an existing username cannot be added. - 2) Only 1 error is raised by the UserCreation Form - 3) The desired error message is raised - """ + """Tests UserCreation Form's unique validator functions correctly by testing 3 things. + 1) A new user with an existing username cannot be added. + 2) Only 1 error is raised by the UserCreation Form + 3) The desired error message is raised + """ # The user already exists, # hence cannot be created. form = UserCreationForm( @@ -30,7 +26,7 @@ def test_username_validation_error_msg(self, user: User): "username": user.username, "password1": user.password, "password2": user.password, - } + }, ) assert not form.is_valid() diff --git a/scram/users/tests/test_models.py b/scram/users/tests/test_models.py index 9920c0ff..a224cd7d 100644 --- a/scram/users/tests/test_models.py +++ b/scram/users/tests/test_models.py @@ -1,3 +1,5 @@ +"""Define tests for the Models in the Users application.""" + import pytest from scram.users.models import User @@ -6,4 +8,5 @@ def test_user_get_absolute_url(user: User): + """Ensure w ecan convert a User object into a URL with info about that user.""" assert user.get_absolute_url() == f"/users/{user.username}/" diff --git a/scram/users/tests/test_urls.py b/scram/users/tests/test_urls.py index dd36a987..aa5a57ad 100644 --- a/scram/users/tests/test_urls.py +++ b/scram/users/tests/test_urls.py @@ -1,3 +1,5 @@ +"""Define tests for the URL resolution for the Users application.""" + import pytest from django.urls import resolve, reverse @@ -7,15 +9,18 @@ def test_detail(user: User): + """Ensure we can get the URL to view details about a single User.""" assert reverse("users:detail", kwargs={"username": user.username}) == f"/users/{user.username}/" assert resolve(f"/users/{user.username}/").view_name == "users:detail" def test_update(): + """Ensure we can get the URL to update a User.""" assert reverse("users:update") == "/users/~update/" assert resolve("/users/~update/").view_name == "users:update" def test_redirect(): + """Ensure we can get the URL to redirect to a User.""" assert reverse("users:redirect") == "/users/~redirect/" assert resolve("/users/~redirect/").view_name == "users:redirect" diff --git a/scram/users/tests/test_views.py b/scram/users/tests/test_views.py index 708f35e5..2df1929a 100644 --- a/scram/users/tests/test_views.py +++ b/scram/users/tests/test_views.py @@ -1,3 +1,5 @@ +"""Define tests for the Views of the Users application.""" + import pytest from django.conf import settings from django.contrib import messages @@ -16,15 +18,18 @@ class TestUserUpdateView: - """ - TODO: + """Define tests related to the Update View. + + Todo: extracting view initialization code as class-scoped fixture would be great if only pytest-django supported non-function-scoped fixture db access -- this is a work-in-progress for now: https://github.com/pytest-dev/pytest-django/pull/258 + """ def test_get_success_url(self, user: User, rf: RequestFactory): + """Ensure we can view an arbitrary URL.""" view = UserUpdateView() request = rf.get("/fake-url/") request.user = user @@ -34,6 +39,7 @@ def test_get_success_url(self, user: User, rf: RequestFactory): assert view.get_success_url() == f"/users/{user.username}/" def test_get_object(self, user: User, rf: RequestFactory): + """Ensure we can retrieve the User object.""" view = UserUpdateView() request = rf.get("/fake-url/") request.user = user @@ -43,12 +49,13 @@ def test_get_object(self, user: User, rf: RequestFactory): assert view.get_object() == user def test_form_valid(self, user: User, rf: RequestFactory): + """Ensure the form validation works.""" view = UserUpdateView() request = rf.get("/fake-url/") # Add the session/message middleware to the request - SessionMiddleware().process_request(request) - MessageMiddleware().process_request(request) + SessionMiddleware(get_response=request).process_request(request) + MessageMiddleware(get_response=request).process_request(request) request.user = user view.request = request @@ -63,7 +70,10 @@ def test_form_valid(self, user: User, rf: RequestFactory): class TestUserRedirectView: + """Define tests related to the Redirect View.""" + def test_get_redirect_url(self, user: User, rf: RequestFactory): + """Ensure the redirection URL works.""" view = UserRedirectView() request = rf.get("/fake-url") request.user = user @@ -74,7 +84,10 @@ def test_get_redirect_url(self, user: User, rf: RequestFactory): class TestUserDetailView: + """Define tests related to the User Detail View.""" + def test_authenticated(self, user: User, rf: RequestFactory): + """Ensure an authenticated user has access.""" request = rf.get("/fake-url/") request.user = UserFactory() @@ -83,6 +96,7 @@ def test_authenticated(self, user: User, rf: RequestFactory): assert response.status_code == 200 def test_not_authenticated(self, user: User, rf: RequestFactory): + """Ensure an unauthenticated user gets redirected to login.""" request = rf.get("/fake-url/") request.user = AnonymousUser() diff --git a/scram/users/urls.py b/scram/users/urls.py index 64f9d9cb..cf8647f3 100644 --- a/scram/users/urls.py +++ b/scram/users/urls.py @@ -1,3 +1,5 @@ +"""Register URLs known to Django, and the View that will handle each.""" + from django.urls import path from scram.users.views import user_detail_view, user_redirect_view, user_update_view diff --git a/scram/users/views.py b/scram/users/views.py index 2b077d42..e05f5cec 100644 --- a/scram/users/views.py +++ b/scram/users/views.py @@ -1,3 +1,5 @@ +"""Define the Views that will handle the HTTP requests.""" + from django.contrib.auth import get_user_model from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.messages.views import SuccessMessageMixin @@ -9,6 +11,7 @@ class UserDetailView(LoginRequiredMixin, DetailView): + """Map this view to the User model, and rely on the DetailView generic view.""" model = User slug_field = "username" @@ -19,15 +22,18 @@ class UserDetailView(LoginRequiredMixin, DetailView): class UserUpdateView(LoginRequiredMixin, SuccessMessageMixin, UpdateView): + """Map this view to the User model, and rely on the UpdateView generic view.""" model = User fields = ["name"] success_message = _("Information successfully updated") def get_success_url(self): + """Return the User detail view.""" return reverse("users:detail", kwargs={"username": self.request.user.username}) def get_object(self): + """Return the User object.""" return self.request.user @@ -35,10 +41,12 @@ def get_object(self): class UserRedirectView(LoginRequiredMixin, RedirectView): + """Map this view to the User model, and rely on the RedirectView generic view.""" permanent = False def get_redirect_url(self): + """Return the User detail view.""" return reverse("users:detail", kwargs={"username": self.request.user.username}) diff --git a/scram/utils/__init__.py b/scram/utils/__init__.py index e69de29b..e5a29f6c 100644 --- a/scram/utils/__init__.py +++ b/scram/utils/__init__.py @@ -0,0 +1 @@ +"""A module for defining utility functions and code.""" diff --git a/scram/utils/context_processors.py b/scram/utils/context_processors.py index 3c535141..0bbf8786 100644 --- a/scram/utils/context_processors.py +++ b/scram/utils/context_processors.py @@ -1,8 +1,14 @@ +"""Define custom functions that take a request and add to the context before template rendering.""" + from django.conf import settings def settings_context(_request): - """Settings available by default to the templates context.""" + """Define settings available by default to the templates context. + + Returns: + dict: Whether or not we have DEBUG on + """ # Note: we intentionally do NOT expose the entire settings # to prevent accidental leaking of sensitive information return {"DEBUG": settings.DEBUG} diff --git a/setup.cfg b/setup.cfg index 84ec37bb..2c082581 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,12 +24,6 @@ django_settings_module = config.settings.test # Django migrations should not produce any errors: ignore_errors = True -[coverage:run] -include = scram/* -omit = *migrations*, *tests* -plugins = - django_coverage_plugin - [behave] paths = scram/route_manager/tests/acceptance stderr_capture = no diff --git a/translator/exceptions.py b/translator/exceptions.py index f1a54299..b6d2cfd6 100644 --- a/translator/exceptions.py +++ b/translator/exceptions.py @@ -1,9 +1,5 @@ -""" -This module holds all of the exceptions we want to raise in our translators. -""" +"""Define all of the exceptions we want to raise in our translators.""" class ASNError(TypeError): - """ - ASNError provides an error class to use when there is an issue with an Autonomous System Number. - """ + """ASNError provides an error class to use when there is an issue with an Autonomous System Number.""" diff --git a/translator/gobgp.py b/translator/gobgp.py index 23ae2572..f2b8dee9 100644 --- a/translator/gobgp.py +++ b/translator/gobgp.py @@ -1,3 +1,5 @@ +"""A translator interface for GoBGP (https://github.com/osrg/gobgp).""" + import logging import attribute_pb2 @@ -13,23 +15,32 @@ DEFAULT_COMMUNITY = 666 DEFAULT_V4_NEXTHOP = "192.0.2.199" DEFAULT_V6_NEXTHOP = "100::1" +MAX_SMALL_ASN = 2**16 +MAX_SMALL_COMM = 2**16 +IPV6 = 6 logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class GoBGP: + """Represents a GoBGP instance.""" -class GoBGP(object): def __init__(self, url): + """Configure the channel used for communication.""" channel = grpc.insecure_channel(url) self.stub = gobgp_pb2_grpc.GobgpApiStub(channel) - def _get_family_AFI(self, ip_version): - if ip_version == 6: + @staticmethod + def _get_family_afi(ip_version): + if ip_version == IPV6: return gobgp_pb2.Family.AFI_IP6 - else: - return gobgp_pb2.Family.AFI_IP + return gobgp_pb2.Family.AFI_IP - def _build_path(self, ip, event_data={}): + def _build_path(self, ip, event_data=None): # noqa: PLR0914 # Grab ASN and Community from our event_data, or use the defaults + if not event_data: + event_data = {} asn = event_data.get("asn", DEFAULT_ASN) community = event_data.get("community", DEFAULT_COMMUNITY) ip_version = ip.ip.version @@ -42,7 +53,7 @@ def _build_path(self, ip, event_data={}): origin.Pack( attribute_pb2.OriginAttribute( origin=2, - ) + ), ) # IP prefix and its associated length @@ -51,27 +62,27 @@ def _build_path(self, ip, event_data={}): attribute_pb2.IPAddressPrefix( prefix_len=ip.network.prefixlen, prefix=str(ip.ip), - ) + ), ) # Set the next hop to the correct value depending on IP family next_hop = Any() - family_afi = self._get_family_AFI(ip_version) - if ip_version == 6: + family_afi = self._get_family_afi(ip_version) + if ip_version == IPV6: next_hops = event_data.get("next_hop", DEFAULT_V6_NEXTHOP) next_hop.Pack( attribute_pb2.MpReachNLRIAttribute( family=gobgp_pb2.Family(afi=family_afi, safi=gobgp_pb2.Family.SAFI_UNICAST), next_hops=[next_hops], nlris=[nlri], - ) + ), ) else: next_hops = event_data.get("next_hop", DEFAULT_V4_NEXTHOP) next_hop.Pack( attribute_pb2.NextHopAttribute( next_hop=next_hops, - ) + ), ) # Set our AS Path @@ -89,13 +100,13 @@ def _build_path(self, ip, event_data={}): communities = Any() # Standard community # Since we pack both into the community string we need to make sure they will both fit - if asn < 65536 and community < 65536: + if asn < MAX_SMALL_ASN and community < MAX_SMALL_COMM: # We bitshift ASN left by 16 so that there is room to add the community on the end of it. This is because # GoBGP wants the community sent as a single integer. comm_id = (asn << 16) + community communities.Pack(attribute_pb2.CommunitiesAttribute(communities=[comm_id])) else: - logging.info(f"LargeCommunity Used - ASN:{asn} Community: {community}") + logger.info("LargeCommunity Used - ASN: %s. Community: %s", asn, community) global_admin = asn local_data1 = community # set to 0 because there's no use case for it, but we need a local_data2 for gobgp to read any of it @@ -116,7 +127,8 @@ def _build_path(self, ip, event_data={}): ) def add_path(self, ip, event_data): - logging.info(f"Blocking {ip}") + """Announce a single route.""" + logger.info("Blocking %s", ip) try: path = self._build_path(ip, event_data) @@ -125,15 +137,17 @@ def add_path(self, ip, event_data): _TIMEOUT_SECONDS, ) except ASNError as e: - logging.warning(f"ASN assertion failed with error: {e}") + logger.warning("ASN assertion failed with error: %s", e) def del_all_paths(self): - logging.warning("Withdrawing ALL routes") + """Remove all routes from being announced.""" + logger.warning("Withdrawing ALL routes") self.stub.DeletePath(gobgp_pb2.DeletePathRequest(table_type=gobgp_pb2.GLOBAL), _TIMEOUT_SECONDS) def del_path(self, ip, event_data): - logging.info(f"Unblocking {ip}") + """Remove a single route from being announced.""" + logger.info("Unblocking %s", ip) try: path = self._build_path(ip, event_data) self.stub.DeletePath( @@ -141,11 +155,16 @@ def del_path(self, ip, event_data): _TIMEOUT_SECONDS, ) except ASNError as e: - logging.warning(f"ASN assertion failed with error: {e}") + logger.warning("ASN assertion failed with error: %s", e) def get_prefixes(self, ip): + """Retrieve the routes that match a prefix and are announced. + + Returns: + list: The routes that overlap with the prefix and are currently announced. + """ prefixes = [gobgp_pb2.TableLookupPrefix(prefix=str(ip.ip))] - family_afi = self._get_family_AFI(ip.ip.version) + family_afi = self._get_family_afi(ip.ip.version) result = self.stub.ListPath( gobgp_pb2.ListPathRequest( table_type=gobgp_pb2.GLOBAL, @@ -157,4 +176,5 @@ def get_prefixes(self, ip): return list(result) def is_blocked(self, ip): + """Return True if at least one route matching the prefix is being announced.""" return len(self.get_prefixes(ip)) > 0 diff --git a/translator/shared.py b/translator/shared.py index 1fe048bd..78555b2d 100644 --- a/translator/shared.py +++ b/translator/shared.py @@ -1,13 +1,12 @@ -""" -This module provides a location for code that we want to share between all translators. -""" +"""Provide a location for code that we want to share between all translators.""" from exceptions import ASNError +MAX_ASN_VAL = 2**32 - 1 + def asn_is_valid(asn: int) -> bool: - """ - asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN. + """asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN. Args: asn (int): The Autonomous System Number that we want to validate @@ -17,11 +16,14 @@ def asn_is_valid(asn: int) -> bool: Returns: bool: _description_ + """ if not isinstance(asn, int): - raise ASNError(f"ASN {asn} is not an Integer, has type {type(asn)}") - if not 0 < asn < 4294967295: + msg = f"ASN {asn} is not an Integer, has type {type(asn)}" + raise ASNError(msg) + if not 0 < asn < MAX_ASN_VAL: # This is the max as stated in rfc6996 - raise ASNError(f"ASN {asn} is out of range. Must be between 0 and 4294967295") + msg = f"ASN {asn} is out of range. Must be between 0 and 4294967295" + raise ASNError(msg) return True diff --git a/translator/tests/acceptance/environment.py b/translator/tests/acceptance/environment.py index 1f06b08d..b5c57b41 100644 --- a/translator/tests/acceptance/environment.py +++ b/translator/tests/acceptance/environment.py @@ -1,6 +1,9 @@ +"""Configure the test environment before executing acceptance tests.""" + from gobgp import GoBGP def before_all(context): + """Create a GoBGP object.""" context.gobgp = GoBGP("gobgp:50051") context.config.setup_logging() diff --git a/translator/tests/acceptance/features/clean_slate.feature b/translator/tests/acceptance/features/clean_slate.feature new file mode 100644 index 00000000..9de4e337 --- /dev/null +++ b/translator/tests/acceptance/features/clean_slate.feature @@ -0,0 +1,8 @@ +Feature: Tests run correctly + When running tests, there are no leftover blocks + + Scenario: No leftover v4 blocks + Then 0.0.0.0/0 is unblocked + + Scenario: No leftover v6 blocks + Then ::/0 is unblocked diff --git a/translator/tests/acceptance/steps/actions.py b/translator/tests/acceptance/steps/actions.py index ce340de1..fdb2a95f 100644 --- a/translator/tests/acceptance/steps/actions.py +++ b/translator/tests/acceptance/steps/actions.py @@ -1,3 +1,5 @@ +"""Define the steps used by Behave.""" + import ipaddress import logging import time @@ -10,6 +12,7 @@ @when("we add {route} with {asn} and {community} to the block list") def add_block(context, route, asn, community): + """Block a single IP.""" ip = ipaddress.ip_interface(route) event_data = {"asn": int(asn), "community": int(community)} context.gobgp.add_path(ip, event_data) @@ -17,36 +20,41 @@ def add_block(context, route, asn, community): @then("we delete {route} with {asn} and {community} from the block list") def del_block(context, route, asn, community): + """Remove a single IP.""" ip = ipaddress.ip_interface(route) event_data = {"asn": int(asn), "community": int(community)} context.gobgp.del_path(ip, event_data) def get_block_status(context, ip): + """Check if the IP is currently blocked. + + Returns: + bool: The return value. True if the IP is currently blocked, False otherwise. + """ # Allow our add/delete requests to settle time.sleep(1) ip_obj = ipaddress.ip_interface(ip) - for path in context.gobgp.get_prefixes(ip_obj): - if ip_obj in ipaddress.ip_network(path.destination.prefix): - return True - - return False + return any(ip_obj in ipaddress.ip_network(path.destination.prefix) for path in context.gobgp.get_prefixes(ip_obj)) @capture @when("{route} and {community} with invalid {asn} is sent") def asn_validation_fails(context, route, asn, community): + """Ensure the ASN was invalid.""" add_block(context, route, asn, community) assert context.log_capture.find_event("ASN assertion failed") @then("{ip} is blocked") def check_block(context, ip): + """Ensure that the IP is currently blocked.""" assert get_block_status(context, ip) @then("{ip} is unblocked") def check_unblock(context, ip): + """Ensure that the IP is currently unblocked.""" assert not get_block_status(context, ip) diff --git a/translator/translator.py b/translator/translator.py index 8521d7ab..d2c34018 100644 --- a/translator/translator.py +++ b/translator/translator.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +"""Define the main event loop for the translator.""" + import asyncio import ipaddress import json @@ -9,28 +11,41 @@ import websockets from gobgp import GoBGP +logger = logging.getLogger(__name__) + +KNOWN_MESSAGES = { + "translator_add", + "translator_remove", + "translator_remove_all", + "translator_check", +} + # Here we setup a debugger if this is desired. This obviously should not be run in production. debug_mode = os.environ.get("DEBUG") if debug_mode: def install_deps(): - # Because of how we build translator currently, we don't have a great way to selectively install things at - # build, so we just do it here! Right now this also includes base.txt, which is unecessary, but in the - # future when we build a little better, it'll already be setup. - logging.info("Installing dependencies for debuggers") + """Install necessary dependencies for debuggers. + + Because of how we build translator currently, we don't have a great way to selectively + install things at build, so we just do it here! Right now this also includes base.txt, + which is unecessary, but in the future when we build a little better, it'll already be + setup. + """ + logger.info("Installing dependencies for debuggers") - import subprocess - import sys + import subprocess # noqa: S404, PLC0415 + import sys # noqa: PLC0415 - subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) + subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) # noqa: S603 TODO: add this to the container build - logging.info("Done installing dependencies for debuggers") + logger.info("Done installing dependencies for debuggers") - logging.info(f"Translator is set to use a debugger. Provided debug mode: {debug_mode}") + logger.info("Translator is set to use a debugger. Provided debug mode: %s", debug_mode) # We have to setup the debugger appropriately for various IDEs. It'd be nice if they all used the same thing but # sadly, we live in a fallen world. if debug_mode == "pycharm-pydevd": - logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") + logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") install_deps() @@ -38,66 +53,67 @@ def install_deps(): pydevd_pycharm.settrace("host.docker.internal", port=56782, stdoutToServer=True, stderrToServer=True) - logging.info("Debugger started.") + logger.info("Debugger started.") elif debug_mode == "debugpy": - logging.info("Entering debug mode for debugpy (VSCode)") + logger.info("Entering debug mode for debugpy (VSCode)") install_deps() import debugpy - debugpy.listen(("0.0.0.0", 56781)) + debugpy.listen(("0.0.0.0", 56781)) # noqa S104 (doesn't like binding to all interfaces) - logging.info("Debugger listening on port 56781.") + logger.info("Debugger listening on port 56781.") else: - logging.warning(f"Invalid debug mode given: {debug_mode}. Debugger not started") + logger.warning("Invalid debug mode given: %s. Debugger not started", debug_mode) # Must match the URL in asgi.py, and needs a trailing slash hostname = os.environ.get("SCRAM_HOSTNAME", "scram_hostname_not_set") url = os.environ.get("SCRAM_EVENTS_URL", "ws://django:8000/ws/route_manager/translator_block/") +async def process(message, websocket, g): + """Take a single message form the websocket and hand it off to the appropriate function.""" + json_message = json.loads(message) + event_type = json_message.get("type") + event_data = json_message.get("message") + if event_type not in KNOWN_MESSAGES: + logger.error("Unknown event type received: %s", event_type) + # TODO: Maybe only allow this in testing? + elif event_type == "translator_remove_all": + g.del_all_paths() + else: + try: + ip = ipaddress.ip_interface(event_data["route"]) + except: # noqa E722 + logger.exception("Error parsing message: %s", message) + return + + if event_type == "translator_add": + g.add_path(ip, event_data) + elif event_type == "translator_remove": + g.del_path(ip, event_data) + elif event_type == "translator_check": + json_message["type"] = "translator_check_resp" + json_message["message"]["is_blocked"] = g.is_blocked(ip) + json_message["message"]["translator_name"] = hostname + await websocket.send(json.dumps(json_message)) + + async def main(): + """Connect to the websocket and start listening for messages.""" g = GoBGP("gobgp:50051") async for websocket in websockets.connect(url): try: async for message in websocket: - json_message = json.loads(message) - event_type = json_message.get("type") - event_data = json_message.get("message") - if event_type not in [ - "translator_add", - "translator_remove", - "translator_remove_all", - "translator_check", - ]: - logging.error(f"Unknown event type received: {event_type!r}") - # TODO: Maybe only allow this in testing? - elif event_type == "translator_remove_all": - g.del_all_paths() - else: - try: - ip = ipaddress.ip_interface(event_data["route"]) - except: # noqa E722 - logging.error(f"Error parsing message: {message!r}") - continue - - if event_type == "translator_add": - g.add_path(ip, event_data) - elif event_type == "translator_remove": - g.del_path(ip, event_data) - elif event_type == "translator_check": - json_message["type"] = "translator_check_resp" - json_message["message"]["is_blocked"] = g.is_blocked(ip) - json_message["message"]["translator_name"] = hostname - await websocket.send(json.dumps(json_message)) + await process(message, websocket, g) except websockets.ConnectionClosed: continue if __name__ == "__main__": - logging.info("translator started") + logger.info("translator started") loop = asyncio.get_event_loop() loop.run_until_complete(main()) loop.close()