From af2785f030b80f3d1f33e3eba2f3bf5064a30cb2 Mon Sep 17 00:00:00 2001
From: Jim Hodapp <james.hodapp@gmail.com>
Date: Tue, 16 May 2023 14:47:10 -0500
Subject: [PATCH] Basics of moving to Axum from Rocket are working, except for
 adding the CORS whitelisting headers and sending out server sent events. SSE
 is still a work-in-progress.

---
 Cargo.toml        |  26 +++--
 src/controller.rs | 264 +++++++++++++++++++++++++++++++++-------------
 src/lib.rs        | 114 ++++++++++----------
 src/main.rs       |  75 +++++++++++--
 src/models.rs     |  12 +--
 src/schema.rs     |   2 +
 6 files changed, 341 insertions(+), 152 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index a3688da..e180935 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -4,12 +4,22 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
-rocket = { version = "0.5.0-rc.2", features = ["json"] }
-serde = "1.0.147"
-diesel = { version = "1.4.8", features = ["postgres"] }
-env_logger = "0.10.0"
-log = "0.4.17"
+async-stream = "0.3"
+axum = { version = "0.6", features = ["tokio", "headers"]}
+axum-macros = "0.3"
+deadpool-diesel = { version = "0.4.1", features = ["postgres"] }
+diesel = { version = "2", features = ["postgres"] }
+diesel_migrations = "2"
+futures = "0.3"
+futures-util = "0.3"
+headers = "0.3"
+serde = { version = "1.0", features = ["derive"] }
+serde_json = "1"
+tokio = { version = "1.0", features = ["full"] }
+tokio-stream = "0.1"
+tower-http = { version = "0.4.0", features = ["fs", "trace"] }
+tracing = "0.1"
+tracing-subscriber = { version = "0.3", features = ["env-filter"] }
 
-[dependencies.rocket_sync_db_pools]
-version = "0.1.0-rc.2"
-features = ["diesel_postgres_pool"]
\ No newline at end of file
+env_logger = "0.10.0"
+log = "0.4.17"
\ No newline at end of file
diff --git a/src/controller.rs b/src/controller.rs
index 3f607d0..97aa3fb 100644
--- a/src/controller.rs
+++ b/src/controller.rs
@@ -1,91 +1,205 @@
-use rocket::response::status::Created;
-use rocket::response::stream::{Event, EventStream};
-use rocket::serde::json::Json;
-use rocket::tokio::select;
-use rocket::tokio::sync::broadcast::Sender;
-use rocket::tokio::time::{self, Duration};
-use rocket::{Shutdown, State};
-
-use crate::schema::*;
-use crate::PgConnection;
+//use async_stream::stream;
+use axum::{
+    extract::State,
+    extract::TypedHeader,
+    http::StatusCode,
+    response::sse::{Event, Sse},
+    response::Json,
+};
+use futures::stream::{self, Stream};
+//use futures_util::select;
 use diesel::prelude::*;
+use std::{convert::Infallible, time::Duration};
+use tokio::runtime;
+use tokio_stream::StreamExt as _;
 
-use log::{debug, error, info};
+use crate::schema::*;
 
 use crate::models::{ApiError, NewReading, Reading};
 
 /// Returns an infinite stream of server-sent events. Each event is a message
 /// pulled from a broadcast queue sent by the `post` handler.
-#[get("/events")]
-pub(crate) async fn events(
-    conn: PgConnection,
-    queue: &State<Sender<Reading>>,
-    mut end: Shutdown,
-) -> EventStream![] {
-    let _rx = queue.subscribe();
-    EventStream! {
-        let mut interval = time::interval(Duration::from_secs(5));
-        loop {
-            select! {
-                _ = interval.tick() => {
-                    match get_latest_reading(&conn).await {
-                        Ok(reading) => {
-                            yield Event::json(&reading);
-                            yield Event::data(format!("{}°C", reading.temperature)).event("temperature");
-                            yield Event::data(format!("{}%", reading.humidity)).event("humidity");
-                            yield Event::data(format!("{} mbars", reading.pressure)).event("pressure");
-                            yield Event::data(format!("{}", reading.air_purity)).event("air_purity");
-                            yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration");
-                        }
-                        Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) }
-                    }
-                }
-
-                // Handle graceful shutdown of infinite EventStream
-                _ = &mut end => {
-                    info!("EventStream graceful shutdown requested, handling...");
-                    break;
-                }
-            }
-        }
-    }
+// #[get("/events")]
+// pub(crate) async fn events(
+//     conn: PgConnection,
+//     queue: &State<Sender<Reading>>,
+//     mut end: Shutdown,
+// ) -> EventStream![] {
+//     let _rx = queue.subscribe();
+//     EventStream! {
+//         let mut interval = time::interval(Duration::from_secs(5));
+//         loop {
+//             select! {
+//                 _ = interval.tick() => {
+//                     match get_latest_reading(&conn).await {
+//                         Ok(reading) => {
+//                             yield Event::json(&reading);
+//                             yield Event::data(format!("{}°C", reading.temperature)).event("temperature");
+//                             yield Event::data(format!("{}%", reading.humidity)).event("humidity");
+//                             yield Event::data(format!("{} mbars", reading.pressure)).event("pressure");
+//                             yield Event::data(format!("{}", reading.air_purity)).event("air_purity");
+//                             yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration");
+//                         }
+//                         Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) }
+//                     }
+//                 }
+
+//                 // Handle graceful shutdown of infinite EventStream
+//                 _ = &mut end => {
+//                     info!("EventStream graceful shutdown requested, handling...");
+//                     break;
+//                 }
+//             }
+//         }
+//     }
+// }
+
+pub(crate) async fn events_handler(
+    State(pool): State<deadpool_diesel::postgres::Pool>,
+    TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
+) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
+    println!("`{}` connected", user_agent.as_str());
+
+    //let mut interval = tokio::time::interval(Duration::from_secs(5));
+
+    // let astream = stream! {
+    // match get_latest_reading(&conn).await {
+    //     Ok(reading) => {
+    //         yield Event::json(&reading);
+    //         yield Event::data(format!("{}°C", reading.temperature)).event("temperature");
+    //         yield Event::data(format!("{}%", reading.humidity)).event("humidity");
+    //         yield Event::data(format!("{} mbars", reading.pressure)).event("pressure");
+    //         yield Event::data(format!("{}", reading.air_purity)).event("air_purity");
+    //         yield Event::data(format!("{} pcs/ltr", reading.dust_concentration)).event("dust_concentration");
+    //     }
+    //     Err(e) => { error!("Err: failed to retrieve latest reading: {:?}", e) }
+    // }
+    // };
+
+    // A `Stream` that repeats an event every 5 seconds
+    // let stream = stream::repeat_with(|| async {
+    //         let event = match get_latest_reading(axum::extract::State(pool)).await {
+    //             //Ok(reading) => { Event::data(format!("{}°C", reading.temperature)).event("temperature") }
+    //             Ok(reading) => { Event::default().data("Successfully retrieved latest reading") }
+    //             Err(e) => { Event::default().data("Failed to retrieve latest reading") }
+    //         }
+    //     })
+    //     .map(Ok)
+    //     .throttle(Duration::from_secs(5));
+
+    // match get_latest_reading(axum::extract::State(pool)).await {
+    //     Ok(reading) => {
+    //         println!("{}°C", reading.temperature);
+    //     }
+    //     Err(e) => { println!("Err: failed to retrieve latest reading: {:?}", e) }
+    // }
+
+    // TODO: maybe we pass in a stream of Event instances filled with Readings into event_handler instead of
+    // asking this method to retrieve Readings from the DB
+    let stream =
+        stream::repeat_with(
+            || match get_latest_reading(axum::extract::State(pool)).await {
+                Ok(reading) => Event::default()
+                    .data(format!("{}°C", reading.temperature))
+                    .event("temperature"),
+                Err(e) => Event::default()
+                    .data("Err: failed to retrieve latest reading: {e}")
+                    .event("error"),
+            },
+        )
+        .map(Ok)
+        .throttle(Duration::from_secs(5));
+
+    // let stream = stream::repeat_with(|| Event::default().data("hi!"))
+    //     .map(Ok)
+    //     .throttle(Duration::from_secs(1));
+
+    Sse::new(stream).keep_alive(
+        axum::response::sse::KeepAlive::new()
+            .interval(Duration::from_secs(5))
+            .text("keep-alive-text"),
+    )
 }
 
-async fn get_latest_reading(conn: &PgConnection) -> Result<Reading, Json<ApiError>> {
-    // Get the last inserted temperature value
-    let reading = conn
-        .run(move |c| {
+// async fn get_latest_reading(conn: &PgConnection) -> Result<Reading, Json<ApiError>> {
+//     // Get the last inserted temperature value
+//     let reading = conn
+//         .run(move |c| {
+//             readings::table
+//                 .order(readings::id.desc())
+//                 .first::<Reading>(c)
+//         })
+//         .await
+//         .map_err(|e| {
+//             Json(ApiError {
+//                 details: e.to_string(),
+//             })
+//         });
+
+//     debug!("Reading: {:?}", reading);
+
+//     reading
+// }
+
+// #[post("/readings/add", data = "<reading>")]
+// pub(crate) async fn create_reading(
+//     conn: PgConnection,
+//     reading: Json<NewReading>,
+// ) -> Result<Created<Json<Reading>>, Json<ApiError>> {
+//     conn.run(move |c| {
+//         diesel::insert_into(readings::table)
+//             .values(&reading.into_inner())
+//             .get_result(c)
+//     })
+//     .await
+//     .map(|a| Created::new("/").body(Json(a)))
+//     .map_err(|e| {
+//         Json(ApiError {
+//             details: e.to_string(),
+//         })
+//     })
+// }
+
+pub(crate) async fn create_reading(
+    State(pool): State<deadpool_diesel::postgres::Pool>,
+    Json(new_reading): Json<NewReading>,
+) -> Result<Json<Reading>, (StatusCode, String)> {
+    let conn = pool.get().await.map_err(internal_error)?;
+    let res = conn
+        .interact(|conn| {
+            diesel::insert_into(readings::table)
+                .values(new_reading)
+                .returning(Reading::as_returning())
+                .get_result(conn)
+        })
+        .await
+        .map_err(internal_error)?
+        .map_err(internal_error)?;
+    Ok(Json(res))
+}
+
+pub(crate) async fn get_latest_reading(
+    State(pool): State<deadpool_diesel::postgres::Pool>,
+) -> Result<Json<Reading>, (StatusCode, String)> {
+    let conn = pool.get().await.map_err(internal_error)?;
+    let res = conn
+        //.interact(|conn| users::table.select(Reading::as_select()).load(conn))
+        .interact(|conn| {
             readings::table
                 .order(readings::id.desc())
-                .first::<Reading>(c)
+                .first::<Reading>(conn)
         })
         .await
-        .map_err(|e| {
-            Json(ApiError {
-                details: e.to_string(),
-            })
-        });
-
-    debug!("Reading: {:?}", reading);
-
-    reading
+        .map_err(internal_error)?
+        .map_err(internal_error)?;
+    Ok(Json(res))
 }
 
-#[post("/readings/add", data = "<reading>")]
-pub(crate) async fn create_reading(
-    conn: PgConnection,
-    reading: Json<NewReading>,
-) -> Result<Created<Json<Reading>>, Json<ApiError>> {
-    conn.run(move |c| {
-        diesel::insert_into(readings::table)
-            .values(&reading.into_inner())
-            .get_result(c)
-    })
-    .await
-    .map(|a| Created::new("/").body(Json(a)))
-    .map_err(|e| {
-        Json(ApiError {
-            details: e.to_string(),
-        })
-    })
+/// Utility function for mapping any error into a `500 Internal Server Error`
+/// response.
+fn internal_error<E>(err: E) -> (StatusCode, String)
+where
+    E: std::error::Error,
+{
+    (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
 }
diff --git a/src/lib.rs b/src/lib.rs
index 2597a02..7bf4d5b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,56 +1,58 @@
-mod controller;
-pub mod models;
-pub mod schema;
-
-#[macro_use]
-extern crate rocket;
-#[macro_use]
-extern crate diesel;
-
-use rocket::fs::{relative, FileServer};
-use rocket::tokio::sync::broadcast::channel;
-use rocket::Build;
-use rocket_sync_db_pools::database;
-
-use rocket::fairing::{Fairing, Info, Kind};
-use rocket::http::Header;
-use rocket::{Request, Response};
-
-use crate::controller::{create_reading, events};
-use crate::models::Reading;
-
-#[database("ambi_rs_dev")]
-pub struct PgConnection(diesel::PgConnection);
-
-pub struct CORS;
-
-#[rocket::async_trait]
-impl Fairing for CORS {
-    fn info(&self) -> Info {
-        Info {
-            name: "Add CORS headers to responses",
-            kind: Kind::Response,
-        }
-    }
-
-    async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) {
-        response.set_header(Header::new("Access-Control-Allow-Origin", "*"));
-        response.set_header(Header::new(
-            "Access-Control-Allow-Methods",
-            "POST, GET, PATCH, OPTIONS",
-        ));
-        response.set_header(Header::new("Access-Control-Allow-Headers", "*"));
-        response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
-    }
-}
-
-#[launch]
-pub fn rocket_builder() -> rocket::Rocket<Build> {
-    rocket::build()
-        .attach(PgConnection::fairing())
-        .attach(CORS)
-        .manage(channel::<Reading>(1024).0)
-        .mount("/", routes![events])
-        .mount("/api", routes![create_reading])
-        .mount("/", FileServer::from(relative!("static")))
-}
+// mod controller;
+// pub mod models;
+// pub mod schema;
+
+// use axum::{
+//     extract::State,
+//     http::StatusCode,
+//     response::Json,
+//     routing::{get, post},
+//     Router,
+// };
+// use diesel::prelude::*;
+// use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
+// use std::net::SocketAddr;
+// use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+// use crate::controller::{create_reading, events};
+// use crate::models::Reading;
+
+// // This embeds the migrations into the application binary
+// // the migration path is relative to the `CARGO_MANIFEST_DIR`
+// pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/");
+
+// #[database("ambi_rs_dev")]
+// pub struct PgConnection(diesel::PgConnection);
+
+// pub struct CORS;
+
+// #[rocket::async_trait]
+// impl Fairing for CORS {
+//     fn info(&self) -> Info {
+//         Info {
+//             name: "Add CORS headers to responses",
+//             kind: Kind::Response,
+//         }
+//     }
+
+//     async fn on_response<'r>(&self, _request: &'r Request<'_>, response: &mut Response<'r>) {
+//         response.set_header(Header::new("Access-Control-Allow-Origin", "*"));
+//         response.set_header(Header::new(
+//             "Access-Control-Allow-Methods",
+//             "POST, GET, PATCH, OPTIONS",
+//         ));
+//         response.set_header(Header::new("Access-Control-Allow-Headers", "*"));
+//         response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
+//     }
+// }
+
+// #[launch]
+// pub fn rocket_builder() -> rocket::Rocket<Build> {
+//     rocket::build()
+//         .attach(PgConnection::fairing())
+//         .attach(CORS)
+//         .manage(channel::<Reading>(1024).0)
+//         .mount("/", routes![events])
+//         .mount("/api", routes![create_reading])
+//         .mount("/", FileServer::from(relative!("static")))
+// }
diff --git a/src/main.rs b/src/main.rs
index b283bbb..fe5bad4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,8 +1,71 @@
-use ambi_rs::rocket_builder;
+pub mod controller;
+pub mod models;
+pub mod schema;
 
-#[rocket::main]
-async fn main() -> Result<(), rocket::Error> {
-    env_logger::init();
-    let _ = rocket_builder().launch().await?;
-    Ok(())
+use axum::{
+    routing::{get, post},
+    Router,
+};
+
+use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
+use std::net::SocketAddr;
+use std::path::PathBuf;
+use tower_http::services::ServeDir;
+use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+use crate::controller::{create_reading, events_handler};
+
+// This embeds the migrations into the application binary
+// the migration path is relative to the `CARGO_MANIFEST_DIR`
+pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/");
+
+#[tokio::main]
+async fn main() {
+    //env_logger::init();
+
+    tracing_subscriber::registry()
+        .with(
+            tracing_subscriber::EnvFilter::try_from_default_env()
+                .unwrap_or_else(|_| "ambi_rs=debug,tower_http=debug".into()),
+        )
+        .with(tracing_subscriber::fmt::layer())
+        .init();
+
+    let db_url = std::env::var("DATABASE_URL").unwrap();
+
+    let static_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("static");
+
+    // set up connection pool
+    let manager = deadpool_diesel::postgres::Manager::new(db_url, deadpool_diesel::Runtime::Tokio1);
+    let pool = deadpool_diesel::postgres::Pool::builder(manager)
+        .build()
+        .unwrap();
+
+    // run the migrations on server startup
+    {
+        let conn = pool.get().await.unwrap();
+        conn.interact(|conn| conn.run_pending_migrations(MIGRATIONS).map(|_| ()))
+            .await
+            .unwrap()
+            .unwrap();
+    }
+
+    let static_files_service = ServeDir::new(static_dir).append_index_html_on_directories(true);
+
+    // build our application with some routes
+    let app = Router::new()
+        .fallback_service(static_files_service)
+        .route("/events", get(events_handler))
+        .route("/api/readings/add", post(create_reading))
+        .with_state(pool);
+
+    // run it with hyper
+    let addr = SocketAddr::from(([127, 0, 0, 1], 8000));
+    tracing::debug!("listening on {}", addr);
+    //let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
+    axum::Server::bind(&addr)
+        .serve(app.into_make_service())
+        .await
+        .unwrap()
+    //axum::serve(listener, app).await.unwrap();
 }
diff --git a/src/models.rs b/src/models.rs
index 5a53459..a72d154 100644
--- a/src/models.rs
+++ b/src/models.rs
@@ -1,9 +1,9 @@
-use crate::schema::readings;
-use diesel::Insertable;
+use diesel::{Insertable, Queryable, Selectable};
 use serde::{Deserialize, Serialize};
 
-#[derive(Debug, Clone, Deserialize, Serialize, Queryable)]
-#[serde(crate = "rocket::serde")]
+use crate::schema::readings;
+
+#[derive(Debug, Clone, Deserialize, Serialize, Queryable, Selectable)]
 pub struct Reading {
     pub id: i32,
     pub temperature: f64,
@@ -14,8 +14,7 @@ pub struct Reading {
 }
 
 #[derive(Debug, Insertable, Deserialize)]
-#[serde(crate = "rocket::serde")]
-#[table_name = "readings"]
+#[diesel(table_name = readings)]
 pub struct NewReading {
     pub temperature: f64,
     pub humidity: f64,
@@ -25,7 +24,6 @@ pub struct NewReading {
 }
 
 #[derive(Serialize, Deserialize, Debug)]
-#[serde(crate = "rocket::serde")]
 pub struct ApiError {
     pub details: String,
 }
diff --git a/src/schema.rs b/src/schema.rs
index 80fcf4a..e7d7051 100644
--- a/src/schema.rs
+++ b/src/schema.rs
@@ -1,5 +1,7 @@
 // @generated automatically by Diesel CLI.
 
+use diesel::prelude::*;
+
 table! {
     readings (id) {
         id -> Int4,