From 6918fd1efe1a77d51d4f385885461f5215c10fe0 Mon Sep 17 00:00:00 2001 From: Andreas Date: Tue, 2 Dec 2025 22:36:25 +0100 Subject: [PATCH] Rework socket activation --- server.py | 169 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 58 deletions(-) diff --git a/server.py b/server.py index f8c6718a..e9ae8bde 100644 --- a/server.py +++ b/server.py @@ -44,6 +44,33 @@ from protocol import BinaryEventTypes # Import cache control middleware from middleware.cache_middleware import cache_control +# Helper -------------------------------------------------------------------- +def _fd_to_socket(fd: int) -> socket.socket: + """ + Turn a raw file‑descriptor that came from systemd into a *non‑blocking* + ``socket.socket`` instance that preserves the original family, type and + protocol. + + The function mirrors the logic used by `sd_listen_fds` in the systemd + libraries (see libsystemd/libsystemd/sd-daemon.c). It queries the + kernel for the socket's actual family, type and protocol using + ``getsockopt`` on the fd. + """ + # Query the kernel for the family / type / protocol of the fd. + # These syscalls are cheap and work for IPv4, IPv6 and Unix sockets. + family = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_DOMAIN) + sock_type = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_TYPE) + proto = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_PROTOCOL) + + # Build the socket object *without* creating a new fd. + # ``socket.socket`` with ``fileno=fd`` re‑uses the existing descriptor. + # ``closefd=False`` tells Python not to close the fd when the socket + # object is garbage‑collected – systemd will close it when the service + # exits. + sock = socket.socket(family=family, type=sock_type, proto=proto, fileno=fd) + sock.setblocking(False) # aiohttp expects non‑blocking + return sock + async def send_socket_catch_exception(function, message): try: await function(message) @@ -1030,93 +1057,119 @@ class PromptServer(): await self.start_multi_address([(address, port)], call_on_start=call_on_start) async def start_multi_address(self, addresses, call_on_start=None, verbose=True): + """ + Starts the aiohttp server. If systemd activation is detected the + provided ``addresses`` are ignored and the sockets returned by + ``_get_systemd_sockets`` are used instead. + """ runner = web.AppRunner(self.app, access_log=None) await runner.setup() ssl_ctx = None scheme = "http" if args.tls_keyfile and args.tls_certfile: - ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) + ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, + verify_mode=ssl.CERT_NONE) ssl_ctx.load_cert_chain(certfile=args.tls_certfile, - keyfile=args.tls_keyfile) + keyfile=args.tls_keyfile) scheme = "https" - - - # Check for systemd socket activation + + # ----------------------------------------------------------------- + # Systemd activation ------------------------------------------------ systemd_sockets = self._get_systemd_sockets() - if systemd_sockets: - # Use systemd sockets instead of the provided addresses if verbose: - logging.info("Using systemd socket activation") - - # Start sites using the systemd sockets + logging.info("Systemd socket activation detected – using supplied socket(s)") + for sock in systemd_sockets: - # Get socket info for logging + # ``sock.getsockname()`` can be 2‑tuple (IPv4) or 4‑tuple (IPv6) + # or a string (Unix). Normalise for logging. try: - sock_name = sock.getsockname() - address = sock_name[0] if sock_name else "unknown" - port = sock_name[1] if sock_name else "unknown" - except: - address = "unknown" - port = "unknown" - + raw_name = sock.getsockname() + if isinstance(raw_name, tuple): + host = raw_name[0] + port = raw_name[1] + if sock.family == socket.AF_INET6: + host = f"[{host}]" + else: + # Unix domain socket – just show the path. + host, port = raw_name, "" + except Exception: + host, port = "unknown", "unknown" + if verbose: - logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port)) - - # Create site using the socket + logging.info(f"GUI reachable at: {scheme}://{host}:{port}") + site = web.SockSite(runner, sock, ssl_context=ssl_ctx) await site.start() - - # Store the socket info for reference + + # Keep a reference – useful for debugging / graceful shutdown self.systemd_sockets = systemd_sockets else: - # Fallback to original behavior + # ----------------------------------------------------------------- + # Classic TCPSite fallback ----------------------------------------- if verbose: - logging.info("Starting server\n") - - for addr in addresses: - address = addr[0] - port = addr[1] + logging.info("Systemd activation not detected – falling back to manual bind") + for address, port in addresses: site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) await site.start() - if not hasattr(self, 'address'): + if not hasattr(self, "address"): self.address = address self.port = port - if ':' in address: - address_print = "[{}]".format(address) - else: - address_print = address - + # Nicely format IPv6 literals for the log line. + address_print = f"[{address}]" if ":" in address else address if verbose: - logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port)) + logging.info(f"GUI reachable at: {scheme}://{address_print}:{port}") - # Call the callback if provided + # ----------------------------------------------------------------- if call_on_start: await call_on_start() + + def _get_systemd_sockets(self) -> List[socket.socket]: + """ + Detect systemd socket activation and return a list of ready‑to‑use + ``socket.socket`` objects. - def _get_systemd_sockets(self): - """Get sockets from systemd socket activation""" - sockets = [] - - # Check if systemd socket activation is being used - if 'LISTEN_FDS' in os.environ and 'LISTEN_PID' in os.environ: - listen_fds = int(os.environ['LISTEN_FDS']) - listen_pid = int(os.environ['LISTEN_PID']) - - # Verify this is our process - if listen_pid == os.getpid(): - # Create sockets from file descriptors 3, 4, 5, etc. - for i in range(listen_fds): - try: - # File descriptor starts at 3 for systemd - fd = 3 + i - sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) - sockets.append(sock) - except Exception as e: - logging.error(f"Failed to create socket from fd {fd}: {e}") - + The function follows the systemd protocol: + * `LISTEN_FDS` tells how many file descriptors have been passed. + * `LISTEN_PID` must match ``os.getpid()``. + * File descriptors start at 3. + * After we have consumed them we *unset* the environment variables so + that any child processes don’t accidentally try to reuse the same + fds. + """ + sockets: List[socket.socket] = [] + + if "LISTEN_FDS" not in os.environ or "LISTEN_PID" not in os.environ: + return sockets + + try: + listen_fds = int(os.getenv("LISTEN_FDS", "0")) + listen_pid = int(os.getenv("LISTEN_PID", "0")) + except ValueError: + logging.error("Invalid LISTEN_FDS/LISTEN_PID values") + return sockets + + if listen_pid != os.getpid(): + # The activation was meant for a *different* process (e.g. a + # parent that forked before exec). Ignore it. + return sockets + + for i in range(listen_fds): + fd = 3 + i + try: + sock = _fd_to_socket(fd) + sockets.append(sock) + logging.debug(f"systemd socket #{i} – fd {fd} – family {sock.family} " + f"type {sock.type} proto {sock.proto}") + except OSError as exc: + logging.error(f"Failed to convert fd {fd} into a socket: {exc}") + + # Remove the activation env‑vars – child processes will otherwise think + # they have been passed sockets as well. + os.unsetenv("LISTEN_FDS") + os.unsetenv("LISTEN_PID") return sockets def add_on_prompt_handler(self, handler):