diff --git a/src/main.rs b/src/main.rs index 2c2e240..951ba44 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, bail, Result}; use log::{debug, info, trace, warn}; use regex::Regex; -use rocket::{fairing::AdHoc, get, post, routes, State}; +use rocket::{fairing::AdHoc, get, http::Status, post, routes, Response, State}; use rocket_contrib::json::Json; use serde::{Deserialize, Serialize}; @@ -85,14 +85,15 @@ fn execute_hook(name: &str, hook: &Hook, data: &serde_json::Value) -> Result<()> } #[post("/", format = "json", data = "")] -fn receive_hook(address: SocketAddr, config: State, data: Json) -> Result { - info!("POST request received from: {}", address); +fn receive_hook(address: SocketAddr, config: State, data: Json) -> Result { + info!("Post request received from: {}", address); + let mut response = Response::new(); let data = serde_json::to_value(data.0)?; trace!("Data received from: {}\n{}", address, data); - if let Some(secret) = data.pointer("secret") { + if let Some(secret) = data.pointer("/secret") { if let Some(secret) = secret.as_str() { let hooks: HashMap<&String, &Hook> = config .hooks @@ -101,16 +102,23 @@ fn receive_hook(address: SocketAddr, config: State, data: Json) -> .collect(); if hooks.is_empty() { - warn!("Secret did not match any hook"); + warn!("Secret from {} did not match any hook", address); + response.set_status(Status::Unauthorized); } else { for (hook_name, hook) in hooks { execute_hook(&hook_name, &hook, &data)?; } } + } else { + warn!("Data received from {} contains invalid data", address); + response.set_status(Status::BadRequest); } + } else { + warn!("Data received from {} did not contain a secret", address); + response.set_status(Status::NotFound); } - Ok("Request received.".to_string()) + Ok(response) } fn get_config() -> Result { @@ -145,9 +153,7 @@ fn get_config() -> Result { fn main() -> Result<()> { env_logger::init(); - let config = get_config()?; - - let config: Config = serde_yaml::from_reader(BufReader::new(config))?; + let config: Config = serde_yaml::from_reader(BufReader::new(get_config()?))?; trace!("Parsed configuration:\n{}", serde_yaml::to_string(&config)?); @@ -160,3 +166,77 @@ fn main() -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use rocket::{http::ContentType, local::Client}; + + #[test] + fn index() { + let rocket = rocket::ignite().mount("/", routes![index]); + + let client = Client::new(rocket).unwrap(); + let mut response = client.get("/").dispatch(); + + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.body_string(), Some("Hello, webhookey!".into())); + } + + #[test] + fn secret() { + let mut hooks = HashMap::new(); + hooks.insert( + "test_hook".to_string(), + Hook { + action: None, + secrets: vec!["valid".to_string()], + filters: HashMap::new(), + }, + ); + let config = Config { hooks: hooks }; + + let rocket = rocket::ignite() + .mount("/", routes![receive_hook]) + .attach(AdHoc::on_attach("webhookey config", move |rocket| { + Ok(rocket.manage(config)) + })); + + let client = Client::new(rocket).unwrap(); + let response = client + .post("/") + .header(ContentType::JSON) + .remote("127.0.0.1:8000".parse().unwrap()) + .body(r#"{ "secret": "valid" }"#) + .dispatch(); + + assert_eq!(response.status(), Status::Ok); + + let response = client + .post("/") + .header(ContentType::JSON) + .remote("127.0.0.1:8000".parse().unwrap()) + .body(r#"{ "secret": "invalid" }"#) + .dispatch(); + + assert_eq!(response.status(), Status::Unauthorized); + + let response = client + .post("/") + .header(ContentType::JSON) + .remote("127.0.0.1:8000".parse().unwrap()) + .body(r#"{ "not_secret": "invalid" }"#) + .dispatch(); + + assert_eq!(response.status(), Status::NotFound); + + let response = client + .post("/") + .header(ContentType::JSON) + .remote("127.0.0.1:8000".parse().unwrap()) + .body(r#"{ "not_secret": "invalid" "#) + .dispatch(); + + assert_eq!(response.status(), Status::BadRequest); + } +}