"""Handle clients' connections."""
from __future__ import annotations
import logging
import select
import socket as s
import time
from typing import TYPE_CHECKING
from cable_club import watcher
from cable_club.data import Writer, models
from cable_club.network import Client
from cable_club.network.states import Connected, Finding
if TYPE_CHECKING:
from cable_club.config import Config
_logger = logging.getLogger(__name__)
[docs]
class Server:
"""Model the server's logic."""
def __init__(self, config: Config) -> None:
"""Initialize an instance."""
self.config = config
models.configure(config)
self.refresh_rules_at = time.monotonic()
self.clients: dict[s.socket, Client] = {}
_, self.rules_files = watcher.rules_changed(self.config.rules_dir, {})
self.rules = watcher.load_rules(self.config.rules_dir, self.rules_files)
[docs]
def select(self) -> tuple[list[s.socket], list[s.socket], list[s.socket]]:
"""Thin wrapper on top of select.
Collects all readers and writers and hands them off to select.select().
"""
reads = list(self.clients)
reads.append(self.socket)
writes = [sock for sock, client in self.clients.items() if client.send_buffer]
return select.select(reads, writes, reads, 1.0)
[docs]
def run(self) -> None:
"""Execute the server's logic (blocking busy loop)."""
with s.socket(s.AF_INET, s.SOCK_STREAM) as self.socket:
self.socket.setsockopt(s.SOL_SOCKET, s.SO_REUSEADDR, 1)
self.socket.bind((self.config.host, self.config.port))
_logger.info("Started Server on %s:%d", self.config.host, self.config.port)
self.socket.listen()
try:
while True:
self.maybe_reload_rules()
read, write, errors = self.select()
self.handle_errors(errors)
self.write_to_all(write)
self.read_all(read)
except KeyboardInterrupt:
_logger.info("Stopping Server")
[docs]
def maybe_reload_rules(self) -> None:
"""Check the rules folder for updates.
This happens every config.rules_refresh_rate seconds (approx).
"""
if time.monotonic() < self.refresh_rules_at:
return
reload_rules, rules_files = watcher.rules_changed(
self.config.rules_dir,
self.rules_files,
)
if reload_rules:
self.rules_files = rules_files
self.rules = watcher.load_rules(self.config.rules_dir, self.rules_files)
self.refresh_rules_at = time.monotonic() + self.config.rules_refresh_rate
[docs]
def handle_error(self, socket: s.socket) -> None:
"""Handle a single error socket."""
if socket is self.socket:
msg = "Error on listening socket."
raise RuntimeError(msg)
self.disconnect(socket)
[docs]
def handle_errors(self, sockets: list[s.socket]) -> None:
"""Handle all error sockets."""
for socket in sockets:
self.handle_error(socket)
[docs]
def write_to(self, socket: s.socket) -> None:
"""Write to a single socket."""
client = self.clients[socket]
try:
buffer = client.send_buffer
n = socket.send(buffer)
_logger.debug("sent %s to %s", buffer, socket)
client.send_buffer = client.send_buffer[n:]
except s.error as e: # noqa: UP024
# ruff complains that socket.error is an alias to OSError and should use
# it instead. however, keeping it like this in case this implementation
# detail changes and `socket.error` becomes something else
self.disconnect(socket, str(e))
[docs]
def write_to_all(self, sockets: list[s.socket]) -> None:
"""Write to all sockets."""
for socket in sockets:
self.write_to(socket)
[docs]
def read_from(self, socket: s.socket) -> None:
"""Read from a single socket."""
if socket is self.socket:
new_sock, address = self.socket.accept()
# ruff doesnt like a boolean argument without any name
# but that's the function signature, nothing we can do here
new_sock.setblocking(False) # noqa: FBT003
# NOTE: address is `Any` based on official Python hinting...
client = self.clients[new_sock] = Client(address)
_logger.info("%s: connected", client)
return
client = self.clients[socket]
try:
recvd = socket.recv(4096)
except ConnectionResetError:
self.disconnect(socket)
return
if not recvd:
# Zero-length read from a non-blocking socket is
# a disconnect.
self.disconnect(socket, "client disconnected")
return
recv_buffer = client.recv_buffer + recvd
while True:
message, sep, recv_buffer = recv_buffer.partition(b"\n")
if not sep:
# No newline, buffer the partial message.
client.recv_buffer = message
break
_logger.debug("received: %s", message)
try:
old = client.state
client.state, state_changed = old.handle(socket, self, message)
if state_changed:
_logger.debug("transition: %s -> %s", old, client.state)
except Exception as e:
msg = "server error"
_logger.exception(msg, exc_info=e)
self.disconnect(socket, msg)
[docs]
def read_all(self, sockets: list[s.socket]) -> None:
"""Read from all sockets."""
for socket in sockets:
self.read_from(socket)
[docs]
def connect(self, s_connecting: s.socket, s_finding: s.socket) -> None:
"""Tell two clients about each other's existence."""
c_connecting = self.clients[s_connecting]
c_finding = self.clients[s_finding]
if not (
isinstance(c_connecting.state, Finding)
and isinstance(c_finding.state, Finding)
):
_logger.error(
"Can only use Server.connect() on players in the Finding state",
)
return
# let them know about each other
writer = Writer()
writer.add("found")
writer.add(0)
c_finding.state.write(writer)
self.write_server_rules(writer)
writer.send(c_connecting)
writer = Writer()
writer.add("found")
writer.add(1)
c_connecting.state.write(writer)
self.write_server_rules(writer)
writer.send(c_finding)
# mark them as connected
c_connecting.state = Connected(s_finding)
c_finding.state = Connected(s_connecting)
_logger.info("%s: connected to %s", c_connecting, c_finding)
[docs]
def disconnect(self, socket: s.socket, reason: str = "unknown error") -> None:
"""Close a client's connection."""
_logger.debug("disconnecting %s. reason: %s", socket, reason)
try:
client = self.clients.pop(socket)
# this happens, at least, when a bad message comes is (eg: bots looking for
# vulnerabilities). socket wasn't setup as a client yet and thus .pop() fails...
# instead of cluttering logs with it, lets just ignore the exception
except KeyError:
return
try:
writer = Writer()
writer.add("disconnect")
writer.add(reason)
writer.send_now(socket)
socket.close()
except Exception as e:
_logger.exception("Couldnt send reason to socket", exc_info=e)
return
# disconnect the other end
if isinstance(client.state, Connected):
self.disconnect(client.state.peer, "peer disconnected")
[docs]
def write_server_rules(self, writer: Writer) -> None:
"""Dump server's rules into a writer."""
writer.add(len(self.rules))
for r in self.rules:
writer.add_raw(r)