from AAA3A_utils import CogsUtils # isort:skip from redbot.core import commands # isort:skip import discord # isort:skip import typing # isort:skip import hmac import inspect from itsdangerous import BadData, SignatureExpired, URLSafeTimedSerializer from markupsafe import Markup from werkzeug.datastructures import ImmutableMultiDict from werkzeug.utils import cached_property from wtforms import ( BooleanField, Field, Form, FormField, HiddenField, SelectFieldBase, SelectMultipleField, SubmitField, ) # NOQA from wtforms.csrf.core import CSRF from wtforms.fields.core import UnboundField from wtforms.meta import DefaultMeta from wtforms.validators import ValidationError from wtforms.widgets import HiddenInput INITIAL_INIT_FIELD = Field.__init__ async def get_form_class( _self, third_party_cog: commands.Cog, method: typing.Literal["HEAD", "GET", "OPTIONS", "POST", "PATCH", "DELETE"], csrf_token: typing.Tuple[str, str], wtf_csrf_secret_key: bytes, data: typing.Dict[typing.Literal["form", "json"], typing.Dict[str, typing.Any]], **kwargs, ): extra_notifications = [] _Auto = object() def _is_submitted() -> bool: return method in {"POST", "PUT", "PATCH", "DELETE"} class _FlaskFormCSRF(CSRF): def setup_form(self, form) -> typing.List[typing.Tuple[str, UnboundField]]: self.meta = form.meta return super().setup_form(form) def generate_csrf_token(self, csrf_token_field) -> str: return csrf_token[1] def validate_csrf_token(self, form, field) -> None: # At this point, the CSRF token should be already validated by the webserver because of the field name in `request.form`. data = field.data secret_key = self.meta.csrf_secret time_limit = self.meta.csrf_time_limit if not data: raise ValidationError("The CSRF token is missing.") s = URLSafeTimedSerializer(secret_key, salt="wtf-csrf-token") try: token = s.loads(data, max_age=time_limit) except SignatureExpired as e: raise ValidationError("The CSRF token has expired.") from e except BadData as e: raise ValidationError("The CSRF token is invalid.") from e if not hmac.compare_digest(csrf_token[0], token): raise ValidationError("The CSRF tokens do not match.") class FlaskForm(Form): class Meta(DefaultMeta): csrf_class = _FlaskFormCSRF @cached_property def csrf(self) -> bool: return True @cached_property def csrf_secret(self) -> bytes: return wtf_csrf_secret_key @cached_property def csrf_time_limit(self) -> int: return 3600 def wrap_formdata(self, form, formdata) -> typing.Optional[ImmutableMultiDict]: if formdata is _Auto: if _is_submitted(): if data["form"]: return ImmutableMultiDict(data["form"]) elif data["json"]: return ImmutableMultiDict(data["json"]) return None return formdata def __init__(self, formdata=_Auto, **kwargs) -> None: super().__init__(formdata=formdata, **kwargs) def is_submitted(self) -> bool: return _is_submitted() def validate_on_submit(self, extra_validators=None) -> bool: if self.is_submitted() and self.validate(extra_validators=extra_validators): return True if any(field.data for field in self if isinstance(field, SubmitField)) and self.errors: for field_name, error_messages in self.errors.items(): if isinstance(error_messages[0], typing.Dict): for sub_field_name, sub_error_messages in error_messages[0].items(): extra_notifications.append( { "message": f"{field_name}-{sub_field_name}: {' '.join(sub_error_messages)}", "category": "warning", } ) continue extra_notifications.append( { "message": f"{field_name}: {' '.join(error_messages)}", "category": "warning", } ) return False async def validate_dpy_converters(self) -> bool: result = True for field in self: for validator in field.validators: if not isinstance(validator, DpyObjectConverter): continue if isinstance(field, SelectMultipleField): try: field.data = [await validator.convert(value) for value in field.data] except commands.BadArgument as e: extra_notifications.append( {"message": f"{field.name}: {e}", "category": "warning"} ) result = False continue if field.data is None or not field.data.strip(): field.data = "" continue try: field.data = await validator.convert(field.data) except commands.BadArgument as e: extra_notifications.append( {"message": f"{field.name}: {e}", "category": "warning"} ) result = False return result def hidden_tag(self, *fields) -> Markup: def hidden_fields(fields): for f in fields: if isinstance(f, str): f = getattr(self, f, None) if f is None or not isinstance(f.widget, HiddenInput): continue yield f return Markup("\n".join(str(f) for f in hidden_fields(fields or self))) def __str__(self) -> Markup: html_form = [ '