diff --git a/Cargo.toml b/Cargo.toml index a215d08c0f..4b0ff3c4ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,4 +28,5 @@ members = [ "examples/hello_alt_methods", "examples/raw_upload", "examples/pastebin", + "examples/state", ] diff --git a/examples/from_request/src/main.rs b/examples/from_request/src/main.rs index 8ecb3c419e..cc2ea6c24f 100644 --- a/examples/from_request/src/main.rs +++ b/examples/from_request/src/main.rs @@ -40,13 +40,15 @@ mod test { use rocket::http::Header; fn test_header_count<'h>(headers: Vec>) { + let rocket = rocket::ignite() + .mount("/", routes![super::header_count]); + let num_headers = headers.len(); let mut req = MockRequest::new(Get, "/"); for header in headers { req = req.header(header); } - let rocket = rocket::ignite().mount("/", routes![super::header_count]); let mut response = req.dispatch_with(&rocket); let expect = format!("Your request contained {} headers!", num_headers); diff --git a/examples/handlebars_templates/src/tests.rs b/examples/handlebars_templates/src/tests.rs index 9e385afd46..18654b3a4f 100644 --- a/examples/handlebars_templates/src/tests.rs +++ b/examples/handlebars_templates/src/tests.rs @@ -11,7 +11,8 @@ macro_rules! run_test { .mount("/", routes![super::index, super::get]) .catch(errors![super::not_found]); - $test_fn($req.dispatch_with(&rocket)); + let mut req = $req; + $test_fn(req.dispatch_with(&rocket)); }) } @@ -19,7 +20,7 @@ macro_rules! run_test { fn test_root() { // Check that the redirect works. for method in &[Get, Head] { - let mut req = MockRequest::new(*method, "/"); + let req = MockRequest::new(*method, "/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::SeeOther); assert!(response.body().is_none()); @@ -31,7 +32,7 @@ fn test_root() { // Check that other request methods are not accepted (and instead caught). for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] { - let mut req = MockRequest::new(*method, "/"); + let req = MockRequest::new(*method, "/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); @@ -48,7 +49,7 @@ fn test_root() { #[test] fn test_name() { // Check that the /hello/ route works. - let mut req = MockRequest::new(Get, "/hello/Jack"); + let req = MockRequest::new(Get, "/hello/Jack"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::Ok); @@ -66,7 +67,7 @@ fn test_name() { #[test] fn test_404() { // Check that the error catcher works. - let mut req = MockRequest::new(Get, "/hello/"); + let req = MockRequest::new(Get, "/hello/"); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); diff --git a/examples/json/src/tests.rs b/examples/json/src/tests.rs index 97ba7b4885..4476ed0c8d 100644 --- a/examples/json/src/tests.rs +++ b/examples/json/src/tests.rs @@ -10,14 +10,15 @@ macro_rules! run_test { .mount("/message", routes![super::new, super::update, super::get]) .catch(errors![super::not_found]); - $test_fn($req.dispatch_with(&rocket)); + let mut req = $req; + $test_fn(req.dispatch_with(&rocket)); }) } #[test] fn bad_get_put() { // Try to get a message with an ID that doesn't exist. - let mut req = MockRequest::new(Get, "/message/99").header(ContentType::JSON); + let req = MockRequest::new(Get, "/message/99").header(ContentType::JSON); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); @@ -27,7 +28,7 @@ fn bad_get_put() { }); // Try to get a message with an invalid ID. - let mut req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON); + let req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::NotFound); let body = response.body().unwrap().into_string().unwrap(); @@ -35,13 +36,13 @@ fn bad_get_put() { }); // Try to put a message without a proper body. - let mut req = MockRequest::new(Put, "/message/80").header(ContentType::JSON); + let req = MockRequest::new(Put, "/message/80").header(ContentType::JSON); run_test!(req, |response: Response| { assert_eq!(response.status(), Status::BadRequest); }); // Try to put a message for an ID that doesn't exist. - let mut req = MockRequest::new(Put, "/message/80") + let req = MockRequest::new(Put, "/message/80") .header(ContentType::JSON) .body(r#"{ "contents": "Bye bye, world!" }"#); @@ -53,13 +54,13 @@ fn bad_get_put() { #[test] fn post_get_put_get() { // Check that a message with ID 1 doesn't exist. - let mut req = MockRequest::new(Get, "/message/1").header(ContentType::JSON); + let req = MockRequest::new(Get, "/message/1").header(ContentType::JSON); run_test!(req, |response: Response| { assert_eq!(response.status(), Status::NotFound); }); // Add a new message with ID 1. - let mut req = MockRequest::new(Post, "/message/1") + let req = MockRequest::new(Post, "/message/1") .header(ContentType::JSON) .body(r#"{ "contents": "Hello, world!" }"#); @@ -68,7 +69,7 @@ fn post_get_put_get() { }); // Check that the message exists with the correct contents. - let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON); + let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::Ok); let body = response.body().unwrap().into_string().unwrap(); @@ -76,7 +77,7 @@ fn post_get_put_get() { }); // Change the message contents. - let mut req = MockRequest::new(Put, "/message/1") + let req = MockRequest::new(Put, "/message/1") .header(ContentType::JSON) .body(r#"{ "contents": "Bye bye, world!" }"#); @@ -85,7 +86,7 @@ fn post_get_put_get() { }); // Check that the message exists with the updated contents. - let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON); + let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON); run_test!(req, |mut response: Response| { assert_eq!(response.status(), Status::Ok); let body = response.body().unwrap().into_string().unwrap(); diff --git a/examples/state/Cargo.toml b/examples/state/Cargo.toml new file mode 100644 index 0000000000..48efe471e1 --- /dev/null +++ b/examples/state/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "state" +version = "0.0.1" +workspace = "../../" + +[dependencies] +rocket = { path = "../../lib" } +rocket_codegen = { path = "../../codegen" } + +[dev-dependencies] +rocket = { path = "../../lib", features = ["testing"] } diff --git a/examples/state/src/main.rs b/examples/state/src/main.rs new file mode 100644 index 0000000000..4a750f7d2a --- /dev/null +++ b/examples/state/src/main.rs @@ -0,0 +1,36 @@ +#![feature(plugin)] +#![plugin(rocket_codegen)] + +extern crate rocket; + +#[cfg(test)] mod tests; + +use std::sync::atomic::{AtomicUsize, Ordering}; + +use rocket::State; +use rocket::response::content; + +struct HitCount(AtomicUsize); + +#[get("/")] +fn index(hit_count: State) -> content::HTML { + hit_count.0.fetch_add(1, Ordering::Relaxed); + let msg = "Your visit has been recorded!"; + let count = format!("Visits: {}", count(hit_count)); + content::HTML(format!("{}

{}", msg, count)) +} + +#[get("/count")] +fn count(hit_count: State) -> String { + hit_count.0.load(Ordering::Relaxed).to_string() +} + +fn rocket() -> rocket::Rocket { + rocket::ignite() + .mount("/", routes![index, count]) + .manage(HitCount(AtomicUsize::new(0))) +} + +fn main() { + rocket().launch(); +} diff --git a/examples/state/src/tests.rs b/examples/state/src/tests.rs new file mode 100644 index 0000000000..7986d52ad0 --- /dev/null +++ b/examples/state/src/tests.rs @@ -0,0 +1,43 @@ +use rocket::Rocket; +use rocket::testing::MockRequest; +use rocket::http::Method::*; +use rocket::http::Status; + +fn register_hit(rocket: &Rocket) { + let mut req = MockRequest::new(Get, "/"); + let response = req.dispatch_with(&rocket); + assert_eq!(response.status(), Status::Ok); +} + +fn get_count(rocket: &Rocket) -> usize { + let mut req = MockRequest::new(Get, "/count"); + let mut response = req.dispatch_with(&rocket); + let body_string = response.body().and_then(|b| b.into_string()).unwrap(); + body_string.parse().unwrap() +} + +#[test] +fn test_count() { + let rocket = super::rocket(); + + // Count should start at 0. + assert_eq!(get_count(&rocket), 0); + + for _ in 0..99 { register_hit(&rocket); } + assert_eq!(get_count(&rocket), 99); + + register_hit(&rocket); + assert_eq!(get_count(&rocket), 100); +} + +// Cargo runs each test in parallel on different threads. We use all of these +// tests below to show (and assert) that state is managed per-Rocket instance. +#[test] fn test_count_parallel() { test_count() } +#[test] fn test_count_parallel_2() { test_count() } +#[test] fn test_count_parallel_3() { test_count() } +#[test] fn test_count_parallel_4() { test_count() } +#[test] fn test_count_parallel_5() { test_count() } +#[test] fn test_count_parallel_6() { test_count() } +#[test] fn test_count_parallel_7() { test_count() } +#[test] fn test_count_parallel_8() { test_count() } +#[test] fn test_count_parallel_9() { test_count() } diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 180f745a31..53cb7c8c68 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -21,6 +21,7 @@ url = "^1" hyper = { version = "^0.9.14", default-features = false } toml = { version = "^0.2", default-features = false } num_cpus = "1" +state = "^0.2" # cookie = "^0.3" [dev-dependencies] diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 87abb98a1e..78fc8d041e 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -96,6 +96,7 @@ extern crate hyper; extern crate url; extern crate toml; extern crate num_cpus; +extern crate state; #[cfg(test)] #[macro_use] extern crate lazy_static; @@ -123,7 +124,7 @@ mod ext; #[doc(inline)] pub use outcome::Outcome; #[doc(inline)] pub use data::Data; pub use router::Route; -pub use request::Request; +pub use request::{Request, State}; pub use error::Error; pub use catcher::Catcher; pub use rocket::Rocket; diff --git a/lib/src/request/mod.rs b/lib/src/request/mod.rs index 05cab1ae3c..825b4e7f88 100644 --- a/lib/src/request/mod.rs +++ b/lib/src/request/mod.rs @@ -4,11 +4,13 @@ mod request; mod param; mod form; mod from_request; +mod state; pub use self::request::Request; pub use self::from_request::{FromRequest, Outcome}; pub use self::param::{FromParam, FromSegments}; pub use self::form::{Form, FromForm, FromFormValue, FormItems}; +pub use self::state::State; /// Type alias to retrieve flash messages from a request. pub type FlashMessage = ::response::Flash<()>; diff --git a/lib/src/request/request.rs b/lib/src/request/request.rs index b5982930e3..786f4e589b 100644 --- a/lib/src/request/request.rs +++ b/lib/src/request/request.rs @@ -5,6 +5,8 @@ use std::fmt; use term_painter::Color::*; use term_painter::ToStyle; +use state::Container; + use error::Error; use super::{FromParam, FromSegments}; @@ -28,6 +30,7 @@ pub struct Request<'r> { remote: Option, params: RefCell>, cookies: Cookies, + state: Option<&'r Container>, } impl<'r> Request<'r> { @@ -51,6 +54,7 @@ impl<'r> Request<'r> { remote: None, params: RefCell::new(Vec::new()), cookies: Cookies::new(&[]), + state: None } } @@ -391,13 +395,25 @@ impl<'r> Request<'r> { Some(Segments(&path[i..j])) } + /// Get the managed state container, if it exists. For internal use only! + #[doc(hidden)] + pub fn get_state(&self) -> Option<&'r Container> { + self.state + } + + /// Set the state. For internal use only! + #[doc(hidden)] + pub fn set_state(&mut self, state: &'r Container) { + self.state = Some(state); + } + /// Convert from Hyper types into a Rocket Request. #[doc(hidden)] pub fn from_hyp(h_method: hyper::Method, h_headers: hyper::header::Headers, h_uri: hyper::RequestUri, h_addr: SocketAddr, - ) -> Result, String> { + ) -> Result, String> { // Get a copy of the URI for later use. let uri = match h_uri { hyper::RequestUri::AbsolutePath(s) => s, diff --git a/lib/src/request/state.rs b/lib/src/request/state.rs new file mode 100644 index 0000000000..d6f2150b85 --- /dev/null +++ b/lib/src/request/state.rs @@ -0,0 +1,49 @@ +use std::ops::Deref; + +use request::{self, FromRequest, Request}; +use outcome::Outcome; +use http::Status; + +// TODO: Doc. +#[derive(Debug, PartialEq, Eq)] +pub struct State<'r, T: Send + Sync + 'static>(&'r T); + +impl<'r, T: Send + Sync + 'static> State<'r, T> { + /// Retrieve a borrow to the underyling value. + /// + /// Using this method is typically unnecessary as `State` implements `Deref` + /// with a `Target` of `T`. This means Rocket will automatically coerce a + /// `State` to an `&T` when the types call for it. + pub fn inner(&self) -> &'r T { + self.0 + } +} + +// TODO: Doc. +impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> { + type Error = (); + + fn from_request(req: &'a Request<'r>) -> request::Outcome, ()> { + if let Some(state) = req.get_state() { + match state.try_get::() { + Some(state) => Outcome::Success(State(state)), + None => { + error_!("Attempted to retrieve unmanaged state!"); + Outcome::Failure((Status::InternalServerError, ())) + } + } + } else { + error_!("Internal Rocket error: managed state is unset!"); + error_!("Please report this error in the Rocket GitHub issue tracker."); + Outcome::Failure((Status::InternalServerError, ())) + } + } +} + +impl<'r, T: Send + Sync + 'static> Deref for State<'r, T> { + type Target = T; + + fn deref(&self) -> &T { + self.0 + } +} diff --git a/lib/src/rocket.rs b/lib/src/rocket.rs index 25b80cb051..9d593121be 100644 --- a/lib/src/rocket.rs +++ b/lib/src/rocket.rs @@ -7,6 +7,8 @@ use std::io::{self, Write}; use term_painter::Color::*; use term_painter::ToStyle; +use state::Container; + use {logger, handler}; use ext::ReadExt; use config::{self, Config}; @@ -29,6 +31,7 @@ pub struct Rocket { router: Router, default_catchers: HashMap, catchers: HashMap, + state: Container } #[doc(hidden)] @@ -175,9 +178,13 @@ impl Rocket { #[doc(hidden)] #[inline(always)] - pub fn dispatch<'r>(&self, request: &'r mut Request, data: Data) -> Response<'r> { + pub fn dispatch<'s, 'r>(&'s self, request: &'r mut Request<'s>, data: Data) + -> Response<'r> { info!("{}:", request); + // Inform the request about the state. + request.set_state(&self.state); + // Do a bit of preprocessing before routing. self.preprocess_request(request, &data); @@ -353,6 +360,7 @@ impl Rocket { router: Router::new(), default_catchers: catcher::defaults::get(), catchers: catcher::defaults::get(), + state: Container::new() } } @@ -472,6 +480,50 @@ impl Rocket { self } + /// Add `state` to the state managed by this instance of Rocket. + /// + /// Managed state can be retrieved by any request handler via the + /// [State](/rocket/struct.State.html) request guard. In particular, if a + /// value of type `T` is managed by Rocket, adding `State` to the list of + /// arguments in a request handler instructs Rocket to retrieve the managed + /// value. + /// + /// # Panics + /// + /// Panics if state of type `T` is already being managed. + /// + /// # Example + /// + /// ```rust + /// # #![feature(plugin)] + /// # #![plugin(rocket_codegen)] + /// # extern crate rocket; + /// use rocket::State; + /// + /// struct MyValue(usize); + /// + /// #[get("/")] + /// fn index(state: State) -> String { + /// format!("The stateful value is: {}", state.0) + /// } + /// + /// fn main() { + /// # if false { // We don't actually want to launch the server in an example. + /// rocket::ignite() + /// .manage(MyValue(10)) + /// # .launch() + /// # } + /// } + /// ``` + pub fn manage(self, state: T) -> Self { + if !self.state.set::(state) { + error!("State for this type is already being managed!"); + panic!("Aborting due to duplicately managed state."); + } + + self + } + /// Starts the application server and begins listening for and dispatching /// requests to mounted routes and catchers. /// diff --git a/lib/src/testing.rs b/lib/src/testing.rs index 134f6ed522..cb5a6f83db 100644 --- a/lib/src/testing.rs +++ b/lib/src/testing.rs @@ -111,12 +111,12 @@ use http::{Method, Header, Cookie}; use std::net::SocketAddr; /// A type for mocking requests for testing Rocket applications. -pub struct MockRequest { - request: Request<'static>, +pub struct MockRequest<'r> { + request: Request<'r>, data: Data } -impl MockRequest { +impl<'r> MockRequest<'r> { /// Constructs a new mocked request with the given `method` and `uri`. #[inline] pub fn new>(method: Method, uri: S) -> Self { @@ -259,7 +259,7 @@ impl MockRequest { /// assert_eq!(body_str, Some("Hello, world!".to_string())); /// # } /// ``` - pub fn dispatch_with<'r>(&'r mut self, rocket: &Rocket) -> Response<'r> { + pub fn dispatch_with<'s>(&'s mut self, rocket: &'r Rocket) -> Response<'s> { let data = ::std::mem::replace(&mut self.data, Data::new(vec![])); rocket.dispatch(&mut self.request, data) }