From a2f1dc2bc64d0f2d2a7d7239896e671543c7b3f1 Mon Sep 17 00:00:00 2001 From: Andreas Date: Tue, 2 Dec 2025 22:20:22 +0100 Subject: [PATCH] First attempt at socket activation --- server.py | 100 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 22 deletions(-) diff --git a/server.py b/server.py index fca5050b..f8c6718a 100644 --- a/server.py +++ b/server.py @@ -37,7 +37,7 @@ from app.user_manager import UserManager from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from app.subgraph_manager import SubgraphManager -from typing import Optional, Union +from typing import Optional, Union, List, Tuple, Any from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes @@ -1039,29 +1039,85 @@ class PromptServer(): ssl_ctx.load_cert_chain(certfile=args.tls_certfile, keyfile=args.tls_keyfile) scheme = "https" - - if verbose: - logging.info("Starting server\n") - for addr in addresses: - address = addr[0] - port = addr[1] - site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) - await site.start() - - if not hasattr(self, 'address'): - self.address = address #TODO: remove this - self.port = port - - if ':' in address: - address_print = "[{}]".format(address) - else: - address_print = address - + + + # Check for systemd socket activation + systemd_sockets = self._get_systemd_sockets() + + if systemd_sockets: + # Use systemd sockets instead of the provided addresses if verbose: - logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port)) + logging.info("Using systemd socket activation") + + # Start sites using the systemd sockets + for sock in systemd_sockets: + # Get socket info for logging + try: + sock_name = sock.getsockname() + address = sock_name[0] if sock_name else "unknown" + port = sock_name[1] if sock_name else "unknown" + except: + address = "unknown" + port = "unknown" + + if verbose: + logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port)) + + # Create site using the socket + site = web.SockSite(runner, sock, ssl_context=ssl_ctx) + await site.start() + + # Store the socket info for reference + self.systemd_sockets = systemd_sockets + else: + # Fallback to original behavior + if verbose: + logging.info("Starting server\n") + + for addr in addresses: + address = addr[0] + port = addr[1] + site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) + await site.start() - if call_on_start is not None: - call_on_start(scheme, self.address, self.port) + if not hasattr(self, 'address'): + self.address = address + self.port = port + + if ':' in address: + address_print = "[{}]".format(address) + else: + address_print = address + + if verbose: + logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port)) + + # Call the callback if provided + if call_on_start: + await call_on_start() + + def _get_systemd_sockets(self): + """Get sockets from systemd socket activation""" + sockets = [] + + # Check if systemd socket activation is being used + if 'LISTEN_FDS' in os.environ and 'LISTEN_PID' in os.environ: + listen_fds = int(os.environ['LISTEN_FDS']) + listen_pid = int(os.environ['LISTEN_PID']) + + # Verify this is our process + if listen_pid == os.getpid(): + # Create sockets from file descriptors 3, 4, 5, etc. + for i in range(listen_fds): + try: + # File descriptor starts at 3 for systemd + fd = 3 + i + sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + sockets.append(sock) + except Exception as e: + logging.error(f"Failed to create socket from fd {fd}: {e}") + + return sockets def add_on_prompt_handler(self, handler): self.on_prompt_handlers.append(handler)