From af2785f030b80f3d1f33e3eba2f3bf5064a30cb2 Mon Sep 17 00:00:00 2001 From: Jim Hodapp <james.hodapp@gmail.com> Date: Tue, 16 May 2023 14:47:10 -0500 Subject: [PATCH] Basics of moving to Axum from Rocket are working, except for adding the CORS whitelisting headers and sending out server sent events. SSE is still a work-in-progress. --- Cargo.toml | 26 +++-- src/controller.rs | 264 +++++++++++++++++++++++++++++++++------------- src/lib.rs | 114 ++++++++++---------- src/main.rs | 75 +++++++++++-- src/models.rs | 12 +-- src/schema.rs | 2 + 6 files changed, 341 insertions(+), 152 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a3688da..e180935 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,12 +4,22 @@ version = "0.1.0" edition = "2021" [dependencies] -rocket = { version = "0.5.0-rc.2", features = ["json"] } -serde = "1.0.147" -diesel = { version = "1.4.8", features = ["postgres"] } -env_logger = "0.10.0" -log = "0.4.17" +async-stream = "0.3" +axum = { version = "0.6", features = ["tokio", "headers"]} +axum-macros = "0.3" +deadpool-diesel = { version = "0.4.1", features = ["postgres"] } +diesel = { version = "2", features = ["postgres"] } +diesel_migrations = "2" +futures = "0.3" +futures-util = "0.3" +headers = "0.3" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1" +tokio = { version = "1.0", features = ["full"] } +tokio-stream = "0.1" +tower-http = { version = "0.4.0", features = ["fs", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } -[dependencies.rocket_sync_db_pools] -version = "0.1.0-rc.2" -features = ["diesel_postgres_pool"] \ No newline at end of file +env_logger = "0.10.0" +log = "0.4.17" \ No newline at end of file diff --git a/src/controller.rs b/src/controller.rs index 3f607d0..97aa3fb 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -1,91 +1,205 @@ -use rocket::response::status::Created; -use rocket::response::stream::{Event, EventStream}; -use rocket::serde::json::Json; -use rocket::tokio::select; -use rocket::tokio::sync::broadcast::Sender; -use rocket::tokio::time::{self, Duration}; -use rocket::{Shutdown, State}; - -use crate::schema::*; -use crate::PgConnection; +//use async_stream::stream; +use axum::{ + extract::State, + extract::TypedHeader, + http::StatusCode, + response::sse::{Event, Sse}, + response::Json, +}; +use futures::stream::{self, Stream}; +//use futures_util::select; use diesel::prelude::*; +use std::{convert::Infallible, time::Duration}; +use tokio::runtime; +use tokio_stream::StreamExt as _; -use log::{debug, error, info}; +use crate::schema::*; use crate::models::{ApiError, NewReading, Reading}; /// Returns an infinite stream of server-sent events. Each event is a message /// pulled from a broadcast queue sent by the `post` handler. -#[get("/events")] -pub(crate) async fn events( - conn: PgConnection, - queue: &State<Sender<Reading>>, - mut end: Shutdown, -) -> EventStream![] { - let _rx = queue.subscribe(); - EventStream! { - let mut interval = time::interval(Duration::from_secs(5)); - loop { - select! { - _ = interval.tick() => { - match get_latest_reading(&conn).await { - Ok(reading) => { - yield Event::json(&reading); - yield Event::data(format!("{}°C", reading.temperature)).event("temperature"); - yield Event::data(format!("{}%", reading.humidity)).event("humidity"); - yield Event::data(format!("{} mbars", reading.pressure)).event("pressure"); - yield Event::data(format!("{}", reading.air_purity)).event("air_purity"); - yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration"); - } - Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) } - } - } - - // Handle graceful shutdown of infinite EventStream - _ = &mut end => { - info!("EventStream graceful shutdown requested, handling..."); - break; - } - } - } - } +// #[get("/events")] +// pub(crate) async fn events( +// conn: PgConnection, +// queue: &State<Sender<Reading>>, +// mut end: Shutdown, +// ) -> EventStream![] { +// let _rx = queue.subscribe(); +// EventStream! { +// let mut interval = time::interval(Duration::from_secs(5)); +// loop { +// select! { +// _ = interval.tick() => { +// match get_latest_reading(&conn).await { +// Ok(reading) => { +// yield Event::json(&reading); +// yield Event::data(format!("{}°C", reading.temperature)).event("temperature"); +// yield Event::data(format!("{}%", reading.humidity)).event("humidity"); +// yield Event::data(format!("{} mbars", reading.pressure)).event("pressure"); +// yield Event::data(format!("{}", reading.air_purity)).event("air_purity"); +// yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration"); +// } +// Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) } +// } +// } + +// // Handle graceful shutdown of infinite EventStream +// _ = &mut end => { +// info!("EventStream graceful shutdown requested, handling..."); +// break; +// } +// } +// } +// } +// } + +pub(crate) async fn events_handler( + State(pool): State<deadpool_diesel::postgres::Pool>, + TypedHeader(user_agent): TypedHeader<headers::UserAgent>, +) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { + println!("`{}` connected", user_agent.as_str()); + + //let mut interval = tokio::time::interval(Duration::from_secs(5)); + + // let astream = stream! { + // match get_latest_reading(&conn).await { + // Ok(reading) => { + // yield Event::json(&reading); + // yield Event::data(format!("{}°C", reading.temperature)).event("temperature"); + // yield Event::data(format!("{}%", reading.humidity)).event("humidity"); + // yield Event::data(format!("{} mbars", reading.pressure)).event("pressure"); + // yield Event::data(format!("{}", reading.air_purity)).event("air_purity"); + // yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration"); + // } + // Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) } + // } + // }; + + // A `Stream` that repeats an event every 5 seconds + // let stream = stream::repeat_with(|| async { + // let event = match get_latest_reading(axum::extract::State(pool)).await { + // //Ok(reading) => { Event::data(format!("{}°C", reading.temperature)).event("temperature") } + // Ok(reading) => { Event::default().data("Successfully retrieved latest reading") } + // Err(e) => { Event::default().data("Failed to retrieve latest reading") } + // } + // }) + // .map(Ok) + // .throttle(Duration::from_secs(5)); + + // match get_latest_reading(axum::extract::State(pool)).await { + // Ok(reading) => { + // println!("{}°C", reading.temperature); + // } + // Err(e) => { println!("Err: failed to retrieve latest reading: {:?}", e) } + // } + + // TODO: maybe we pass in a stream of Event instances filled with Readings into event_handler instead of + // asking this method to retrieve Readings from the DB + let stream = + stream::repeat_with( + || match get_latest_reading(axum::extract::State(pool)).await { + Ok(reading) => Event::default() + .data(format!("{}°C", reading.temperature)) + .event("temperature"), + Err(e) => Event::default() + .data("Err: failed to retrieve latest reading: {e}") + .event("error"), + }, + ) + .map(Ok) + .throttle(Duration::from_secs(5)); + + // let stream = stream::repeat_with(|| Event::default().data("hi!")) + // .map(Ok) + // .throttle(Duration::from_secs(1)); + + Sse::new(stream).keep_alive( + axum::response::sse::KeepAlive::new() + .interval(Duration::from_secs(5)) + .text("keep-alive-text"), + ) } -async fn get_latest_reading(conn: &PgConnection) -> Result<Reading, Json<ApiError>> { - // Get the last inserted temperature value - let reading = conn - .run(move |c| { +// async fn get_latest_reading(conn: &PgConnection) -> Result<Reading, Json<ApiError>> { +// // Get the last inserted temperature value +// let reading = conn +// .run(move |c| { +// readings::table +// .order(readings::id.desc()) +// .first::<Reading>(c) +// }) +// .await +// .map_err(|e| { +// Json(ApiError { +// details: e.to_string(), +// }) +// }); + +// debug!("Reading: {:?}", reading); + +// reading +// } + +// #[post("/readings/add", data = "<reading>")] +// pub(crate) async fn create_reading( +// conn: PgConnection, +// reading: Json<NewReading>, +// ) -> Result<Created<Json<Reading>>, Json<ApiError>> { +// conn.run(move |c| { +// diesel::insert_into(readings::table) +// .values(&reading.into_inner()) +// .get_result(c) +// }) +// .await +// .map(|a| Created::new("/").body(Json(a))) +// .map_err(|e| { +// Json(ApiError { +// details: e.to_string(), +// }) +// }) +// } + +pub(crate) async fn create_reading( + State(pool): State<deadpool_diesel::postgres::Pool>, + Json(new_reading): Json<NewReading>, +) -> Result<Json<Reading>, (StatusCode, String)> { + let conn = pool.get().await.map_err(internal_error)?; + let res = conn + .interact(|conn| { + diesel::insert_into(readings::table) + .values(new_reading) + .returning(Reading::as_returning()) + .get_result(conn) + }) + .await + .map_err(internal_error)? + .map_err(internal_error)?; + Ok(Json(res)) +} + +pub(crate) async fn get_latest_reading( + State(pool): State<deadpool_diesel::postgres::Pool>, +) -> Result<Json<Reading>, (StatusCode, String)> { + let conn = pool.get().await.map_err(internal_error)?; + let res = conn + //.interact(|conn| users::table.select(Reading::as_select()).load(conn)) + .interact(|conn| { readings::table .order(readings::id.desc()) - .first::<Reading>(c) + .first::<Reading>(conn) }) .await - .map_err(|e| { - Json(ApiError { - details: e.to_string(), - }) - }); - - debug!("Reading: {:?}", reading); - - reading + .map_err(internal_error)? + .map_err(internal_error)?; + Ok(Json(res)) } -#[post("/readings/add", data = "<reading>")] -pub(crate) async fn create_reading( - conn: PgConnection, - reading: Json<NewReading>, -) -> Result<Created<Json<Reading>>, Json<ApiError>> { - conn.run(move |c| { - diesel::insert_into(readings::table) - .values(&reading.into_inner()) - .get_result(c) - }) - .await - .map(|a| Created::new("/").body(Json(a))) - .map_err(|e| { - Json(ApiError { - details: e.to_string(), - }) - }) +/// Utility function for mapping any error into a `500 Internal Server Error` +/// response. +fn internal_error<E>(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) } diff --git a/src/lib.rs b/src/lib.rs index 2597a02..7bf4d5b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,56 +1,58 @@ -mod controller; -pub mod models; -pub mod schema; - -#[macro_use] -extern crate rocket; -#[macro_use] -extern crate diesel; - -use rocket::fs::{relative, FileServer}; -use rocket::tokio::sync::broadcast::channel; -use rocket::Build; -use rocket_sync_db_pools::database; - -use rocket::fairing::{Fairing, Info, Kind}; -use rocket::http::Header; -use rocket::{Request, Response}; - -use crate::controller::{create_reading, events}; -use crate::models::Reading; - -#[database("ambi_rs_dev")] -pub struct PgConnection(diesel::PgConnection); - -pub struct CORS; - -#[rocket::async_trait] -impl Fairing for CORS { - fn info(&self) -> Info { - Info { - name: "Add CORS headers to responses", - kind: Kind::Response, - } - } - - async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) { - response.set_header(Header::new("Access-Control-Allow-Origin", "*")); - response.set_header(Header::new( - "Access-Control-Allow-Methods", - "POST, GET, PATCH, OPTIONS", - )); - response.set_header(Header::new("Access-Control-Allow-Headers", "*")); - response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); - } -} - -#[launch] -pub fn rocket_builder() -> rocket::Rocket<Build> { - rocket::build() - .attach(PgConnection::fairing()) - .attach(CORS) - .manage(channel::<Reading>(1024).0) - .mount("/", routes![events]) - .mount("/api", routes![create_reading]) - .mount("/", FileServer::from(relative!("static"))) -} +// mod controller; +// pub mod models; +// pub mod schema; + +// use axum::{ +// extract::State, +// http::StatusCode, +// response::Json, +// routing::{get, post}, +// Router, +// }; +// use diesel::prelude::*; +// use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; +// use std::net::SocketAddr; +// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +// use crate::controller::{create_reading, events}; +// use crate::models::Reading; + +// // This embeds the migrations into the application binary +// // the migration path is relative to the `CARGO_MANIFEST_DIR` +// pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/"); + +// #[database("ambi_rs_dev")] +// pub struct PgConnection(diesel::PgConnection); + +// pub struct CORS; + +// #[rocket::async_trait] +// impl Fairing for CORS { +// fn info(&self) -> Info { +// Info { +// name: "Add CORS headers to responses", +// kind: Kind::Response, +// } +// } + +// async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) { +// response.set_header(Header::new("Access-Control-Allow-Origin", "*")); +// response.set_header(Header::new( +// "Access-Control-Allow-Methods", +// "POST, GET, PATCH, OPTIONS", +// )); +// response.set_header(Header::new("Access-Control-Allow-Headers", "*")); +// response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); +// } +// } + +// #[launch] +// pub fn rocket_builder() -> rocket::Rocket<Build> { +// rocket::build() +// .attach(PgConnection::fairing()) +// .attach(CORS) +// .manage(channel::<Reading>(1024).0) +// .mount("/", routes![events]) +// .mount("/api", routes![create_reading]) +// .mount("/", FileServer::from(relative!("static"))) +// } diff --git a/src/main.rs b/src/main.rs index b283bbb..fe5bad4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,71 @@ -use ambi_rs::rocket_builder; +pub mod controller; +pub mod models; +pub mod schema; -#[rocket::main] -async fn main() -> Result<(), rocket::Error> { - env_logger::init(); - let _ = rocket_builder().launch().await?; - Ok(()) +use axum::{ + routing::{get, post}, + Router, +}; + +use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; +use std::net::SocketAddr; +use std::path::PathBuf; +use tower_http::services::ServeDir; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use crate::controller::{create_reading, events_handler}; + +// This embeds the migrations into the application binary +// the migration path is relative to the `CARGO_MANIFEST_DIR` +pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/"); + +#[tokio::main] +async fn main() { + //env_logger::init(); + + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "ambi_rs=debug,tower_http=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let db_url = std::env::var("DATABASE_URL").unwrap(); + + let static_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("static"); + + // set up connection pool + let manager = deadpool_diesel::postgres::Manager::new(db_url, deadpool_diesel::Runtime::Tokio1); + let pool = deadpool_diesel::postgres::Pool::builder(manager) + .build() + .unwrap(); + + // run the migrations on server startup + { + let conn = pool.get().await.unwrap(); + conn.interact(|conn| conn.run_pending_migrations(MIGRATIONS).map(|_| ())) + .await + .unwrap() + .unwrap(); + } + + let static_files_service = ServeDir::new(static_dir).append_index_html_on_directories(true); + + // build our application with some routes + let app = Router::new() + .fallback_service(static_files_service) + .route("/events", get(events_handler)) + .route("/api/readings/add", post(create_reading)) + .with_state(pool); + + // run it with hyper + let addr = SocketAddr::from(([127, 0, 0, 1], 8000)); + tracing::debug!("listening on {}", addr); + //let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::Server::bind(&addr) + .serve(app.into_make_service()) + .await + .unwrap() + //axum::serve(listener, app).await.unwrap(); } diff --git a/src/models.rs b/src/models.rs index 5a53459..a72d154 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,9 +1,9 @@ -use crate::schema::readings; -use diesel::Insertable; +use diesel::{Insertable, Queryable, Selectable}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, Deserialize, Serialize, Queryable)] -#[serde(crate = "rocket::serde")] +use crate::schema::readings; + +#[derive(Debug, Clone, Deserialize, Serialize, Queryable, Selectable)] pub struct Reading { pub id: i32, pub temperature: f64, @@ -14,8 +14,7 @@ pub struct Reading { } #[derive(Debug, Insertable, Deserialize)] -#[serde(crate = "rocket::serde")] -#[table_name = "readings"] +#[diesel(table_name = readings)] pub struct NewReading { pub temperature: f64, pub humidity: f64, @@ -25,7 +24,6 @@ pub struct NewReading { } #[derive(Serialize, Deserialize, Debug)] -#[serde(crate = "rocket::serde")] pub struct ApiError { pub details: String, } diff --git a/src/schema.rs b/src/schema.rs index 80fcf4a..e7d7051 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,5 +1,7 @@ // @generated automatically by Diesel CLI. +use diesel::prelude::*; + table! { readings (id) { id -> Int4,