Rework socket activation
This commit is contained in:
169
server.py
169
server.py
@@ -44,6 +44,33 @@ from protocol import BinaryEventTypes
|
||||
# Import cache control middleware
|
||||
from middleware.cache_middleware import cache_control
|
||||
|
||||
# Helper --------------------------------------------------------------------
|
||||
def _fd_to_socket(fd: int) -> socket.socket:
|
||||
"""
|
||||
Turn a raw file‑descriptor that came from systemd into a *non‑blocking*
|
||||
``socket.socket`` instance that preserves the original family, type and
|
||||
protocol.
|
||||
|
||||
The function mirrors the logic used by `sd_listen_fds` in the systemd
|
||||
libraries (see libsystemd/libsystemd/sd-daemon.c). It queries the
|
||||
kernel for the socket's actual family, type and protocol using
|
||||
``getsockopt`` on the fd.
|
||||
"""
|
||||
# Query the kernel for the family / type / protocol of the fd.
|
||||
# These syscalls are cheap and work for IPv4, IPv6 and Unix sockets.
|
||||
family = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_DOMAIN)
|
||||
sock_type = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_TYPE)
|
||||
proto = socket.getsockopt(fd, socket.SOL_SOCKET, socket.SO_PROTOCOL)
|
||||
|
||||
# Build the socket object *without* creating a new fd.
|
||||
# ``socket.socket`` with ``fileno=fd`` re‑uses the existing descriptor.
|
||||
# ``closefd=False`` tells Python not to close the fd when the socket
|
||||
# object is garbage‑collected – systemd will close it when the service
|
||||
# exits.
|
||||
sock = socket.socket(family=family, type=sock_type, proto=proto, fileno=fd)
|
||||
sock.setblocking(False) # aiohttp expects non‑blocking
|
||||
return sock
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
@@ -1030,93 +1057,119 @@ class PromptServer():
|
||||
await self.start_multi_address([(address, port)], call_on_start=call_on_start)
|
||||
|
||||
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||||
"""
|
||||
Starts the aiohttp server. If systemd activation is detected the
|
||||
provided ``addresses`` are ignored and the sockets returned by
|
||||
``_get_systemd_sockets`` are used instead.
|
||||
"""
|
||||
runner = web.AppRunner(self.app, access_log=None)
|
||||
await runner.setup()
|
||||
ssl_ctx = None
|
||||
scheme = "http"
|
||||
if args.tls_keyfile and args.tls_certfile:
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER, verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_SERVER,
|
||||
verify_mode=ssl.CERT_NONE)
|
||||
ssl_ctx.load_cert_chain(certfile=args.tls_certfile,
|
||||
keyfile=args.tls_keyfile)
|
||||
keyfile=args.tls_keyfile)
|
||||
scheme = "https"
|
||||
|
||||
|
||||
# Check for systemd socket activation
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Systemd activation ------------------------------------------------
|
||||
systemd_sockets = self._get_systemd_sockets()
|
||||
|
||||
if systemd_sockets:
|
||||
# Use systemd sockets instead of the provided addresses
|
||||
if verbose:
|
||||
logging.info("Using systemd socket activation")
|
||||
|
||||
# Start sites using the systemd sockets
|
||||
logging.info("Systemd socket activation detected – using supplied socket(s)")
|
||||
|
||||
for sock in systemd_sockets:
|
||||
# Get socket info for logging
|
||||
# ``sock.getsockname()`` can be 2‑tuple (IPv4) or 4‑tuple (IPv6)
|
||||
# or a string (Unix). Normalise for logging.
|
||||
try:
|
||||
sock_name = sock.getsockname()
|
||||
address = sock_name[0] if sock_name else "unknown"
|
||||
port = sock_name[1] if sock_name else "unknown"
|
||||
except:
|
||||
address = "unknown"
|
||||
port = "unknown"
|
||||
|
||||
raw_name = sock.getsockname()
|
||||
if isinstance(raw_name, tuple):
|
||||
host = raw_name[0]
|
||||
port = raw_name[1]
|
||||
if sock.family == socket.AF_INET6:
|
||||
host = f"[{host}]"
|
||||
else:
|
||||
# Unix domain socket – just show the path.
|
||||
host, port = raw_name, ""
|
||||
except Exception:
|
||||
host, port = "unknown", "unknown"
|
||||
|
||||
if verbose:
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address, port))
|
||||
|
||||
# Create site using the socket
|
||||
logging.info(f"GUI reachable at: {scheme}://{host}:{port}")
|
||||
|
||||
site = web.SockSite(runner, sock, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
|
||||
# Store the socket info for reference
|
||||
|
||||
# Keep a reference – useful for debugging / graceful shutdown
|
||||
self.systemd_sockets = systemd_sockets
|
||||
else:
|
||||
# Fallback to original behavior
|
||||
# -----------------------------------------------------------------
|
||||
# Classic TCPSite fallback -----------------------------------------
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
logging.info("Systemd activation not detected – falling back to manual bind")
|
||||
for address, port in addresses:
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
if not hasattr(self, "address"):
|
||||
self.address = address
|
||||
self.port = port
|
||||
|
||||
if ':' in address:
|
||||
address_print = "[{}]".format(address)
|
||||
else:
|
||||
address_print = address
|
||||
|
||||
# Nicely format IPv6 literals for the log line.
|
||||
address_print = f"[{address}]" if ":" in address else address
|
||||
if verbose:
|
||||
logging.info("To see the GUI go to: {}://{}:{}".format(scheme, address_print, port))
|
||||
logging.info(f"GUI reachable at: {scheme}://{address_print}:{port}")
|
||||
|
||||
# Call the callback if provided
|
||||
# -----------------------------------------------------------------
|
||||
if call_on_start:
|
||||
await call_on_start()
|
||||
|
||||
def _get_systemd_sockets(self) -> List[socket.socket]:
|
||||
"""
|
||||
Detect systemd socket activation and return a list of ready‑to‑use
|
||||
``socket.socket`` objects.
|
||||
|
||||
def _get_systemd_sockets(self):
|
||||
"""Get sockets from systemd socket activation"""
|
||||
sockets = []
|
||||
|
||||
# Check if systemd socket activation is being used
|
||||
if 'LISTEN_FDS' in os.environ and 'LISTEN_PID' in os.environ:
|
||||
listen_fds = int(os.environ['LISTEN_FDS'])
|
||||
listen_pid = int(os.environ['LISTEN_PID'])
|
||||
|
||||
# Verify this is our process
|
||||
if listen_pid == os.getpid():
|
||||
# Create sockets from file descriptors 3, 4, 5, etc.
|
||||
for i in range(listen_fds):
|
||||
try:
|
||||
# File descriptor starts at 3 for systemd
|
||||
fd = 3 + i
|
||||
sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
|
||||
sockets.append(sock)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to create socket from fd {fd}: {e}")
|
||||
|
||||
The function follows the systemd protocol:
|
||||
* `LISTEN_FDS` tells how many file descriptors have been passed.
|
||||
* `LISTEN_PID` must match ``os.getpid()``.
|
||||
* File descriptors start at 3.
|
||||
* After we have consumed them we *unset* the environment variables so
|
||||
that any child processes don’t accidentally try to reuse the same
|
||||
fds.
|
||||
"""
|
||||
sockets: List[socket.socket] = []
|
||||
|
||||
if "LISTEN_FDS" not in os.environ or "LISTEN_PID" not in os.environ:
|
||||
return sockets
|
||||
|
||||
try:
|
||||
listen_fds = int(os.getenv("LISTEN_FDS", "0"))
|
||||
listen_pid = int(os.getenv("LISTEN_PID", "0"))
|
||||
except ValueError:
|
||||
logging.error("Invalid LISTEN_FDS/LISTEN_PID values")
|
||||
return sockets
|
||||
|
||||
if listen_pid != os.getpid():
|
||||
# The activation was meant for a *different* process (e.g. a
|
||||
# parent that forked before exec). Ignore it.
|
||||
return sockets
|
||||
|
||||
for i in range(listen_fds):
|
||||
fd = 3 + i
|
||||
try:
|
||||
sock = _fd_to_socket(fd)
|
||||
sockets.append(sock)
|
||||
logging.debug(f"systemd socket #{i} – fd {fd} – family {sock.family} "
|
||||
f"type {sock.type} proto {sock.proto}")
|
||||
except OSError as exc:
|
||||
logging.error(f"Failed to convert fd {fd} into a socket: {exc}")
|
||||
|
||||
# Remove the activation env‑vars – child processes will otherwise think
|
||||
# they have been passed sockets as well.
|
||||
os.unsetenv("LISTEN_FDS")
|
||||
os.unsetenv("LISTEN_PID")
|
||||
return sockets
|
||||
|
||||
def add_on_prompt_handler(self, handler):
|
||||
|
||||
Reference in New Issue
Block a user