mirror of
https://github.com/ddnet/ddnet.git
synced 2024-11-10 01:58:19 +00:00
mastersrv: Add Action: delete
to /ddnet/15/register
This will delete a server from the server list without a timeout, allowing servers that want to stop registering or servers that shut down gracefully to no longer appear in the server list.
This commit is contained in:
parent
264c6f969d
commit
ce08cc0e53
|
@ -22,15 +22,29 @@ pub struct Addr {
|
|||
pub protocol: Protocol,
|
||||
}
|
||||
|
||||
/// A register address, serialized like
|
||||
/// tw-0.6+udp://connecting-address.invalid:8303.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
|
||||
pub struct RegisterAddr {
|
||||
pub port: u16,
|
||||
pub protocol: Protocol,
|
||||
}
|
||||
|
||||
impl fmt::Display for Protocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
self.as_str().fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct UnknownProtocol;
|
||||
|
||||
impl fmt::Display for UnknownProtocol {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
"protocol must be one of tw-0.5+udp, tw-0.6+udp or tw-0.7+udp".fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Protocol {
|
||||
type Err = UnknownProtocol;
|
||||
fn from_str(s: &str) -> Result<Protocol, UnknownProtocol> {
|
||||
|
@ -101,7 +115,7 @@ impl fmt::Display for Addr {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct InvalidAddr;
|
||||
|
||||
impl FromStr for Addr {
|
||||
|
@ -160,6 +174,100 @@ impl<'de> serde::Deserialize<'de> for Addr {
|
|||
}
|
||||
}
|
||||
|
||||
impl RegisterAddr {
|
||||
pub fn with_ip(self, ip: IpAddr) -> Addr {
|
||||
Addr {
|
||||
ip,
|
||||
port: self.port,
|
||||
protocol: self.protocol,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RegisterAddr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut buf: ArrayString<[u8; 128]> = ArrayString::new();
|
||||
write!(
|
||||
&mut buf,
|
||||
"{}://connecting-address.invalid:{}",
|
||||
self.protocol, self.port,
|
||||
)
|
||||
.unwrap();
|
||||
buf.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum ParseRegisterAddrError {
|
||||
Url(url::ParseError),
|
||||
Protocol(UnknownProtocol),
|
||||
HostNotConnectingAddressInvalid,
|
||||
PortNotPresent,
|
||||
}
|
||||
|
||||
impl fmt::Display for ParseRegisterAddrError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
use self::ParseRegisterAddrError::*;
|
||||
match *self {
|
||||
Url(e) => write!(f, "URL parse error: {}", e),
|
||||
Protocol(e) => write!(f, "protocol parse error: {}", e),
|
||||
HostNotConnectingAddressInvalid => write!(
|
||||
f,
|
||||
"register address must have domain connecting-address.invalid"
|
||||
),
|
||||
PortNotPresent => write!(f, "register address must specify port"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for RegisterAddr {
|
||||
type Err = ParseRegisterAddrError;
|
||||
fn from_str(s: &str) -> Result<RegisterAddr, ParseRegisterAddrError> {
|
||||
use self::ParseRegisterAddrError as Error;
|
||||
let url = Url::parse(s).map_err(Error::Url)?;
|
||||
let protocol: Protocol = url.scheme().parse().map_err(Error::Protocol)?;
|
||||
if url.host_str() != Some("connecting-address.invalid") {
|
||||
return Err(Error::HostNotConnectingAddressInvalid);
|
||||
}
|
||||
let port = url.port().ok_or(Error::PortNotPresent)?;
|
||||
Ok(RegisterAddr { port, protocol })
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for RegisterAddr {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut buf: ArrayString<[u8; 128]> = ArrayString::new();
|
||||
write!(&mut buf, "{}", self).unwrap();
|
||||
serializer.serialize_str(&buf)
|
||||
}
|
||||
}
|
||||
|
||||
struct RegisterAddrVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for RegisterAddrVisitor {
|
||||
type Value = RegisterAddr;
|
||||
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str("a URL like tw-0.6+udp://connecting-address.invalid:8303")
|
||||
}
|
||||
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<RegisterAddr, E> {
|
||||
let invalid_value = || E::invalid_value(serde::de::Unexpected::Str(v), &self);
|
||||
Ok(RegisterAddr::from_str(v).map_err(|_| invalid_value())?)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for RegisterAddr {
|
||||
fn deserialize<D>(deserializer: D) -> Result<RegisterAddr, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_str(RegisterAddrVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::Addr;
|
||||
|
@ -186,4 +294,15 @@ mod test {
|
|||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_addr_from_str() {
|
||||
assert_eq!(
|
||||
RegisterAddr::from_str("tw-0.6+udp://connecting-address.invalid:8303").unwrap(),
|
||||
RegisterAddr {
|
||||
port: 8303,
|
||||
protocol: Protocol::V6,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ use std::mem;
|
|||
use std::net::IpAddr;
|
||||
use std::net::SocketAddr;
|
||||
use std::panic;
|
||||
use std::panic::UnwindSafe;
|
||||
use std::path::Path;
|
||||
use std::process;
|
||||
use std::str;
|
||||
|
@ -35,7 +36,6 @@ use tokio::fs;
|
|||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::time;
|
||||
use url::Url;
|
||||
use warp::Filter;
|
||||
|
||||
#[macro_use]
|
||||
|
@ -43,6 +43,7 @@ extern crate log;
|
|||
|
||||
use crate::addr::Addr;
|
||||
use crate::addr::Protocol;
|
||||
use crate::addr::RegisterAddr;
|
||||
use crate::locations::Location;
|
||||
use crate::locations::Locations;
|
||||
|
||||
|
@ -55,11 +56,9 @@ const SERVER_TIMEOUT_SECONDS: u64 = 30;
|
|||
|
||||
type ShortString = ArrayString<[u8; 64]>;
|
||||
|
||||
// TODO: delete action for server shutdown
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Register {
|
||||
address: Url,
|
||||
address: RegisterAddr,
|
||||
secret: ShortString,
|
||||
connless_request_token: Option<ShortString>,
|
||||
challenge_secret: ShortString,
|
||||
|
@ -68,6 +67,12 @@ struct Register {
|
|||
info: Option<json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Delete {
|
||||
address: RegisterAddr,
|
||||
secret: ShortString,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "status")]
|
||||
enum RegisterResponse {
|
||||
|
@ -337,6 +342,11 @@ enum AddResult {
|
|||
Obsolete,
|
||||
}
|
||||
|
||||
enum RemoveResult {
|
||||
Removed,
|
||||
NotFound,
|
||||
}
|
||||
|
||||
struct FromDumpError;
|
||||
|
||||
impl Servers {
|
||||
|
@ -415,6 +425,19 @@ impl Servers {
|
|||
AddResult::Refreshed
|
||||
}
|
||||
}
|
||||
fn remove(&mut self, addr: Addr, secret: ShortString) -> RemoveResult {
|
||||
match self.addresses.get(&addr) {
|
||||
Some(a_info) if secret == a_info.secret => {}
|
||||
_ => return RemoveResult::NotFound,
|
||||
}
|
||||
assert!(self.addresses.remove(&addr).is_some());
|
||||
let server = self.servers.get_mut(&secret).unwrap();
|
||||
server.addresses.retain(|&a| a != addr);
|
||||
if server.addresses.is_empty() {
|
||||
assert!(self.servers.remove(&secret).is_some());
|
||||
}
|
||||
RemoveResult::Removed
|
||||
}
|
||||
fn prune_before(&mut self, time: Timestamp, log: bool) {
|
||||
let mut remove = Vec::new();
|
||||
for (&addr, a_info) in &self.addresses {
|
||||
|
@ -658,11 +681,7 @@ fn handle_register(
|
|||
remote_addr: IpAddr,
|
||||
register: Register,
|
||||
) -> Result<RegisterResponse, RegisterError> {
|
||||
let protocol: Protocol = register.address.scheme().parse().map_err(|_| {
|
||||
"register address must start with one of tw-0.5+udp://, tw-0.6+udp://, tw-0.7+udp://"
|
||||
})?;
|
||||
|
||||
let connless_request_token_7 = match protocol {
|
||||
let connless_request_token_7 = match register.address.protocol {
|
||||
Protocol::V5 => None,
|
||||
Protocol::V6 => None,
|
||||
Protocol::V7 => {
|
||||
|
@ -677,20 +696,8 @@ fn handle_register(
|
|||
Some(token)
|
||||
}
|
||||
};
|
||||
if register.address.host_str() != Some("connecting-address.invalid") {
|
||||
return Err("register address must have domain connecting-address.invalid".into());
|
||||
}
|
||||
let port = if let Some(p) = register.address.port() {
|
||||
p
|
||||
} else {
|
||||
return Err("register address must specify port".into());
|
||||
};
|
||||
|
||||
let addr = Addr {
|
||||
ip: remote_addr,
|
||||
port,
|
||||
protocol,
|
||||
};
|
||||
let addr = register.address.with_ip(remote_addr);
|
||||
let challenge = shared.challenge_for_addr(&addr);
|
||||
|
||||
let correct_challenge = register
|
||||
|
@ -756,7 +763,7 @@ fn handle_register(
|
|||
tokio::spawn(send_challenge(
|
||||
connless_request_token_7,
|
||||
shared.socket.clone(),
|
||||
SocketAddr::new(remote_addr, port),
|
||||
SocketAddr::new(addr.ip, addr.port),
|
||||
register.challenge_secret,
|
||||
challenge.current,
|
||||
));
|
||||
|
@ -765,37 +772,50 @@ fn handle_register(
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
fn handle_delete(
|
||||
shared: Shared,
|
||||
remote_addr: IpAddr,
|
||||
delete: Delete,
|
||||
) -> Result<RegisterResponse, RegisterError> {
|
||||
let addr = delete.address.with_ip(remote_addr);
|
||||
match shared.lock_servers().remove(addr, delete.secret) {
|
||||
RemoveResult::Removed => {
|
||||
debug!("successfully removed {}", addr);
|
||||
Ok(RegisterResponse::Success)
|
||||
}
|
||||
RemoveResult::NotFound => Err("could not find registered server".into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_opt<T: str::FromStr>(
|
||||
headers: &warp::http::HeaderMap,
|
||||
name: &str,
|
||||
) -> Result<Option<T>, RegisterError>
|
||||
where
|
||||
T::Err: fmt::Display,
|
||||
{
|
||||
headers
|
||||
.get(name)
|
||||
.map(|v| -> Result<T, RegisterError> {
|
||||
v.to_str()
|
||||
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))?
|
||||
.parse()
|
||||
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn parse<T: str::FromStr>(headers: &warp::http::HeaderMap, name: &str) -> Result<T, RegisterError>
|
||||
where
|
||||
T::Err: fmt::Display,
|
||||
{
|
||||
parse_opt(headers, name)?
|
||||
.ok_or_else(|| RegisterError::new(format!("missing required header {}", name)))
|
||||
}
|
||||
|
||||
fn register_from_headers(
|
||||
headers: &warp::http::HeaderMap,
|
||||
info: &[u8],
|
||||
) -> Result<Register, RegisterError> {
|
||||
fn parse_opt<T: str::FromStr>(
|
||||
headers: &warp::http::HeaderMap,
|
||||
name: &str,
|
||||
) -> Result<Option<T>, RegisterError>
|
||||
where
|
||||
T::Err: fmt::Display,
|
||||
{
|
||||
headers
|
||||
.get(name)
|
||||
.map(|v| -> Result<T, RegisterError> {
|
||||
v.to_str()
|
||||
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))?
|
||||
.parse()
|
||||
.map_err(|e| RegisterError::new(format!("invalid header {}: {}", name, e)))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
fn parse<T: str::FromStr>(
|
||||
headers: &warp::http::HeaderMap,
|
||||
name: &str,
|
||||
) -> Result<T, RegisterError>
|
||||
where
|
||||
T::Err: fmt::Display,
|
||||
{
|
||||
parse_opt(headers, name)?
|
||||
.ok_or_else(|| RegisterError::new(format!("missing required header {}", name)))
|
||||
}
|
||||
Ok(Register {
|
||||
address: parse(headers, "Address")?,
|
||||
secret: parse(headers, "Secret")?,
|
||||
|
@ -816,6 +836,13 @@ fn register_from_headers(
|
|||
})
|
||||
}
|
||||
|
||||
fn delete_from_headers(headers: &warp::http::HeaderMap) -> Result<Delete, RegisterError> {
|
||||
Ok(Delete {
|
||||
address: parse(headers, "Address")?,
|
||||
secret: parse(headers, "Secret")?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn recover(err: warp::Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
use warp::http::StatusCode;
|
||||
let (e, status): (&dyn fmt::Display, _) = if err.is_not_found() {
|
||||
|
@ -967,16 +994,58 @@ async fn main() {
|
|||
timekeeper,
|
||||
));
|
||||
|
||||
let register = warp::post()
|
||||
.and(warp::path!("ddnet" / "15" / "register"))
|
||||
let connecting_addr = move |addr: Option<SocketAddr>,
|
||||
headers: &warp::http::HeaderMap|
|
||||
-> Result<IpAddr, RegisterError> {
|
||||
let mut addr = if let Some(header) = &connecting_ip_header {
|
||||
headers
|
||||
.get(header)
|
||||
.ok_or_else(|| RegisterError::new(format!("missing {} header", header)))?
|
||||
.to_str()
|
||||
.map_err(|_| RegisterError::from("non-ASCII in connecting IP header"))?
|
||||
.parse()
|
||||
.map_err(|e| RegisterError::new(format!("{}", e)))?
|
||||
} else {
|
||||
addr.unwrap().ip()
|
||||
};
|
||||
if let IpAddr::V6(v6) = addr {
|
||||
if let Some(v4) = v6.to_ipv4() {
|
||||
// TODO: switch to `to_ipv4_mapped` in the future.
|
||||
if !v6.is_loopback() {
|
||||
addr = IpAddr::from(v4);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(addr)
|
||||
};
|
||||
|
||||
fn build_response<F>(f: F) -> Result<warp::http::Response<String>, warp::http::Error>
|
||||
where
|
||||
F: FnOnce() -> Result<RegisterResponse, RegisterError> + UnwindSafe,
|
||||
{
|
||||
let (http_status, body) = match panic::catch_unwind(f) {
|
||||
Ok(Ok(r)) => (warp::http::StatusCode::OK, r),
|
||||
Ok(Err(e)) => (e.status(), RegisterResponse::Error(e)),
|
||||
Err(_) => (
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RegisterResponse::Error("unexpected panic".into()),
|
||||
),
|
||||
};
|
||||
warp::http::Response::builder()
|
||||
.status(http_status)
|
||||
.header(warp::http::header::CONTENT_TYPE, "application/json")
|
||||
.body(json::to_string(&body).unwrap() + "\n")
|
||||
}
|
||||
|
||||
let register = warp::path!("ddnet" / "15" / "register")
|
||||
.and(warp::post())
|
||||
.and(warp::header::headers_cloned())
|
||||
.and(warp::addr::remote())
|
||||
.and(warp::body::content_length_limit(16 * 1024)) // limit body size to 16 KiB
|
||||
.and(warp::body::bytes())
|
||||
.map(
|
||||
move |headers: warp::http::HeaderMap, addr: Option<SocketAddr>, info: bytes::Bytes| {
|
||||
let (http_status, body) = match panic::catch_unwind(|| {
|
||||
let register = register_from_headers(&headers, &info)?;
|
||||
build_response(|| {
|
||||
let shared = Shared {
|
||||
challenger: &challenger,
|
||||
locations: &locations,
|
||||
|
@ -984,40 +1053,21 @@ async fn main() {
|
|||
socket: &socket.0,
|
||||
timekeeper,
|
||||
};
|
||||
let mut addr = if let Some(header) = &connecting_ip_header {
|
||||
headers
|
||||
.get(header)
|
||||
.ok_or_else(|| {
|
||||
RegisterError::new(format!("missing {} header", header))
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|_| RegisterError::from("non-ASCII in connecting IP header"))?
|
||||
.parse()
|
||||
.map_err(|e| RegisterError::new(format!("{}", e)))?
|
||||
} else {
|
||||
addr.unwrap().ip()
|
||||
};
|
||||
if let IpAddr::V6(v6) = addr {
|
||||
if let Some(v4) = v6.to_ipv4() {
|
||||
// TODO: switch to `to_ipv4_mapped` in the future.
|
||||
if !v6.is_loopback() {
|
||||
addr = IpAddr::from(v4);
|
||||
}
|
||||
let addr = connecting_addr(addr, &headers)?;
|
||||
match headers.get("Action").map(warp::http::HeaderValue::as_bytes) {
|
||||
None => {
|
||||
let register = register_from_headers(&headers, &info)?;
|
||||
handle_register(shared, addr, register)
|
||||
}
|
||||
Some(b"delete") => {
|
||||
let delete = delete_from_headers(&headers)?;
|
||||
handle_delete(shared, addr, delete)
|
||||
}
|
||||
Some(action) => {
|
||||
Err(RegisterError::new(format!("Unknown Action header value {:?} (must either not be present or have value \"delete\"", String::from_utf8_lossy(action))))
|
||||
}
|
||||
}
|
||||
handle_register(shared, addr, register)
|
||||
}) {
|
||||
Ok(Ok(r)) => (warp::http::StatusCode::OK, r),
|
||||
Ok(Err(e)) => (e.status(), RegisterResponse::Error(e)),
|
||||
Err(_) => (
|
||||
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
RegisterResponse::Error("unexpected panic".into()),
|
||||
),
|
||||
};
|
||||
warp::http::Response::builder()
|
||||
.status(http_status)
|
||||
.header(warp::http::header::CONTENT_TYPE, "application/json")
|
||||
.body(json::to_string(&body).unwrap() + "\n")
|
||||
})
|
||||
},
|
||||
)
|
||||
.recover(recover);
|
||||
|
|
Loading…
Reference in a new issue