Source code for schedula.utils.form

# coding=utf-8
# -*- coding: UTF-8 -*-
#
# Copyright 2015-2024, Vincenzo Arcidiacono;
# Licensed under the EUPL (the 'Licence');
# You may not use this work except in compliance with the Licence.
# You may obtain a copy of the Licence at: http://ec.europa.eu/idabc/eupl

"""
It provides functions to build a form flask app from a dispatcher.

Sub-Modules:

.. currentmodule:: schedula.utils.form

.. autosummary::
    :nosignatures:
    :toctree: form/

    cli
    config
    gapp
    json_secrets
    mail
    server
"""
import io
import os
import gzip
import glob
import hmac
import json
import secrets
import hashlib
import datetime
import mimetypes
import webbrowser
import os.path as osp
from ..web import WebMap
from . import json_secrets
from urllib.parse import urlparse
from jinja2 import TemplateNotFound
from werkzeug.exceptions import NotFound
from itsdangerous import URLSafeTimedSerializer, BadData
from .server import Config, basic_app, default_get_form_context
from flask import (
    render_template, Blueprint, current_app, session, g, request, send_file,
    jsonify
)
from flask_babel import get_locale

__author__ = 'Vincenzo Arcidiacono <vinci1it2000@gmail.com>'

static_dir = osp.join(osp.dirname(__file__), 'static')

static_context = {
    f'main_{k}': osp.relpath(glob.glob(osp.join(
        static_dir, 'schedula', k, f'main.*.{k}.gz'
    ))[0], osp.join(static_dir, 'schedula')).replace('\\', '/')
    for k in ('js', 'css')
}
static_context = {
    k: v[:-3] if v.endswith('.gz') else v for k, v in static_context.items()
}


[docs] class FormMap(WebMap): _get_basic_app_config = Config
[docs] def get_form_context(self): context = default_get_form_context().copy() if hasattr(self, '_get_form_context'): context.update(self._get_form_context()) return context
def _get_form_data(self): return @staticmethod def _view(url, *args, **kwargs): webbrowser.open(url) csrf_defaults = { 'CSRF_FIELD_NAME': 'CSRF_token', 'CSRF_SECRET_KEY': lambda: current_app.secret_key, 'CSRF_TIME_LIMIT': 3600, 'CSRF_HEADERS': {'X-CSRFToken', 'X-CSRF-Token'}, 'CSRF_AUTO_REFRESH_HEADER': 'N-CSRF-Token', 'CSRF_ENABLED': True, 'CSRF_METHODS': {'POST', 'PUT', 'PATCH', 'DELETE'}, 'CSRF_SSL_STRICT': True } csrf_required = { 'CSRF_FIELD_NAME': 'A field name is required to use CSRF.', 'CSRF_SECRET_KEY': 'A secret key is required to use CSRF.', 'CSRF_HEADERS': 'A valid headers is required to use CSRF.', 'CSRF_METHODS': 'A valid request methods is required to use CSRF.' }
[docs] def __init__(self): super(FormMap, self).__init__() self._csrf_protected = set() self.url_prefix = os.environ.get('SCHEDULA_FORM_URL_PREFIX', '')
def _config(self, config_name): value = current_app.config.get( config_name, self.csrf_defaults[config_name] ) if hasattr(value, '__call__'): value = value() if value is None and config_name in self.csrf_required: raise RuntimeError(self.csrf_required[config_name]) return value def __getattr__(self, item): if item.startswith('get_') and hasattr(self, f'_{item}'): attr = getattr(self, f'_{item}') if isinstance(attr, dict): attr = attr.get(request.path, attr.get( None, getattr(self.__class__, f'_{item}') )) if hasattr(attr, '__call__'): return attr return lambda: attr return super(FormMap, self).__getattr__(item)
[docs] def render_form(self, form='index'): template = f'schedula/{form}.html' context = { 'name': form, 'form_id': form, 'form': self, 'app': current_app, 'get_locale': get_locale } context.update(static_context) try: return render_template(template, **context) except TemplateNotFound: # noinspection PyUnresolvedReferences return render_template('schedula/base.html', **context)
def _csrf_token(self): field_name = self._config('CSRF_FIELD_NAME') base_token = request.form.get(field_name) if base_token: return base_token # if the form has a prefix, the name will be {prefix}-csrf_token for key in request.form: if key.endswith(field_name): csrf_token = request.form[key] if csrf_token: return csrf_token # find the token in the headers for header_name in self._config('CSRF_HEADERS'): csrf_token = request.headers.get(header_name) if csrf_token: return csrf_token return None
[docs] def generate_csrf(self): if self._config('CSRF_ENABLED'): field_name = self._config('CSRF_FIELD_NAME') if field_name not in g: secret_key = self._config('CSRF_SECRET_KEY') s = URLSafeTimedSerializer(secret_key, salt='csrf-token') if field_name not in session: session[field_name] = hashlib.sha1( os.urandom(64) ).hexdigest() try: token = s.dumps(session[field_name]) except TypeError: session[field_name] = hashlib.sha1( os.urandom(64) ).hexdigest() token = s.dumps(session[field_name]) setattr(g, field_name, token) return g.get(field_name)
[docs] def add_headers(self, resp): if g.get('csrf_refresh'): token = self.generate_csrf() g.csrf_refresh = False if token: resp.headers[self._config('CSRF_AUTO_REFRESH_HEADER')] = token return resp
[docs] def validate_csrf(self): if (not self._config('CSRF_ENABLED') or request.method not in self._config('CSRF_METHODS') or not request.endpoint or not ( ('view', request.endpoint) in self._csrf_protected or ('bp', request.blueprint) in self._csrf_protected )): return token = self._csrf_token() if not token: return jsonify({'error': 'The CSRF token is missing.'}) field_name = self._config('CSRF_FIELD_NAME') if field_name not in session: return jsonify({'error': 'The CSRF session token is missing.'}) secret_key = self._config('CSRF_SECRET_KEY') s = URLSafeTimedSerializer(secret_key, salt='csrf-token') try: token, timestamp = s.loads(token, return_timestamp=True) except BadData: return jsonify({'error': 'The CSRF token is invalid.'}) if not hmac.compare_digest(session[field_name], token): return jsonify({'error': 'The CSRF tokens do not match.'}) if request.is_secure and self._config('CSRF_SSL_STRICT'): if not request.referrer: return jsonify({'error': 'The referrer header is missing.'}) c = urlparse(request.referrer) r = urlparse(f'https://{request.host}/') if not all(( c.scheme == r.scheme, c.hostname == r.hostname, c.port == r.port )): return jsonify({ 'error': 'The referrer does not match the host.' }) time_limit = self._config('CSRF_TIME_LIMIT') or 0 if time_limit >= 0: now = datetime.datetime.now(tz=datetime.timezone.utc) if not (0 <= (now - timestamp).total_seconds() <= time_limit): g.csrf_refresh = True g.csrf_valid = True # mark this request as CSRF valid
[docs] @staticmethod def send_static_file(filename): is_form = filename.startswith('forms') filename = f'schedula/{filename}'.split('/') download_name = filename[-1] kw = { 'conditional': True, 'download_name': download_name, 'max_age': current_app.get_send_file_max_age(download_name) } gzipped = 'gzip' in request.headers.get('Accept-Encoding', '').lower() for i, sdir in enumerate((current_app.static_folder, static_dir)): sdir = osp.join(sdir, *filename[:-1]) for ext in ('.gz', '')[::gzipped and 1 or -1]: fn = f'{download_name}{ext}' fp = osp.join(sdir, fn) if osp.exists(fp): with open(fp, "rb") as f: if is_form: data = json_secrets.dumps(json.load(f)).encode() else: data = f.read() if gzipped != bool(ext): func = gzipped and gzip.compress or gzip.decompress f = io.BytesIO(func(data)) fn = gzipped and f'{fn}.gz' or download_name else: f = io.BytesIO(data) kw['last_modified'] = os.stat(fp).st_mtime mimetype, encoding = mimetypes.guess_type(fn) response = send_file(f, **kw) if i: response.cache_control.immutable = True response.cache_control.public = True response.cache_control.max_age = 946080000 # 30 years. response.content_type = mimetype response.content_encoding = encoding return response raise NotFound
[docs] def add2csrf_protected(self, app=None, item=None): if item: self._csrf_protected.add(item) elif isinstance(app, Blueprint): self._csrf_protected.add(('bp', app.name)) else: if app.secret_key is None: app.secret_key = secrets.token_hex(32) for endpoint, func in app.view_functions.items(): if not getattr(func, 'csrf_exempt', False): self._csrf_protected.add(('view', endpoint)) return app
[docs] def app(self, root_path=None, depth=1, mute=False, blueprint_name=None, **kwargs): app = super(FormMap, self).app( root_path=root_path, depth=depth, mute=mute, blueprint_name=blueprint_name, **kwargs ) self.add2csrf_protected(app) app.before_request(self.validate_csrf) app.after_request(self.add_headers) bp = Blueprint( 'schedula', __name__, template_folder='templates' ) bp.add_url_rule('/<form>', 'render_form', self.render_form) bp.add_url_rule('/', 'render_form') bp.add_url_rule( '/static/schedula/<path:filename>', 'static', self.send_static_file ) bp.add_url_rule('/static/schedula/<string:filename>', 'static') app.register_blueprint(bp) return app
[docs] def basic_app(self, root_path, mute=True, blueprint_name=None, **kwargs): app = super(FormMap, self).basic_app( root_path, mute=mute, blueprint_name=blueprint_name, **kwargs ) if blueprint_name is None: app = basic_app(self, app) return app