summaryrefslogtreecommitdiffstats
path: root/awx/main/wsrelay.py
diff options
context:
space:
mode:
authorRick Elrod <rick@elrod.me>2023-05-30 18:31:49 +0200
committerJohn Westcott IV <32551173+john-westcott-iv@users.noreply.github.com>2023-06-14 23:40:15 +0200
commitf9bd780d62d5dd8502921c6008949591426e0642 (patch)
tree44ce09debb67d0c43ea502ec9210b84c321522c2 /awx/main/wsrelay.py
parentReplacing psycopg2.copy_expert with psycopg3.copy (diff)
downloadawx-f9bd780d62d5dd8502921c6008949591426e0642.tar.xz
awx-f9bd780d62d5dd8502921c6008949591426e0642.zip
[wsrelay] Port back to psycopg3
Signed-off-by: Rick Elrod <rick@elrod.me>
Diffstat (limited to '')
-rw-r--r--awx/main/wsrelay.py100
1 files changed, 52 insertions, 48 deletions
diff --git a/awx/main/wsrelay.py b/awx/main/wsrelay.py
index 9ce5d9e22d..20b04433d0 100644
--- a/awx/main/wsrelay.py
+++ b/awx/main/wsrelay.py
@@ -12,7 +12,7 @@ from channels.layers import get_channel_layer
from django.conf import settings
from django.apps import apps
-import asyncpg
+import psycopg
from awx.main.analytics.broadcast_websocket import (
RelayWebsocketStats,
@@ -209,51 +209,55 @@ class WebSocketRelayManager(object):
# hostname -> ip
self.known_hosts: Dict[str, str] = dict()
- async def on_ws_heartbeat(self, conn, pid, channel, payload):
- try:
- if not payload or channel != "web_ws_heartbeat":
- return
-
+ async def on_ws_heartbeat(self, conn):
+ await conn.execute("LISTEN web_ws_heartbeat")
+ async for notif in conn.notifies():
+ if notif is None:
+ continue
try:
- payload = json.loads(payload)
- except json.JSONDecodeError:
- logmsg = "Failed to decode message from pg_notify channel `web_ws_heartbeat`"
- if logger.isEnabledFor(logging.DEBUG):
- logmsg = "{} {}".format(logmsg, payload)
- logger.warning(logmsg)
- return
-
- # Skip if the message comes from the same host we are running on
- # In this case, we'll be sharing a redis, no need to relay.
- if payload.get("hostname") == self.local_hostname:
- return
-
- if payload.get("action") == "online":
- hostname = payload.get("hostname")
- ip = payload.get("ip")
- if ip is None:
- # If we don't get an IP, just try the hostname, maybe it resolves
- ip = hostname
- if ip is None:
- logger.warning(f"Received invalid online ws_heartbeat, missing hostname and ip: {payload}")
+ if not notif.payload or notif.channel != "web_ws_heartbeat":
+ return
+
+ try:
+ payload = json.loads(notif.payload)
+ except json.JSONDecodeError:
+ logmsg = "Failed to decode message from pg_notify channel `web_ws_heartbeat`"
+ if logger.isEnabledFor(logging.DEBUG):
+ logmsg = "{} {}".format(logmsg, payload)
+ logger.warning(logmsg)
return
- self.known_hosts[hostname] = ip
- logger.debug(f"Web host {hostname} ({ip}) online heartbeat received.")
- elif payload.get("action") == "offline":
- hostname = payload.get("hostname")
- ip = payload.get("ip")
- if ip is None:
- # If we don't get an IP, just try the hostname, maybe it resolves
- ip = hostname
- if ip is None:
- logger.warning(f"Received invalid offline ws_heartbeat, missing hostname and ip: {payload}")
+
+ # Skip if the message comes from the same host we are running on
+ # In this case, we'll be sharing a redis, no need to relay.
+ if payload.get("hostname") == self.local_hostname:
return
- self.cleanup_offline_host(ip)
- logger.debug(f"Web host {hostname} ({ip}) offline heartbeat received.")
- except Exception as e:
- # This catch-all is the same as the one above. asyncio will eat the exception
- # but we want to know about it.
- logger.exception(f"on_ws_heartbeat exception: {e}")
+
+ if payload.get("action") == "online":
+ hostname = payload.get("hostname")
+ ip = payload.get("ip")
+ if ip is None:
+ # If we don't get an IP, just try the hostname, maybe it resolves
+ ip = hostname
+ if ip is None:
+ logger.warning(f"Received invalid online ws_heartbeat, missing hostname and ip: {payload}")
+ return
+ self.known_hosts[hostname] = ip
+ logger.debug(f"Web host {hostname} ({ip}) online heartbeat received.")
+ elif payload.get("action") == "offline":
+ hostname = payload.get("hostname")
+ ip = payload.get("ip")
+ if ip is None:
+ # If we don't get an IP, just try the hostname, maybe it resolves
+ ip = hostname
+ if ip is None:
+ logger.warning(f"Received invalid offline ws_heartbeat, missing hostname and ip: {payload}")
+ return
+ self.cleanup_offline_host(ip)
+ logger.debug(f"Web host {hostname} ({ip}) offline heartbeat received.")
+ except Exception as e:
+ # This catch-all is the same as the one above. asyncio will eat the exception
+ # but we want to know about it.
+ logger.exception(f"on_ws_heartbeat exception: {e}")
def cleanup_offline_host(self, hostname):
"""
@@ -282,16 +286,16 @@ class WebSocketRelayManager(object):
# Set up a pg_notify consumer for allowing web nodes to "provision" and "deprovision" themselves gracefully.
database_conf = settings.DATABASES['default']
- async_conn = await asyncpg.connect(
- database=database_conf['NAME'],
+ async_conn = await psycopg.AsyncConnection.connect(
+ dbname=database_conf['NAME'],
host=database_conf['HOST'],
user=database_conf['USER'],
password=database_conf['PASSWORD'],
port=database_conf['PORT'],
- # We cannot include these because asyncpg doesn't allow all the options that psycopg does.
- # **database_conf.get("OPTIONS", {}),
+ **database_conf.get("OPTIONS", {}),
)
- await async_conn.add_listener("web_ws_heartbeat", self.on_ws_heartbeat)
+ await async_conn.set_autocommit(True)
+ event_loop.create_task(self.on_ws_heartbeat(async_conn))
# Establishes a websocket connection to /websocket/relay on all API servers
while True: