diff --git a/Cargo.lock b/Cargo.lock index 7da721dbca9382c3d6637651afa8b21ad7980650..5f520b6abf2950e22d7a8f9ecef6c3cd4a94157f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1135,6 +1135,7 @@ dependencies = [ "anyhow", "arrayvec", "async-graphql", + "async-mutex", "async-trait", "dubp", "duniter-conf", diff --git a/rust-libs/duniter-conf/src/gva_conf.rs b/rust-libs/duniter-conf/src/gva_conf.rs index bad62480bec695117c26ed5a401ac878e9becabc..38825c5617e0787d5547ad379a2df6ce6dc64658 100644 --- a/rust-libs/duniter-conf/src/gva_conf.rs +++ b/rust-libs/duniter-conf/src/gva_conf.rs @@ -13,7 +13,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see <https://www.gnu.org/licenses/>. -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use crate::*; @@ -30,6 +30,7 @@ pub struct GvaConf { remote_path: Option<String>, remote_subscriptions_path: Option<String>, remote_tls: Option<bool>, + whitelist: Option<Vec<IpAddr>>, } impl GvaConf { @@ -99,4 +100,14 @@ impl GvaConf { pub fn get_remote_tls(&self) -> bool { self.remote_tls.unwrap_or(false) } + pub fn get_whitelist(&self) -> &[IpAddr] { + if let Some(ref whitelist) = self.whitelist { + whitelist + } else { + &[ + IpAddr::V4(Ipv4Addr::LOCALHOST), + IpAddr::V6(Ipv6Addr::LOCALHOST), + ] + } + } } diff --git a/rust-libs/modules/duniter-gva/Cargo.toml b/rust-libs/modules/duniter-gva/Cargo.toml index b918aae3569e46b8f9dffa5c4e43a02a7fc382b3..42b07f85609cd41d71cdf79578945a6cedf87e62 100644 --- a/rust-libs/modules/duniter-gva/Cargo.toml +++ b/rust-libs/modules/duniter-gva/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" anyhow = "1.0.33" arrayvec = "0.5.1" async-graphql = "2.0.0" +async-mutex = "1.4.0" async-trait = "0.1.41" dubp = { version = "0.32.2" } duniter-conf = { path = "../../duniter-conf" } @@ -31,5 +32,5 @@ warp = "0.2" duniter-dbs = { path = "../../duniter-dbs", features = ["mem"] } mockall = "0.8.0" serde_json = "1.0.53" -tokio = { version = "0.2.22", features = ["macros", "rt-threaded"] } +tokio = { version = "0.2.22", features = ["macros", "rt-threaded", "time"] } unwrap = "1.2.1" diff --git a/rust-libs/modules/duniter-gva/src/anti_spam.rs b/rust-libs/modules/duniter-gva/src/anti_spam.rs new file mode 100644 index 0000000000000000000000000000000000000000..908d51e019af861bdcb950b7d656f3b7a1b5bdb3 --- /dev/null +++ b/rust-libs/modules/duniter-gva/src/anti_spam.rs @@ -0,0 +1,167 @@ +// Copyright (C) 2020 Éloïs SANCHEZ. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see <https://www.gnu.org/licenses/>. + +use crate::*; +use async_mutex::Mutex; +use duniter_dbs::kv_typed::prelude::Arc; +use std::{ + collections::{HashMap, HashSet}, + iter::FromIterator, + net::IpAddr, + time::Duration, + time::Instant, +}; + +const COUNT_INTERVAL: usize = 40; +const MIN_DURATION_INTERVAL: Duration = Duration::from_secs(20); +const LARGE_DURATION_INTERVAL: Duration = Duration::from_secs(180); +const REDUCED_COUNT_INTERVAL: usize = COUNT_INTERVAL - 5; +const MAX_BAN_COUNT: usize = 16; +const BAN_FORGET_MIN_DURATION: Duration = Duration::from_secs(180); + +#[derive(Clone)] +pub(crate) struct AntiSpam { + state: Arc<Mutex<AntiSpamInner>>, + whitelist: HashSet<IpAddr>, +} + +struct AntiSpamInner { + ban: HashMap<IpAddr, (bool, usize, Instant)>, + ips_time: HashMap<IpAddr, (usize, Instant)>, +} + +impl From<&GvaConf> for AntiSpam { + fn from(conf: &GvaConf) -> Self { + AntiSpam { + state: Arc::new(Mutex::new(AntiSpamInner { + ban: HashMap::with_capacity(10), + ips_time: HashMap::with_capacity(10), + })), + whitelist: HashSet::from_iter(conf.get_whitelist().iter().copied()), + } + } +} + +impl AntiSpam { + fn verify_interval(ip: IpAddr, state: &mut AntiSpamInner, ban_count: usize) -> bool { + if let Some((count, instant)) = state.ips_time.get(&ip).copied() { + if count == COUNT_INTERVAL { + let duration = Instant::now().duration_since(instant); + if duration > MIN_DURATION_INTERVAL { + if duration > LARGE_DURATION_INTERVAL { + state.ips_time.insert(ip, (1, Instant::now())); + true + } else { + state + .ips_time + .insert(ip, (REDUCED_COUNT_INTERVAL, Instant::now())); + true + } + } else { + state.ban.insert(ip, (true, ban_count, Instant::now())); + false + } + } else { + state.ips_time.insert(ip, (count + 1, instant)); + true + } + } else { + state.ips_time.insert(ip, (1, Instant::now())); + true + } + } + pub(crate) async fn verify(&self, remote_addr_opt: Option<std::net::SocketAddr>) -> bool { + if let Some(remote_addr) = remote_addr_opt { + let ip = remote_addr.ip(); + if self.whitelist.contains(&ip) { + true + } else { + let mut guard = self.state.lock().await; + if let Some((is_banned, ban_count, instant)) = guard.ban.get(&ip).copied() { + let ban_duration = + Duration::from_secs(1 << std::cmp::min(ban_count, MAX_BAN_COUNT)); + if is_banned { + if Instant::now().duration_since(instant) > ban_duration { + guard.ban.insert(ip, (false, ban_count + 1, Instant::now())); + guard.ips_time.insert(ip, (1, Instant::now())); + true + } else { + guard.ban.insert(ip, (true, ban_count + 1, Instant::now())); + false + } + } else if Instant::now().duration_since(instant) + > std::cmp::max(ban_duration, BAN_FORGET_MIN_DURATION) + { + guard.ban.remove(&ip); + guard.ips_time.insert(ip, (1, Instant::now())); + true + } else { + Self::verify_interval(ip, &mut guard, ban_count) + } + } else { + Self::verify_interval(ip, &mut guard, 0) + } + } + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; + use tokio::time::delay_for; + + const LOCAL_IP4: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); + const LOCAL_IP6: IpAddr = IpAddr::V6(Ipv6Addr::LOCALHOST); + + #[tokio::test] + async fn test_anti_spam() { + let anti_spam = AntiSpam::from(&GvaConf::default()); + assert!(!anti_spam.verify(None).await); + + for _ in 0..(COUNT_INTERVAL * 2) { + assert!(anti_spam.verify(Some(SocketAddr::new(LOCAL_IP4, 0))).await); + assert!(anti_spam.verify(Some(SocketAddr::new(LOCAL_IP6, 0))).await); + } + + let extern_ip = IpAddr::V4(Ipv4Addr::UNSPECIFIED); + + // Consume max queries + for _ in 0..COUNT_INTERVAL { + assert!(anti_spam.verify(Some(SocketAddr::new(extern_ip, 0))).await); + } + // Should be banned + assert!(!anti_spam.verify(Some(SocketAddr::new(extern_ip, 0))).await); + + // Should be un-banned after one second + delay_for(Duration::from_millis(1_100)).await; + // Re-consume max queries + for _ in 0..COUNT_INTERVAL { + assert!(anti_spam.verify(Some(SocketAddr::new(extern_ip, 0))).await); + } + // Should be banned for 2 seconds this time + delay_for(Duration::from_millis(1_100)).await; + // Attempting a request when I'm banned must be twice my banning time + assert!(!anti_spam.verify(Some(SocketAddr::new(extern_ip, 0))).await); + delay_for(Duration::from_millis(4_100)).await; + // Re-consume max queries + for _ in 0..COUNT_INTERVAL { + assert!(anti_spam.verify(Some(SocketAddr::new(extern_ip, 0))).await); + } + } +} diff --git a/rust-libs/modules/duniter-gva/src/lib.rs b/rust-libs/modules/duniter-gva/src/lib.rs index 025f23f8874cff4b6695834e933122908d4f8fc3..2aec18bd67823bb301bbd8fb52e71884c65705f4 100644 --- a/rust-libs/modules/duniter-gva/src/lib.rs +++ b/rust-libs/modules/duniter-gva/src/lib.rs @@ -24,6 +24,7 @@ pub use duniter_conf::gva_conf::GvaConf; +mod anti_spam; mod entities; mod inputs; mod inputs_validators; diff --git a/rust-libs/modules/duniter-gva/src/warp_.rs b/rust-libs/modules/duniter-gva/src/warp_.rs index 5bf78f67a81a230ae88d8a03bd82d466cc1975a8..878fb5d21d77e8dbcba5ed4c9786dded4dfe9e0f 100644 --- a/rust-libs/modules/duniter-gva/src/warp_.rs +++ b/rust-libs/modules/duniter-gva/src/warp_.rs @@ -13,6 +13,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see <https://www.gnu.org/licenses/>. +use crate::anti_spam::AntiSpam; use crate::*; pub struct BadRequest(pub anyhow::Error); @@ -54,38 +55,49 @@ pub(crate) fn graphql( schema: GraphQlSchema, opts: async_graphql::http::MultipartOptions, ) -> impl warp::Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone { + let anti_spam = AntiSpam::from(conf); let opts = Arc::new(opts); warp::path::path(conf.get_path()) + .and(warp::addr::remote()) .and(warp::method()) .and(warp::query::raw().or(warp::any().map(String::new)).unify()) .and(warp::header::optional::<String>("content-type")) .and(warp::body::stream()) .and(warp::any().map(move || opts.clone())) .and(warp::any().map(move || schema.clone())) + .and(warp::any().map(move || anti_spam.clone())) .and_then( - |method, + |remote_addr, + method, query: String, content_type, body, opts: Arc<async_graphql::http::MultipartOptions>, - schema| async move { - if method == http::Method::GET { - let request: async_graphql::Request = serde_urlencoded::from_str(&query) + schema, + anti_spam: AntiSpam| async move { + if anti_spam.verify(remote_addr).await { + if method == http::Method::GET { + let request: async_graphql::Request = serde_urlencoded::from_str(&query) + .map_err(|err| warp::reject::custom(BadRequest(err.into())))?; + Ok::<_, Rejection>((schema, request)) + } else { + let request = async_graphql::http::receive_body( + content_type, + futures::TryStreamExt::map_err(body, |err| { + std::io::Error::new(std::io::ErrorKind::Other, err) + }) + .map_ok(|mut buf| warp::Buf::to_bytes(&mut buf)) + .into_async_read(), + async_graphql::http::MultipartOptions::clone(&opts), + ) + .await .map_err(|err| warp::reject::custom(BadRequest(err.into())))?; - Ok::<_, Rejection>((schema, request)) + Ok::<_, Rejection>((schema, request)) + } } else { - let request = async_graphql::http::receive_body( - content_type, - futures::TryStreamExt::map_err(body, |err| { - std::io::Error::new(std::io::ErrorKind::Other, err) - }) - .map_ok(|mut buf| warp::Buf::to_bytes(&mut buf)) - .into_async_read(), - async_graphql::http::MultipartOptions::clone(&opts), - ) - .await - .map_err(|err| warp::reject::custom(BadRequest(err.into())))?; - Ok::<_, Rejection>((schema, request)) + Err(warp::reject::custom(BadRequest(anyhow::Error::msg( + "too many requests", + )))) } }, ) @@ -100,11 +112,25 @@ pub(crate) fn graphql_ws( conf: &GvaConf, schema: GraphQlSchema, ) -> impl warp::Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone { + let anti_spam = AntiSpam::from(conf); warp::path::path(conf.get_subscriptions_path()) + .and(warp::addr::remote()) .and(warp::ws()) .and(warp::any().map(move || schema.clone())) - .map(|ws: warp::ws::Ws, schema: GraphQlSchema| { - ws.on_upgrade(move |websocket| { + .and(warp::any().map(move || anti_spam.clone())) + .and_then( + |remote_addr, ws: warp::ws::Ws, schema: GraphQlSchema, anti_spam: AntiSpam| async move { + if anti_spam.verify(remote_addr).await { + Ok((ws, schema)) + } else { + Err(warp::reject::custom(BadRequest(anyhow::Error::msg( + "too many requests", + )))) + } + }, + ) + .and_then(|(ws, schema): (warp::ws::Ws, GraphQlSchema)| { + let reply = ws.on_upgrade(move |websocket| { let (ws_sender, ws_receiver) = websocket.split(); async move { @@ -120,7 +146,12 @@ pub(crate) fn graphql_ws( .forward(ws_sender) .await; } - }) + }); + + futures::future::ready(Ok::<_, Rejection>(warp::reply::with_header( + reply, + "Sec-WebSocket-Protocol", + "graphql-ws", + ))) }) - .map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")) } diff --git a/server.ts b/server.ts index 2387e4802db1e90b6bcbf560a5fb1d23f10e3b46..280773d94ff7b7117c21446d1c9f48d995d48c6a 100644 --- a/server.ts +++ b/server.ts @@ -357,7 +357,7 @@ export class Server extends stream.Duplex implements HookableServer { await this.dal.init(this.conf, commandName); // Get rust endpoints for (let endpoint of this.dal.getRustEndpoints()) { - logger.info("TMP: rustEndpoint: %s", endpoint); + //logger.info("TMP: rustEndpoint: %s", endpoint); this.addEndpointsDefinitions(async () => endpoint); } // Maintenance