mastersrv: Add Action: delete to /ddnet/15/register

This will delete a server from the server list without a timeout,
allowing servers that want to stop registering or servers that shut down
gracefully to no longer appear in the server list.
This commit is contained in:
heinrich5991 2022-06-03 11:22:52 +02:00
parent 264c6f969d
commit ce08cc0e53
2 changed files with 257 additions and 88 deletions

View file

@ -22,15 +22,29 @@ pub struct Addr {
pub protocol: Protocol,
}
/// A register address, serialized like
/// tw-0.6+udp://connecting-address.invalid:8303.
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct RegisterAddr {
pub port: u16,
pub protocol: Protocol,
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.as_str().fmt(f)
}
}
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
pub struct UnknownProtocol;
impl fmt::Display for UnknownProtocol {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
"protocol must be one of tw-0.5+udp, tw-0.6+udp or tw-0.7+udp".fmt(f)
}
}
impl FromStr for Protocol {
type Err = UnknownProtocol;
fn from_str(s: &str) -> Result<Protocol, UnknownProtocol> {
@ -101,7 +115,7 @@ impl fmt::Display for Addr {
}
}
#[derive(Debug)]
#[derive(Clone, Copy, Debug)]
pub struct InvalidAddr;
impl FromStr for Addr {
@ -160,6 +174,100 @@ impl<'de> serde::Deserialize<'de> for Addr {
}
}
impl RegisterAddr {
pub fn with_ip(self, ip: IpAddr) -> Addr {
Addr {
ip,
port: self.port,
protocol: self.protocol,
}
}
}
impl fmt::Display for RegisterAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut buf: ArrayString<[u8; 128]> = ArrayString::new();
write!(
&mut buf,
"{}://connecting-address.invalid:{}",
self.protocol, self.port,
)
.unwrap();
buf.fmt(f)
}
}
#[derive(Clone, Copy, Debug)]
pub enum ParseRegisterAddrError {
Url(url::ParseError),
Protocol(UnknownProtocol),
HostNotConnectingAddressInvalid,
PortNotPresent,
}
impl fmt::Display for ParseRegisterAddrError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::ParseRegisterAddrError::*;
match *self {
Url(e) => write!(f, "URL parse error: {}", e),
Protocol(e) => write!(f, "protocol parse error: {}", e),
HostNotConnectingAddressInvalid => write!(
f,
"register address must have domain connecting-address.invalid"
),
PortNotPresent => write!(f, "register address must specify port"),
}
}
}
impl FromStr for RegisterAddr {
type Err = ParseRegisterAddrError;
fn from_str(s: &str) -> Result<RegisterAddr, ParseRegisterAddrError> {
use self::ParseRegisterAddrError as Error;
let url = Url::parse(s).map_err(Error::Url)?;
let protocol: Protocol = url.scheme().parse().map_err(Error::Protocol)?;
if url.host_str() != Some("connecting-address.invalid") {
return Err(Error::HostNotConnectingAddressInvalid);
}
let port = url.port().ok_or(Error::PortNotPresent)?;
Ok(RegisterAddr { port, protocol })
}
}
impl serde::Serialize for RegisterAddr {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut buf: ArrayString<[u8; 128]> = ArrayString::new();
write!(&mut buf, "{}", self).unwrap();
serializer.serialize_str(&buf)
}
}
struct RegisterAddrVisitor;
impl<'de> serde::de::Visitor<'de> for RegisterAddrVisitor {
type Value = RegisterAddr;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("a URL like tw-0.6+udp://connecting-address.invalid:8303")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<RegisterAddr, E> {
let invalid_value = || E::invalid_value(serde::de::Unexpected::Str(v), &self);
Ok(RegisterAddr::from_str(v).map_err(|_| invalid_value())?)
}
}
impl<'de> serde::Deserialize<'de> for RegisterAddr {
fn deserialize<D>(deserializer: D) -> Result<RegisterAddr, D::Error>
where
D: serde::de::Deserializer<'de>,
{
deserializer.deserialize_str(RegisterAddrVisitor)
}
}
#[cfg(test)]
mod test {
use super::Addr;
@ -186,4 +294,15 @@ mod test {
}
);
}
#[test]
fn register_addr_from_str() {
assert_eq!(
RegisterAddr::from_str("tw-0.6+udp://connecting-address.invalid:8303").unwrap(),
RegisterAddr {
port: 8303,
protocol: Protocol::V6,
}
);
}
}

View file

@ -22,6 +22,7 @@ use std::mem;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::panic;
use std::panic::UnwindSafe;
use std::path::Path;
use std::process;
use std::str;
@ -35,7 +36,6 @@ use tokio::fs;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::time;
use url::Url;
use warp::Filter;
#[macro_use]
@ -43,6 +43,7 @@ extern crate log;
use crate::addr::Addr;
use crate::addr::Protocol;
use crate::addr::RegisterAddr;
use crate::locations::Location;
use crate::locations::Locations;
@ -55,11 +56,9 @@ const SERVER_TIMEOUT_SECONDS: u64 = 30;
type ShortString = ArrayString<[u8; 64]>;
// TODO: delete action for server shutdown
#[derive(Debug, Deserialize)]
struct Register {
address: Url,
address: RegisterAddr,
secret: ShortString,
connless_request_token: Option<ShortString>,
challenge_secret: ShortString,
@ -68,6 +67,12 @@ struct Register {
info: Option<json::Value>,
}
#[derive(Debug, Deserialize)]
struct Delete {
address: RegisterAddr,
secret: ShortString,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "snake_case", tag = "status")]
enum RegisterResponse {
@ -337,6 +342,11 @@ enum AddResult {
Obsolete,
}
enum RemoveResult {
Removed,
NotFound,
}
struct FromDumpError;
impl Servers {
@ -415,6 +425,19 @@ impl Servers {
AddResult::Refreshed
}
}
fn remove(&mut self, addr: Addr, secret: ShortString) -> RemoveResult {
match self.addresses.get(&addr) {
Some(a_info) if secret == a_info.secret => {}
_ => return RemoveResult::NotFound,
}
assert!(self.addresses.remove(&addr).is_some());
let server = self.servers.get_mut(&secret).unwrap();
server.addresses.retain(|&a| a != addr);
if server.addresses.is_empty() {
assert!(self.servers.remove(&secret).is_some());
}
RemoveResult::Removed
}
fn prune_before(&mut self, time: Timestamp, log: bool) {
let mut remove = Vec::new();
for (&addr, a_info) in &self.addresses {
@ -658,11 +681,7 @@ fn handle_register(
remote_addr: IpAddr,
register: Register,
) -> Result<RegisterResponse, RegisterError> {
let protocol: Protocol = register.address.scheme().parse().map_err(|_| {
"register address must start with one of tw-0.5+udp://, tw-0.6+udp://, tw-0.7+udp://"
})?;
let connless_request_token_7 = match protocol {
let connless_request_token_7 = match register.address.protocol {
Protocol::V5 => None,
Protocol::V6 => None,
Protocol::V7 => {
@ -677,20 +696,8 @@ fn handle_register(
Some(token)
}
};
if register.address.host_str() != Some("connecting-address.invalid") {
return Err("register address must have domain connecting-address.invalid".into());
}
let port = if let Some(p) = register.address.port() {
p
} else {
return Err("register address must specify port".into());
};
let addr = Addr {
ip: remote_addr,
port,
protocol,
};
let addr = register.address.with_ip(remote_addr);
let challenge = shared.challenge_for_addr(&addr);
let correct_challenge = register
@ -756,7 +763,7 @@ fn handle_register(
tokio::spawn(send_challenge(
connless_request_token_7,
shared.socket.clone(),
SocketAddr::new(remote_addr, port),
SocketAddr::new(addr.ip, addr.port),
register.challenge_secret,
challenge.current,
));
@ -765,37 +772,50 @@ fn handle_register(
Ok(result)
}
fn handle_delete(
shared: Shared,
remote_addr: IpAddr,
delete: Delete,
) -> Result<RegisterResponse, RegisterError> {
let addr = delete.address.with_ip(remote_addr);
match shared.lock_servers().remove(addr, delete.secret) {
RemoveResult::Removed => {
debug!("successfully removed {}", addr);
Ok(RegisterResponse::Success)
}
RemoveResult::NotFound => Err("could not find registered server".into()),
}
}
fn parse_opt<T: str::FromStr>(
headers: &warp::http::HeaderMap,
name: &str,
) -> Result<Option<T>, RegisterError>
where
T::Err: fmt::Display,
{
headers
.get(name)
.map(|v| -> Result<T, RegisterError> {
v.to_str()
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))?
.parse()
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))
})
.transpose()
}
fn parse<T: str::FromStr>(headers: &warp::http::HeaderMap, name: &str) -> Result<T, RegisterError>
where
T::Err: fmt::Display,
{
parse_opt(headers, name)?
.ok_or_else(|| RegisterError::new(format!("missing required header {}", name)))
}
fn register_from_headers(
headers: &warp::http::HeaderMap,
info: &[u8],
) -> Result<Register, RegisterError> {
fn parse_opt<T: str::FromStr>(
headers: &warp::http::HeaderMap,
name: &str,
) -> Result<Option<T>, RegisterError>
where
T::Err: fmt::Display,
{
headers
.get(name)
.map(|v| -> Result<T, RegisterError> {
v.to_str()
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))?
.parse()
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))
})
.transpose()
}
fn parse<T: str::FromStr>(
headers: &warp::http::HeaderMap,
name: &str,
) -> Result<T, RegisterError>
where
T::Err: fmt::Display,
{
parse_opt(headers, name)?
.ok_or_else(|| RegisterError::new(format!("missing required header {}", name)))
}
Ok(Register {
address: parse(headers, "Address")?,
secret: parse(headers, "Secret")?,
@ -816,6 +836,13 @@ fn register_from_headers(
})
}
fn delete_from_headers(headers: &warp::http::HeaderMap) -> Result<Delete, RegisterError> {
Ok(Delete {
address: parse(headers, "Address")?,
secret: parse(headers, "Secret")?,
})
}
async fn recover(err: warp::Rejection) -> Result<impl warp::Reply, warp::Rejection> {
use warp::http::StatusCode;
let (e, status): (&dyn fmt::Display, _) = if err.is_not_found() {
@ -967,16 +994,58 @@ async fn main() {
timekeeper,
));
let register = warp::post()
.and(warp::path!("ddnet" / "15" / "register"))
let connecting_addr = move |addr: Option<SocketAddr>,
headers: &warp::http::HeaderMap|
-> Result<IpAddr, RegisterError> {
let mut addr = if let Some(header) = &connecting_ip_header {
headers
.get(header)
.ok_or_else(|| RegisterError::new(format!("missing {} header", header)))?
.to_str()
.map_err(|_| RegisterError::from("non-ASCII in connecting IP header"))?
.parse()
.map_err(|e| RegisterError::new(format!("{}", e)))?
} else {
addr.unwrap().ip()
};
if let IpAddr::V6(v6) = addr {
if let Some(v4) = v6.to_ipv4() {
// TODO: switch to `to_ipv4_mapped` in the future.
if !v6.is_loopback() {
addr = IpAddr::from(v4);
}
}
}
Ok(addr)
};
fn build_response<F>(f: F) -> Result<warp::http::Response<String>, warp::http::Error>
where
F: FnOnce() -> Result<RegisterResponse, RegisterError> + UnwindSafe,
{
let (http_status, body) = match panic::catch_unwind(f) {
Ok(Ok(r)) => (warp::http::StatusCode::OK, r),
Ok(Err(e)) => (e.status(), RegisterResponse::Error(e)),
Err(_) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
RegisterResponse::Error("unexpected panic".into()),
),
};
warp::http::Response::builder()
.status(http_status)
.header(warp::http::header::CONTENT_TYPE, "application/json")
.body(json::to_string(&body).unwrap() + "\n")
}
let register = warp::path!("ddnet" / "15" / "register")
.and(warp::post())
.and(warp::header::headers_cloned())
.and(warp::addr::remote())
.and(warp::body::content_length_limit(16 * 1024)) // limit body size to 16 KiB
.and(warp::body::bytes())
.map(
move |headers: warp::http::HeaderMap, addr: Option<SocketAddr>, info: bytes::Bytes| {
let (http_status, body) = match panic::catch_unwind(|| {
let register = register_from_headers(&headers, &info)?;
build_response(|| {
let shared = Shared {
challenger: &challenger,
locations: &locations,
@ -984,40 +1053,21 @@ async fn main() {
socket: &socket.0,
timekeeper,
};
let mut addr = if let Some(header) = &connecting_ip_header {
headers
.get(header)
.ok_or_else(|| {
RegisterError::new(format!("missing {} header", header))
})?
.to_str()
.map_err(|_| RegisterError::from("non-ASCII in connecting IP header"))?
.parse()
.map_err(|e| RegisterError::new(format!("{}", e)))?
} else {
addr.unwrap().ip()
};
if let IpAddr::V6(v6) = addr {
if let Some(v4) = v6.to_ipv4() {
// TODO: switch to `to_ipv4_mapped` in the future.
if !v6.is_loopback() {
addr = IpAddr::from(v4);
}
let addr = connecting_addr(addr, &headers)?;
match headers.get("Action").map(warp::http::HeaderValue::as_bytes) {
None => {
let register = register_from_headers(&headers, &info)?;
handle_register(shared, addr, register)
}
Some(b"delete") => {
let delete = delete_from_headers(&headers)?;
handle_delete(shared, addr, delete)
}
Some(action) => {
Err(RegisterError::new(format!("Unknown Action header value {:?} (must either not be present or have value \"delete\"", String::from_utf8_lossy(action))))
}
}
handle_register(shared, addr, register)
}) {
Ok(Ok(r)) => (warp::http::StatusCode::OK, r),
Ok(Err(e)) => (e.status(), RegisterResponse::Error(e)),
Err(_) => (
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
RegisterResponse::Error("unexpected panic".into()),
),
};
warp::http::Response::builder()
.status(http_status)
.header(warp::http::header::CONTENT_TYPE, "application/json")
.body(json::to_string(&body).unwrap() + "\n")
})
},
)
.recover(recover);