123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from __future__ import absolute_import, unicode_literals
-
- import hashlib
- import hmac
- import json
-
- from django import forms
- from django.conf import settings
- from django.core.exceptions import ValidationError
- from django.db import connections
- from django.utils.crypto import constant_time_compare
- from django.utils.encoding import force_bytes
- from django.utils.functional import cached_property
-
- from debug_toolbar.panels.sql.utils import reformat_sql
-
-
- class SQLSelectForm(forms.Form):
- """
- Validate params
-
- sql: The sql statement with interpolated params
- raw_sql: The sql statement with placeholders
- params: JSON encoded parameter values
- duration: time for SQL to execute passed in from toolbar just for redisplay
- hash: the hash of (secret + sql + params) for tamper checking
- """
- sql = forms.CharField()
- raw_sql = forms.CharField()
- params = forms.CharField()
- alias = forms.CharField(required=False, initial='default')
- duration = forms.FloatField()
- hash = forms.CharField()
-
- def __init__(self, *args, **kwargs):
- initial = kwargs.get('initial', None)
-
- if initial is not None:
- initial['hash'] = self.make_hash(initial)
-
- super(SQLSelectForm, self).__init__(*args, **kwargs)
-
- for name in self.fields:
- self.fields[name].widget = forms.HiddenInput()
-
- def clean_raw_sql(self):
- value = self.cleaned_data['raw_sql']
-
- if not value.lower().strip().startswith('select'):
- raise ValidationError("Only 'select' queries are allowed.")
-
- return value
-
- def clean_params(self):
- value = self.cleaned_data['params']
-
- try:
- return json.loads(value)
- except ValueError:
- raise ValidationError('Is not valid JSON')
-
- def clean_alias(self):
- value = self.cleaned_data['alias']
-
- if value not in connections:
- raise ValidationError("Database alias '%s' not found" % value)
-
- return value
-
- def clean_hash(self):
- hash = self.cleaned_data['hash']
-
- if not constant_time_compare(hash, self.make_hash(self.data)):
- raise ValidationError('Tamper alert')
-
- return hash
-
- def reformat_sql(self):
- return reformat_sql(self.cleaned_data['sql'])
-
- def make_hash(self, data):
- m = hmac.new(key=force_bytes(settings.SECRET_KEY), digestmod=hashlib.sha1)
- for item in [data['sql'], data['params']]:
- m.update(force_bytes(item))
- return m.hexdigest()
-
- @property
- def connection(self):
- return connections[self.cleaned_data['alias']]
-
- @cached_property
- def cursor(self):
- return self.connection.cursor()
|