|
| 1 | +"""Input sanitization utilities for XSS and injection prevention""" |
| 2 | +import re |
| 3 | +import html |
| 4 | +from functools import wraps |
| 5 | +from flask import request, g |
| 6 | + |
| 7 | +# HTML tags that should never appear in user input |
| 8 | +DANGEROUS_PATTERNS = [ |
| 9 | + re.compile(r'<script[^>]*>.*?</script>', re.IGNORECASE | re.DOTALL), |
| 10 | + re.compile(r'javascript:', re.IGNORECASE), |
| 11 | + re.compile(r'on\w+\s*=', re.IGNORECASE), # onclick=, onerror=, etc. |
| 12 | + re.compile(r'<iframe[^>]*>', re.IGNORECASE), |
| 13 | + re.compile(r'<object[^>]*>', re.IGNORECASE), |
| 14 | + re.compile(r'<embed[^>]*>', re.IGNORECASE), |
| 15 | +] |
| 16 | + |
| 17 | +# Fields that should NOT be sanitized (e.g., passwords, file contents) |
| 18 | +SKIP_SANITIZATION_FIELDS = {'password', 'password_hash', 'current_password', 'new_password', 'confirm_password'} |
| 19 | + |
| 20 | + |
| 21 | +def sanitize_string(value): |
| 22 | + """ |
| 23 | + Sanitize a string value by escaping HTML entities. |
| 24 | +
|
| 25 | + Args: |
| 26 | + value: String to sanitize |
| 27 | +
|
| 28 | + Returns: |
| 29 | + Sanitized string with HTML entities escaped |
| 30 | + """ |
| 31 | + if not isinstance(value, str): |
| 32 | + return value |
| 33 | + |
| 34 | + # Escape HTML entities to prevent XSS |
| 35 | + sanitized = html.escape(value, quote=True) |
| 36 | + |
| 37 | + return sanitized |
| 38 | + |
| 39 | + |
| 40 | +def contains_dangerous_pattern(value): |
| 41 | + """ |
| 42 | + Check if a string contains potentially dangerous patterns. |
| 43 | +
|
| 44 | + Args: |
| 45 | + value: String to check |
| 46 | +
|
| 47 | + Returns: |
| 48 | + bool: True if dangerous pattern found |
| 49 | + """ |
| 50 | + if not isinstance(value, str): |
| 51 | + return False |
| 52 | + |
| 53 | + for pattern in DANGEROUS_PATTERNS: |
| 54 | + if pattern.search(value): |
| 55 | + return True |
| 56 | + return False |
| 57 | + |
| 58 | + |
| 59 | +def sanitize_dict(data, skip_fields=None): |
| 60 | + """ |
| 61 | + Recursively sanitize all string values in a dictionary. |
| 62 | +
|
| 63 | + Args: |
| 64 | + data: Dictionary to sanitize |
| 65 | + skip_fields: Set of field names to skip |
| 66 | +
|
| 67 | + Returns: |
| 68 | + Sanitized dictionary |
| 69 | + """ |
| 70 | + if skip_fields is None: |
| 71 | + skip_fields = SKIP_SANITIZATION_FIELDS |
| 72 | + |
| 73 | + if isinstance(data, dict): |
| 74 | + return { |
| 75 | + key: sanitize_dict(value, skip_fields) if key not in skip_fields else value |
| 76 | + for key, value in data.items() |
| 77 | + } |
| 78 | + elif isinstance(data, list): |
| 79 | + return [sanitize_dict(item, skip_fields) for item in data] |
| 80 | + elif isinstance(data, str): |
| 81 | + return sanitize_string(data) |
| 82 | + else: |
| 83 | + return data |
| 84 | + |
| 85 | + |
| 86 | +def get_sanitized_input(): |
| 87 | + """ |
| 88 | + Get sanitized versions of request form data, args, and JSON. |
| 89 | + Stores results in Flask's g object for reuse within request. |
| 90 | +
|
| 91 | + Returns: |
| 92 | + dict with 'form', 'args', and 'json' keys containing sanitized data |
| 93 | + """ |
| 94 | + if hasattr(g, '_sanitized_input'): |
| 95 | + return g._sanitized_input |
| 96 | + |
| 97 | + sanitized = { |
| 98 | + 'form': sanitize_dict(request.form.to_dict()) if request.form else {}, |
| 99 | + 'args': sanitize_dict(request.args.to_dict()) if request.args else {}, |
| 100 | + 'json': sanitize_dict(request.get_json(silent=True) or {}) if request.is_json else {} |
| 101 | + } |
| 102 | + |
| 103 | + g._sanitized_input = sanitized |
| 104 | + return sanitized |
| 105 | + |
| 106 | + |
| 107 | +def log_dangerous_input(field_name, value, logger=None): |
| 108 | + """ |
| 109 | + Log when dangerous input is detected (for security monitoring). |
| 110 | +
|
| 111 | + Args: |
| 112 | + field_name: Name of the field containing dangerous input |
| 113 | + value: The dangerous value (truncated for logging) |
| 114 | + logger: Logger instance to use |
| 115 | + """ |
| 116 | + if logger: |
| 117 | + truncated = value[:100] + '...' if len(value) > 100 else value |
| 118 | + logger.warning(f"Potentially dangerous input detected in field '{field_name}': {truncated}") |
| 119 | + |
| 120 | + |
| 121 | +class InputSanitizer: |
| 122 | + """ |
| 123 | + Middleware class for input sanitization. |
| 124 | + Can be configured with custom settings. |
| 125 | + """ |
| 126 | + |
| 127 | + def __init__(self, app=None, logger=None): |
| 128 | + self.app = app |
| 129 | + self.logger = logger |
| 130 | + self.enabled = True |
| 131 | + |
| 132 | + if app is not None: |
| 133 | + self.init_app(app) |
| 134 | + |
| 135 | + def init_app(self, app): |
| 136 | + """Initialize the sanitizer with a Flask app.""" |
| 137 | + self.app = app |
| 138 | + |
| 139 | + @app.before_request |
| 140 | + def sanitize_request_input(): |
| 141 | + """Pre-process and flag dangerous input before route handlers.""" |
| 142 | + if not self.enabled: |
| 143 | + return |
| 144 | + |
| 145 | + # Check form data |
| 146 | + if request.form: |
| 147 | + for key, value in request.form.items(): |
| 148 | + if key not in SKIP_SANITIZATION_FIELDS and contains_dangerous_pattern(value): |
| 149 | + log_dangerous_input(key, value, self.logger) |
| 150 | + g._has_dangerous_input = True |
| 151 | + |
| 152 | + # Check query args |
| 153 | + if request.args: |
| 154 | + for key, value in request.args.items(): |
| 155 | + if contains_dangerous_pattern(value): |
| 156 | + log_dangerous_input(key, value, self.logger) |
| 157 | + g._has_dangerous_input = True |
| 158 | + |
| 159 | + # Check JSON body |
| 160 | + if request.is_json: |
| 161 | + json_data = request.get_json(silent=True) |
| 162 | + if json_data: |
| 163 | + self._check_json_recursive(json_data) |
| 164 | + |
| 165 | + # Pre-compute sanitized input for use in routes |
| 166 | + get_sanitized_input() |
| 167 | + |
| 168 | + def _check_json_recursive(self, data, path=''): |
| 169 | + """Recursively check JSON data for dangerous patterns.""" |
| 170 | + if isinstance(data, dict): |
| 171 | + for key, value in data.items(): |
| 172 | + current_path = f"{path}.{key}" if path else key |
| 173 | + if key not in SKIP_SANITIZATION_FIELDS: |
| 174 | + self._check_json_recursive(value, current_path) |
| 175 | + elif isinstance(data, list): |
| 176 | + for i, item in enumerate(data): |
| 177 | + self._check_json_recursive(item, f"{path}[{i}]") |
| 178 | + elif isinstance(data, str): |
| 179 | + if contains_dangerous_pattern(data): |
| 180 | + log_dangerous_input(path, data, self.logger) |
| 181 | + g._has_dangerous_input = True |
0 commit comments