You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

client.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import signal
  3. import subprocess
  4. from django.core.files.temp import NamedTemporaryFile
  5. from django.db.backends.base.client import BaseDatabaseClient
  6. def _escape_pgpass(txt):
  7. """
  8. Escape a fragment of a PostgreSQL .pgpass file.
  9. """
  10. return txt.replace('\\', '\\\\').replace(':', '\\:')
  11. class DatabaseClient(BaseDatabaseClient):
  12. executable_name = 'psql'
  13. @classmethod
  14. def runshell_db(cls, conn_params):
  15. args = [cls.executable_name]
  16. host = conn_params.get('host', '')
  17. port = conn_params.get('port', '')
  18. dbname = conn_params.get('database', '')
  19. user = conn_params.get('user', '')
  20. passwd = conn_params.get('password', '')
  21. if user:
  22. args += ['-U', user]
  23. if host:
  24. args += ['-h', host]
  25. if port:
  26. args += ['-p', str(port)]
  27. args += [dbname]
  28. temp_pgpass = None
  29. sigint_handler = signal.getsignal(signal.SIGINT)
  30. try:
  31. if passwd:
  32. # Create temporary .pgpass file.
  33. temp_pgpass = NamedTemporaryFile(mode='w+')
  34. try:
  35. print(
  36. _escape_pgpass(host) or '*',
  37. str(port) or '*',
  38. _escape_pgpass(dbname) or '*',
  39. _escape_pgpass(user) or '*',
  40. _escape_pgpass(passwd),
  41. file=temp_pgpass,
  42. sep=':',
  43. flush=True,
  44. )
  45. os.environ['PGPASSFILE'] = temp_pgpass.name
  46. except UnicodeEncodeError:
  47. # If the current locale can't encode the data, let the
  48. # user input the password manually.
  49. pass
  50. # Allow SIGINT to pass to psql to abort queries.
  51. signal.signal(signal.SIGINT, signal.SIG_IGN)
  52. subprocess.check_call(args)
  53. finally:
  54. # Restore the original SIGINT handler.
  55. signal.signal(signal.SIGINT, sigint_handler)
  56. if temp_pgpass:
  57. temp_pgpass.close()
  58. if 'PGPASSFILE' in os.environ: # unit tests need cleanup
  59. del os.environ['PGPASSFILE']
  60. def runshell(self):
  61. DatabaseClient.runshell_db(self.connection.get_connection_params())