77
88import asyncio
99import collections
10+ import enum
1011import functools
1112import getpass
1213import os
2829from . import protocol
2930
3031
32+ class SSLMode(enum.IntEnum):
33+ disable = 0
34+ allow = 1
35+ prefer = 2
36+ require = 3
37+ verify_ca = 4
38+ verify_full = 5
39+
40+ @classmethod
41+ def parse(cls, sslmode):
42+ if isinstance(sslmode, cls):
43+ return sslmode
44+ return getattr(cls, sslmode.replace('-', '_'))
45+
46+
3147_ConnectionParameters = collections.namedtuple(
3248 'ConnectionParameters',
3349 [
3450 'user',
3551 'password',
3652 'database',
3753 'ssl',
38- 'ssl_is_advisory ',
54+ 'sslmode ',
3955 'connect_timeout',
4056 'server_settings',
4157 ])
@@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402418 if ssl is None and have_tcp_addrs:
403419 ssl = 'prefer'
404420
405- # ssl_is_advisory is only allowed to come from the sslmode parameter.
406- ssl_is_advisory = None
407- if isinstance(ssl, str):
408- SSLMODES = {
409- 'disable': 0,
410- 'allow': 1,
411- 'prefer': 2,
412- 'require': 3,
413- 'verify-ca': 4,
414- 'verify-full': 5,
415- }
421+ if isinstance(ssl, (str, SSLMode)):
416422 try:
417- sslmode = SSLMODES[ ssl]
418- except KeyError :
419- modes = ', '.join(SSLMODES.keys() )
423+ sslmode = SSLMode.parse( ssl)
424+ except AttributeError :
425+ modes = ', '.join(m.name.replace('_', '-') for m in SSLMode )
420426 raise exceptions.InterfaceError(
421427 '`sslmode` parameter must be one of: {}'.format(modes))
422428
423- # sslmode 'allow' is currently handled as 'prefer' because we're
424- # missing the "retry with SSL" behavior for 'allow', but do have the
425- # "retry without SSL" behavior for 'prefer'.
426- # Not changing 'allow' to 'prefer' here would be effectively the same
427- # as changing 'allow' to 'disable'.
428- if sslmode == SSLMODES['allow']:
429- sslmode = SSLMODES['prefer']
430-
431429 # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432430 # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433- if sslmode <= SSLMODES[' allow'] :
431+ if sslmode < SSLMode. allow:
434432 ssl = False
435- ssl_is_advisory = sslmode >= SSLMODES['allow']
436433 else:
437434 ssl = ssl_module.create_default_context()
438- ssl.check_hostname = sslmode >= SSLMODES['verify-full']
435+ ssl.check_hostname = sslmode >= SSLMode.verify_full
439436 ssl.verify_mode = ssl_module.CERT_REQUIRED
440- if sslmode <= SSLMODES[' require'] :
437+ if sslmode <= SSLMode. require:
441438 ssl.verify_mode = ssl_module.CERT_NONE
442- ssl_is_advisory = sslmode <= SSLMODES['prefer']
443439 elif ssl is True:
444440 ssl = ssl_module.create_default_context()
441+ sslmode = SSLMode.verify_full
442+ else:
443+ sslmode = SSLMode.disable
445444
446445 if server_settings is not None and (
447446 not isinstance(server_settings, dict) or
@@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453452
454453 params = _ConnectionParameters(
455454 user=user, password=password, database=database, ssl=ssl,
456- ssl_is_advisory=ssl_is_advisory , connect_timeout=connect_timeout,
455+ sslmode=sslmode , connect_timeout=connect_timeout,
457456 server_settings=server_settings)
458457
459458 return addrs, params
@@ -520,9 +519,8 @@ def data_received(self, data):
520519 data == b'N'):
521520 # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522521 # since the only way to get ssl_is_advisory is from
523- # sslmode=prefer (or sslmode=allow). But be extra sure to
524- # disallow insecure connections when the ssl context asks for
525- # real security.
522+ # sslmode=prefer. But be extra sure to disallow insecure
523+ # connections when the ssl context asks for real security.
526524 self.on_data.set_result(False)
527525 else:
528526 self.on_data.set_exception(
@@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566564 new_tr = tr
567565
568566 pg_proto = protocol_factory()
567+ pg_proto.is_ssl = do_ssl_upgrade
569568 pg_proto.connection_made(new_tr)
570569 new_tr.set_protocol(pg_proto)
571570
@@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584583 tr.close()
585584
586585 try:
587- return await conn_factory(sock=sock)
586+ new_tr, pg_proto = await conn_factory(sock=sock)
587+ pg_proto.is_ssl = do_ssl_upgrade
588+ return new_tr, pg_proto
588589 except (Exception, asyncio.CancelledError):
589590 sock.close()
590591 raise
@@ -605,8 +606,6 @@ async def _connect_addr(
605606 if timeout <= 0:
606607 raise asyncio.TimeoutError
607608
608- connected = _create_future(loop)
609-
610609 params_input = params
611610 if callable(params.password):
612611 if inspect.iscoroutinefunction(params.password):
@@ -615,6 +614,49 @@ async def _connect_addr(
615614 password = params.password()
616615
617616 params = params._replace(password=password)
617+ args = (addr, loop, config, connection_class, record_class, params_input)
618+
619+ # prepare the params (which attempt has ssl) for the 2 attempts
620+ if params.sslmode == SSLMode.allow:
621+ params_retry = params
622+ params = params._replace(ssl=None)
623+ elif params.sslmode == SSLMode.prefer:
624+ params_retry = params._replace(ssl=None)
625+ else:
626+ # skip retry if we don't have to
627+ return await __connect_addr(params, timeout, False, *args)
628+
629+ # first attempt
630+ before = time.monotonic()
631+ try:
632+ return await __connect_addr(params, timeout, True, *args)
633+ except _Retry:
634+ pass
635+
636+ # second attempt
637+ timeout -= time.monotonic() - before
638+ if timeout <= 0:
639+ raise asyncio.TimeoutError
640+ else:
641+ return await __connect_addr(params_retry, timeout, False, *args)
642+
643+
644+ class _Retry(Exception):
645+ pass
646+
647+
648+ async def __connect_addr(
649+ params,
650+ timeout,
651+ retry,
652+ addr,
653+ loop,
654+ config,
655+ connection_class,
656+ record_class,
657+ params_input,
658+ ):
659+ connected = _create_future(loop)
618660
619661 proto_factory = lambda: protocol.Protocol(
620662 addr, connected, params, record_class, loop)
@@ -625,7 +667,7 @@ async def _connect_addr(
625667 elif params.ssl:
626668 connector = _create_ssl_connection(
627669 proto_factory, *addr, loop=loop, ssl_context=params.ssl,
628- ssl_is_advisory=params.ssl_is_advisory )
670+ ssl_is_advisory=params.sslmode == SSLMode.prefer )
629671 else:
630672 connector = loop.create_connection(proto_factory, *addr)
631673
@@ -638,6 +680,35 @@ async def _connect_addr(
638680 if timeout <= 0:
639681 raise asyncio.TimeoutError
640682 await compat.wait_for(connected, timeout=timeout)
683+ except (
684+ exceptions.InvalidAuthorizationSpecificationError,
685+ exceptions.ConnectionDoesNotExistError, # seen on Windows
686+ ):
687+ tr.close()
688+
689+ # retry=True here is a redundant check because we don't want to
690+ # accidentally raise the internal _Retry to the outer world
691+ if retry and (
692+ params.sslmode == SSLMode.allow and not pr.is_ssl or
693+ params.sslmode == SSLMode.prefer and pr.is_ssl
694+ ):
695+ # Trigger retry when:
696+ # 1. First attempt with sslmode=allow, ssl=None failed
697+ # 2. First attempt with sslmode=prefer, ssl=ctx failed while the
698+ # server claimed to support SSL (returning "S" for SSLRequest)
699+ # (likely because pg_hba.conf rejected the connection)
700+ raise _Retry()
701+
702+ else:
703+ # but will NOT retry if:
704+ # 1. First attempt with sslmode=prefer failed but the server
705+ # doesn't support SSL (returning 'N' for SSLRequest), because
706+ # we already tried to connect without SSL thru ssl_is_advisory
707+ # 2. Second attempt with sslmode=prefer, ssl=None failed
708+ # 3. Second attempt with sslmode=allow, ssl=ctx failed
709+ # 4. Any other sslmode
710+ raise
711+
641712 except (Exception, asyncio.CancelledError):
642713 tr.close()
643714 raise
@@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):
684755
685756 def __init__(self):
686757 self.on_disconnect = _create_future(loop)
758+ self.is_ssl = False
687759
688760 def connection_lost(self, exc):
689761 if not self.on_disconnect.done():
@@ -692,13 +764,13 @@ def connection_lost(self, exc):
692764 if isinstance(addr, str):
693765 tr, pr = await loop.create_unix_connection(CancelProto, addr)
694766 else:
695- if params.ssl:
767+ if params.ssl and params.sslmode != SSLMode.allow :
696768 tr, pr = await _create_ssl_connection(
697769 CancelProto,
698770 *addr,
699771 loop=loop,
700772 ssl_context=params.ssl,
701- ssl_is_advisory=params.ssl_is_advisory )
773+ ssl_is_advisory=params.sslmode == SSLMode.prefer )
702774 else:
703775 tr, pr = await loop.create_connection(
704776 CancelProto, *addr)
0 commit comments