Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

032 advertised address optimization #1136

Draft
wants to merge 6 commits into
base: 5.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

if t.TYPE_CHECKING:
from ..._api import TelemetryAPI
from ...addressing import Address


# Set up logger
Expand Down Expand Up @@ -135,9 +136,12 @@ class AsyncBolt:
# results for it.
most_recent_qid = None

address_callback = None
advertised_address: Address | None = None

def __init__(
self,
unresolved_address,
address,
sock,
max_connection_lifetime,
*,
Expand All @@ -149,12 +153,12 @@ def __init__(
notifications_disabled_classifications=None,
telemetry_disabled=False,
):
self.unresolved_address = unresolved_address
self._address = address
self.socket = sock
self.local_port = self.socket.getsockname()[1]
self.server_info = ServerInfo(
ResolvedAddress(
sock.getpeername(), host_name=unresolved_address.host
sock.getpeername(), host_name=address._unresolved.host
),
self.PROTOCOL_VERSION,
)
Expand Down Expand Up @@ -200,6 +204,15 @@ def __del__(self):
if not asyncio.iscoroutinefunction(self.close):
self.close()

@property
def address(self):
return self._address

@address.setter
def address(self, value):
self._address = value
self.server_info._address = value._unresolved

@abc.abstractmethod
def _get_server_state_manager(self) -> ServerStateManagerBase: ...

Expand Down Expand Up @@ -308,6 +321,7 @@ def protocol_handlers(cls, protocol_version=None):
AsyncBolt5x5,
AsyncBolt5x6,
AsyncBolt5x7,
AsyncBolt5x8,
)

handlers = {
Expand All @@ -325,6 +339,7 @@ def protocol_handlers(cls, protocol_version=None):
AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5,
AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6,
AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7,
AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8,
}

if protocol_version is None:
Expand Down Expand Up @@ -461,7 +476,10 @@ async def open(

# avoid new lines after imports for better readability and conciseness
# fmt: off
if protocol_version == (5, 7):
if protocol_version == (5, 8):
from ._bolt5 import AsyncBolt5x8
bolt_cls = AsyncBolt5x8
elif protocol_version == (5, 7):
from ._bolt5 import AsyncBolt5x7
bolt_cls = AsyncBolt5x7
elif protocol_version == (5, 6):
Expand Down Expand Up @@ -954,12 +972,12 @@ async def send_all(self):
if self.closed():
raise ServiceUnavailable(
"Failed to write to closed connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)
if self.defunct():
raise ServiceUnavailable(
"Failed to write to defunct connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)

await self._send_all()
Expand All @@ -977,12 +995,12 @@ async def fetch_message(self):
if self._closed:
raise ServiceUnavailable(
"Failed to read from closed connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)
if self._defunct:
raise ServiceUnavailable(
"Failed to read from defunct connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)
if not self.responses:
return 0, 0
Expand Down Expand Up @@ -1014,14 +1032,14 @@ async def fetch_all(self):
async def _set_defunct_read(self, error=None, silent=False):
message = (
"Failed to read from defunct connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)
await self._set_defunct(message, error=error, silent=silent)

async def _set_defunct_write(self, error=None, silent=False):
message = (
"Failed to write data to connection "
f"{self.unresolved_address!r} ({self.server_info.address!r})"
f"{self.address!r} ({self.server_info.address!r})"
)
await self._set_defunct(message, error=error, silent=silent)

Expand Down Expand Up @@ -1060,7 +1078,7 @@ async def _set_defunct(self, message, error=None, silent=False):
# connection again.
await self.close()
if self.pool and not self._get_server_state_manager().failed():
await self.pool.deactivate(address=self.unresolved_address)
await self.pool.deactivate(address=self.address)

# Iterate through the outstanding responses, and if any correspond
# to COMMIT requests then raise an error to signal that we are
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j/_async/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,12 @@ async def _process_message(self, tag, fields):
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
await self.pool.deactivate(address=self.unresolved_address)
await self.pool.deactivate(address=self.address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address,
address=self.address,
database=self.last_database,
)
raise
Expand All @@ -595,7 +595,7 @@ async def _process_message(self, tag, fields):
sig_int = ord(summary_signature)
raise BoltProtocolError(
f"Unexpected response message with signature {sig_int:02X}",
self.unresolved_address,
self.address,
)

return len(details), 1
6 changes: 3 additions & 3 deletions src/neo4j/_async/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,12 @@ async def _process_message(self, tag, fields):
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
await self.pool.deactivate(address=self.unresolved_address)
await self.pool.deactivate(address=self.address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address,
address=self.address,
database=self.last_database,
)
raise
Expand All @@ -511,7 +511,7 @@ async def _process_message(self, tag, fields):
sig_int = ord(summary_signature)
raise BoltProtocolError(
f"Unexpected response message with signature {sig_int:02X}",
self.unresolved_address,
self.address,
)

return len(details), 1
Expand Down
57 changes: 51 additions & 6 deletions src/neo4j/_async/io/_bolt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..._codec.hydration import v2 as hydration_v2
from ..._exceptions import BoltProtocolError
from ..._meta import BOLT_AGENT_DICT
from ...addressing import Address
from ...api import (
READ_ACCESS,
Version,
Expand Down Expand Up @@ -496,12 +497,12 @@ async def _process_message(self, tag, fields):
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
await self.pool.deactivate(address=self.unresolved_address)
await self.pool.deactivate(address=self.address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address,
address=self.address,
database=self.last_database,
)
raise
Expand All @@ -513,7 +514,7 @@ async def _process_message(self, tag, fields):
sig_int = ord(summary_signature)
raise BoltProtocolError(
f"Unexpected response message with signature {sig_int:02X}",
self.unresolved_address,
self.address,
)

return len(details), 1
Expand Down Expand Up @@ -1204,12 +1205,12 @@ async def _process_message(self, tag, fields):
await response.on_failure(summary_metadata or {})
except (ServiceUnavailable, DatabaseUnavailable):
if self.pool:
await self.pool.deactivate(address=self.unresolved_address)
await self.pool.deactivate(address=self.address)
raise
except (NotALeader, ForbiddenOnReadOnlyDatabase):
if self.pool:
await self.pool.on_write_failure(
address=self.unresolved_address,
address=self.address,
database=self.last_database,
)
raise
Expand All @@ -1221,7 +1222,51 @@ async def _process_message(self, tag, fields):
sig_int = ord(summary_signature)
raise BoltProtocolError(
f"Unexpected response message with signature {sig_int:02X}",
self.unresolved_address,
self.address,
)

return len(details), 1


class AsyncBolt5x8(AsyncBolt5x7):
PROTOCOL_VERSION = Version(5, 8)

def logon(self, dehydration_hooks=None, hydration_hooks=None):
dehydration_hooks, hydration_hooks = self._default_hydration_hooks(
dehydration_hooks, hydration_hooks
)
logged_auth_dict = dict(self.auth_dict)
if "credentials" in logged_auth_dict:
logged_auth_dict["credentials"] = "*******"
log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict)
self._append(
b"\x6a",
(self.auth_dict,),
response=LogonResponse(
self, "logon", hydration_hooks, on_success=self._logon_success
),
dehydration_hooks=dehydration_hooks,
)

async def _logon_success(self, meta: object) -> None:
if not isinstance(meta, dict):
log.warning(
"[#%04X] _: <NON-FATAL PROTOCOL VIOLATION> "
"LOGON expected dictionary metadata, got %r",
self.local_port,
meta,
)
return
address = meta.get("advertised_address", ...)
if address is ...:
return
if not isinstance(address, str):
log.warning(
"[#%04X] _: <NON-FATAL PROTOCOL VIOLATION> "
"LOGON expected string advertised_address, got %r",
self.local_port,
address,
)
return
self.advertised_address = Address.parse(address, default_port=7687)
await AsyncUtil.callback(self.address_callback, self)
Loading