Rework socket activation

This commit is contained in:
2025-12-02 22:36:25 +01:00
parent a2f1dc2bc6
commit 6918fd1efe

169
server.py
View File

@@ -44,6 +44,33 @@ from protocol import BinaryEventTypes
# Import cache control middleware # Import cache control middleware
from middleware.cache_middleware import cache_control from middleware.cache_middleware import cache_control
# Helper --------------------------------------------------------------------
def _fd_to_socket(fd: int) -> socket.socket:
"""
Turn a raw filedescriptor that came from systemd into a *nonblocking*
``socket.socket`` instance that preserves the original family, type and
protocol.
The function mirrors the logic used by `sd_listen_fds` in the systemd
libraries (see libsystemd/libsystemd/sd-daemon.c). It queries the
kernel for the socket's actual family, type and protocol using
``getsockopt`` on the fd.
"""
# Query the kernel for the family / type / protocol of the fd.
# These syscalls are cheap and work for IPv4, IPv6 and Unix sockets.
family = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_DOMAIN)
sock_type = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_TYPE)
proto = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_PROTOCOL)
# Build the socket object *without* creating a new fd.
# ``socket.socket`` with ``fileno=fd`` reuses the existing descriptor.
# ``closefd=False`` tells Python not to close the fd when the socket
# object is garbagecollected systemd will close it when the service
# exits.
sock = socket.socket(family=family, type=sock_type, proto=proto, fileno=fd)
sock.setblocking(False) # aiohttp expects nonblocking
return sock
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
await function(message) await function(message)
@@ -1030,93 +1057,119 @@ class PromptServer():
await self.start_multi_address([(address, port)], call_on_start=call_on_start) await self.start_multi_address([(address, port)], call_on_start=call_on_start)
async def start_multi_address(self, addresses, call_on_start=None, verbose=True): async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
"""
Starts the aiohttp server. If systemd activation is detected the
provided ``addresses`` are ignored and the sockets returned by
``_get_systemd_sockets`` are used instead.
"""
runner = web.AppRunner(self.app, access_log=None) runner = web.AppRunner(self.app, access_log=None)
await runner.setup() await runner.setup()
ssl_ctx = None ssl_ctx = None
scheme = "http" scheme = "http"
if args.tls_keyfile and args.tls_certfile: if args.tls_keyfile and args.tls_certfile:
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE) ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER,
verify_mode=ssl.CERT_NONE)
ssl_ctx.load_cert_chain(certfile=args.tls_certfile, ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
keyfile=args.tls_keyfile) keyfile=args.tls_keyfile)
scheme = "https" scheme = "https"
# -----------------------------------------------------------------
# Check for systemd socket activation # Systemd activation ------------------------------------------------
systemd_sockets = self._get_systemd_sockets() systemd_sockets = self._get_systemd_sockets()
if systemd_sockets: if systemd_sockets:
# Use systemd sockets instead of the provided addresses
if verbose: if verbose:
logging.info("Using systemd socket activation") logging.info("Systemd socket activation detected using supplied socket(s)")
# Start sites using the systemd sockets
for sock in systemd_sockets: for sock in systemd_sockets:
# Get socket info for logging # ``sock.getsockname()`` can be 2tuple (IPv4) or 4tuple (IPv6)
# or a string (Unix). Normalise for logging.
try: try:
sock_name = sock.getsockname() raw_name = sock.getsockname()
address = sock_name[0] if sock_name else "unknown" if isinstance(raw_name, tuple):
port = sock_name[1] if sock_name else "unknown" host = raw_name[0]
except: port = raw_name[1]
address = "unknown" if sock.family == socket.AF_INET6:
port = "unknown" host = f"[{host}]"
else:
# Unix domain socket just show the path.
host, port = raw_name, ""
except Exception:
host, port = "unknown", "unknown"
if verbose: if verbose:
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port)) logging.info(f"GUI reachable at: {scheme}://{host}:{port}")
# Create site using the socket
site = web.SockSite(runner, sock, ssl_context=ssl_ctx) site = web.SockSite(runner, sock, ssl_context=ssl_ctx)
await site.start() await site.start()
# Store the socket info for reference # Keep a reference useful for debugging / graceful shutdown
self.systemd_sockets = systemd_sockets self.systemd_sockets = systemd_sockets
else: else:
# Fallback to original behavior # -----------------------------------------------------------------
# Classic TCPSite fallback -----------------------------------------
if verbose: if verbose:
logging.info("Starting server\n") logging.info("Systemd activation not detected falling back to manual bind")
for address, port in addresses:
for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start() await site.start()
if not hasattr(self, 'address'): if not hasattr(self, "address"):
self.address = address self.address = address
self.port = port self.port = port
if ':' in address: # Nicely format IPv6 literals for the log line.
address_print = "[{}]".format(address) address_print = f"[{address}]" if ":" in address else address
else:
address_print = address
if verbose: if verbose:
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port)) logging.info(f"GUI reachable at: {scheme}://{address_print}:{port}")
# Call the callback if provided # -----------------------------------------------------------------
if call_on_start: if call_on_start:
await call_on_start() await call_on_start()
def _get_systemd_sockets(self) -> List[socket.socket]:
"""
Detect systemd socket activation and return a list of readytouse
``socket.socket`` objects.
def _get_systemd_sockets(self): The function follows the systemd protocol:
"""Get sockets from systemd socket activation""" * `LISTEN_FDS` tells how many file descriptors have been passed.
sockets = [] * `LISTEN_PID` must match ``os.getpid()``.
* File descriptors start at 3.
# Check if systemd socket activation is being used * After we have consumed them we *unset* the environment variables so
if 'LISTEN_FDS' in os.environ and 'LISTEN_PID' in os.environ: that any child processes dont accidentally try to reuse the same
listen_fds = int(os.environ['LISTEN_FDS']) fds.
listen_pid = int(os.environ['LISTEN_PID']) """
sockets: List[socket.socket] = []
# Verify this is our process
if listen_pid == os.getpid(): if "LISTEN_FDS" not in os.environ or "LISTEN_PID" not in os.environ:
# Create sockets from file descriptors 3, 4, 5, etc. return sockets
for i in range(listen_fds):
try: try:
# File descriptor starts at 3 for systemd listen_fds = int(os.getenv("LISTEN_FDS", "0"))
fd = 3 + i listen_pid = int(os.getenv("LISTEN_PID", "0"))
sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) except ValueError:
sockets.append(sock) logging.error("Invalid LISTEN_FDS/LISTEN_PID values")
except Exception as e: return sockets
logging.error(f"Failed to create socket from fd {fd}: {e}")
if listen_pid != os.getpid():
# The activation was meant for a *different* process (e.g. a
# parent that forked before exec). Ignore it.
return sockets
for i in range(listen_fds):
fd = 3 + i
try:
sock = _fd_to_socket(fd)
sockets.append(sock)
logging.debug(f"systemd socket #{i} fd {fd} family {sock.family} "
f"type {sock.type} proto {sock.proto}")
except OSError as exc:
logging.error(f"Failed to convert fd {fd} into a socket: {exc}")
# Remove the activation envvars child processes will otherwise think
# they have been passed sockets as well.
os.unsetenv("LISTEN_FDS")
os.unsetenv("LISTEN_PID")
return sockets return sockets
def add_on_prompt_handler(self, handler): def add_on_prompt_handler(self, handler):