Skip to content

Commit

Permalink
Introduce Managed State.
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Jan 21, 2017
1 parent 9ef65a8 commit c815911
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ members = [
"examples/hello_alt_methods",
"examples/raw_upload",
"examples/pastebin",
"examples/state",
]
4 changes: 3 additions & 1 deletion examples/from_request/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ mod test {
use rocket::http::Header;

fn test_header_count<'h>(headers: Vec<Header<'static>>) {
let rocket = rocket::ignite()
.mount("/", routes![super::header_count]);

let num_headers = headers.len();
let mut req = MockRequest::new(Get, "/");
for header in headers {
req = req.header(header);
}

let rocket = rocket::ignite().mount("/", routes![super::header_count]);
let mut response = req.dispatch_with(&rocket);

let expect = format!("Your request contained {} headers!", num_headers);
Expand Down
11 changes: 6 additions & 5 deletions examples/handlebars_templates/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ macro_rules! run_test {
.mount("/", routes![super::index, super::get])
.catch(errors![super::not_found]);

$test_fn($req.dispatch_with(&rocket));
let mut req = $req;
$test_fn(req.dispatch_with(&rocket));
})
}

#[test]
fn test_root() {
// Check that the redirect works.
for method in &[Get, Head] {
let mut req = MockRequest::new(*method, "/");
let req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::SeeOther);
assert!(response.body().is_none());
Expand All @@ -31,7 +32,7 @@ fn test_root() {

// Check that other request methods are not accepted (and instead caught).
for method in &[Post, Put, Delete, Options, Trace, Connect, Patch] {
let mut req = MockRequest::new(*method, "/");
let req = MockRequest::new(*method, "/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);

Expand All @@ -48,7 +49,7 @@ fn test_root() {
#[test]
fn test_name() {
// Check that the /hello/<name> route works.
let mut req = MockRequest::new(Get, "/hello/Jack");
let req = MockRequest::new(Get, "/hello/Jack");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);

Expand All @@ -66,7 +67,7 @@ fn test_name() {
#[test]
fn test_404() {
// Check that the error catcher works.
let mut req = MockRequest::new(Get, "/hello/");
let req = MockRequest::new(Get, "/hello/");
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);

Expand Down
21 changes: 11 additions & 10 deletions examples/json/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ macro_rules! run_test {
.mount("/message", routes![super::new, super::update, super::get])
.catch(errors![super::not_found]);

$test_fn($req.dispatch_with(&rocket));
let mut req = $req;
$test_fn(req.dispatch_with(&rocket));
})
}

#[test]
fn bad_get_put() {
// Try to get a message with an ID that doesn't exist.
let mut req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/99").header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);

Expand All @@ -27,21 +28,21 @@ fn bad_get_put() {
});

// Try to get a message with an invalid ID.
let mut req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/hi").header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::NotFound);
let body = response.body().unwrap().into_string().unwrap();
assert!(body.contains("error"));
});

// Try to put a message without a proper body.
let mut req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
let req = MockRequest::new(Put, "/message/80").header(ContentType::JSON);
run_test!(req, |response: Response| {
assert_eq!(response.status(), Status::BadRequest);
});

// Try to put a message for an ID that doesn't exist.
let mut req = MockRequest::new(Put, "/message/80")
let req = MockRequest::new(Put, "/message/80")
.header(ContentType::JSON)
.body(r#"{ "contents": "Bye bye, world!" }"#);

Expand All @@ -53,13 +54,13 @@ fn bad_get_put() {
#[test]
fn post_get_put_get() {
// Check that a message with ID 1 doesn't exist.
let mut req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1").header(ContentType::JSON);
run_test!(req, |response: Response| {
assert_eq!(response.status(), Status::NotFound);
});

// Add a new message with ID 1.
let mut req = MockRequest::new(Post, "/message/1")
let req = MockRequest::new(Post, "/message/1")
.header(ContentType::JSON)
.body(r#"{ "contents": "Hello, world!" }"#);

Expand All @@ -68,15 +69,15 @@ fn post_get_put_get() {
});

// Check that the message exists with the correct contents.
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
let body = response.body().unwrap().into_string().unwrap();
assert!(body.contains("Hello, world!"));
});

// Change the message contents.
let mut req = MockRequest::new(Put, "/message/1")
let req = MockRequest::new(Put, "/message/1")
.header(ContentType::JSON)
.body(r#"{ "contents": "Bye bye, world!" }"#);

Expand All @@ -85,7 +86,7 @@ fn post_get_put_get() {
});

// Check that the message exists with the updated contents.
let mut req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
let req = MockRequest::new(Get, "/message/1") .header(ContentType::JSON);
run_test!(req, |mut response: Response| {
assert_eq!(response.status(), Status::Ok);
let body = response.body().unwrap().into_string().unwrap();
Expand Down
11 changes: 11 additions & 0 deletions examples/state/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "state"
version = "0.0.1"
workspace = "../../"

[dependencies]
rocket = { path = "../../lib" }
rocket_codegen = { path = "../../codegen" }

[dev-dependencies]
rocket = { path = "../../lib", features = ["testing"] }
36 changes: 36 additions & 0 deletions examples/state/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#![feature(plugin)]
#![plugin(rocket_codegen)]

extern crate rocket;

#[cfg(test)] mod tests;

use std::sync::atomic::{AtomicUsize, Ordering};

use rocket::State;
use rocket::response::content;

struct HitCount(AtomicUsize);

#[get("/")]
fn index(hit_count: State<HitCount>) -> content::HTML<String> {
hit_count.0.fetch_add(1, Ordering::Relaxed);
let msg = "Your visit has been recorded!";
let count = format!("Visits: {}", count(hit_count));
content::HTML(format!("{}<br /><br />{}", msg, count))
}

#[get("/count")]
fn count(hit_count: State<HitCount>) -> String {
hit_count.0.load(Ordering::Relaxed).to_string()
}

fn rocket() -> rocket::Rocket {
rocket::ignite()
.mount("/", routes![index, count])
.manage(HitCount(AtomicUsize::new(0)))
}

fn main() {
rocket().launch();
}
43 changes: 43 additions & 0 deletions examples/state/src/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use rocket::Rocket;
use rocket::testing::MockRequest;
use rocket::http::Method::*;
use rocket::http::Status;

fn register_hit(rocket: &Rocket) {
let mut req = MockRequest::new(Get, "/");
let response = req.dispatch_with(&rocket);
assert_eq!(response.status(), Status::Ok);
}

fn get_count(rocket: &Rocket) -> usize {
let mut req = MockRequest::new(Get, "/count");
let mut response = req.dispatch_with(&rocket);
let body_string = response.body().and_then(|b| b.into_string()).unwrap();
body_string.parse().unwrap()
}

#[test]
fn test_count() {
let rocket = super::rocket();

// Count should start at 0.
assert_eq!(get_count(&rocket), 0);

for _ in 0..99 { register_hit(&rocket); }
assert_eq!(get_count(&rocket), 99);

register_hit(&rocket);
assert_eq!(get_count(&rocket), 100);
}

// Cargo runs each test in parallel on different threads. We use all of these
// tests below to show (and assert) that state is managed per-Rocket instance.
#[test] fn test_count_parallel() { test_count() }
#[test] fn test_count_parallel_2() { test_count() }
#[test] fn test_count_parallel_3() { test_count() }
#[test] fn test_count_parallel_4() { test_count() }
#[test] fn test_count_parallel_5() { test_count() }
#[test] fn test_count_parallel_6() { test_count() }
#[test] fn test_count_parallel_7() { test_count() }
#[test] fn test_count_parallel_8() { test_count() }
#[test] fn test_count_parallel_9() { test_count() }
1 change: 1 addition & 0 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ url = "^1"
hyper = { version = "^0.9.14", default-features = false }
toml = { version = "^0.2", default-features = false }
num_cpus = "1"
state = "^0.2"
# cookie = "^0.3"

[dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ extern crate hyper;
extern crate url;
extern crate toml;
extern crate num_cpus;
extern crate state;

#[cfg(test)] #[macro_use] extern crate lazy_static;

Expand Down Expand Up @@ -123,7 +124,7 @@ mod ext;
#[doc(inline)] pub use outcome::Outcome;
#[doc(inline)] pub use data::Data;
pub use router::Route;
pub use request::Request;
pub use request::{Request, State};
pub use error::Error;
pub use catcher::Catcher;
pub use rocket::Rocket;
Expand Down
2 changes: 2 additions & 0 deletions lib/src/request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ mod request;
mod param;
mod form;
mod from_request;
mod state;

pub use self::request::Request;
pub use self::from_request::{FromRequest, Outcome};
pub use self::param::{FromParam, FromSegments};
pub use self::form::{Form, FromForm, FromFormValue, FormItems};
pub use self::state::State;

/// Type alias to retrieve flash messages from a request.
pub type FlashMessage = ::response::Flash<()>;
18 changes: 17 additions & 1 deletion lib/src/request/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::fmt;
use term_painter::Color::*;
use term_painter::ToStyle;

use state::Container;

use error::Error;
use super::{FromParam, FromSegments};

Expand All @@ -28,6 +30,7 @@ pub struct Request<'r> {
remote: Option<SocketAddr>,
params: RefCell<Vec<(usize, usize)>>,
cookies: Cookies,
state: Option<&'r Container>,
}

impl<'r> Request<'r> {
Expand All @@ -51,6 +54,7 @@ impl<'r> Request<'r> {
remote: None,
params: RefCell::new(Vec::new()),
cookies: Cookies::new(&[]),
state: None
}
}

Expand Down Expand Up @@ -391,13 +395,25 @@ impl<'r> Request<'r> {
Some(Segments(&path[i..j]))
}

/// Get the managed state container, if it exists. For internal use only!
#[doc(hidden)]
pub fn get_state(&self) -> Option<&'r Container> {
self.state
}

/// Set the state. For internal use only!
#[doc(hidden)]
pub fn set_state(&mut self, state: &'r Container) {
self.state = Some(state);
}

/// Convert from Hyper types into a Rocket Request.
#[doc(hidden)]
pub fn from_hyp(h_method: hyper::Method,
h_headers: hyper::header::Headers,
h_uri: hyper::RequestUri,
h_addr: SocketAddr,
) -> Result<Request<'static>, String> {
) -> Result<Request<'r>, String> {
// Get a copy of the URI for later use.
let uri = match h_uri {
hyper::RequestUri::AbsolutePath(s) => s,
Expand Down
49 changes: 49 additions & 0 deletions lib/src/request/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::ops::Deref;

use request::{self, FromRequest, Request};
use outcome::Outcome;
use http::Status;

// TODO: Doc.
#[derive(Debug, PartialEq, Eq)]
pub struct State<'r, T: Send + Sync + 'static>(&'r T);

impl<'r, T: Send + Sync + 'static> State<'r, T> {
/// Retrieve a borrow to the underyling value.
///
/// Using this method is typically unnecessary as `State` implements `Deref`
/// with a `Target` of `T`. This means Rocket will automatically coerce a
/// `State<T>` to an `&T` when the types call for it.
pub fn inner(&self) -> &'r T {
self.0
}
}

// TODO: Doc.
impl<'a, 'r, T: Send + Sync + 'static> FromRequest<'a, 'r> for State<'r, T> {
type Error = ();

fn from_request(req: &'a Request<'r>) -> request::Outcome<State<'r, T>, ()> {
if let Some(state) = req.get_state() {
match state.try_get::<T>() {
Some(state) => Outcome::Success(State(state)),
None => {
error_!("Attempted to retrieve unmanaged state!");
Outcome::Failure((Status::InternalServerError, ()))
}
}
} else {
error_!("Internal Rocket error: managed state is unset!");
error_!("Please report this error in the Rocket GitHub issue tracker.");
Outcome::Failure((Status::InternalServerError, ()))
}
}
}

impl<'r, T: Send + Sync + 'static> Deref for State<'r, T> {
type Target = T;

fn deref(&self) -> &T {
self.0
}
}
Loading

0 comments on commit c815911

Please sign in to comment.