diff --git a/src/server/middleware/mod.rs b/src/server/middleware/mod.rs new file mode 100644 index 0000000..708af95 --- /dev/null +++ b/src/server/middleware/mod.rs @@ -0,0 +1,108 @@ +use http::{Request, Response, StatusCode}; +use pin_project_lite::pin_project; +use std::{env, future::Future, pin::Pin, task::{Context, Poll}}; +use cookie::Cookie; +use dotenvy::dotenv; +use http::header::COOKIE; +use tower_layer::Layer; +use tower_service::Service; + +#[derive(Debug, Clone, Copy)] +pub struct AuthLayer {} + +impl AuthLayer { + pub fn new() -> Self { + AuthLayer {} + } +} + +impl Layer for AuthLayer { + type Service = Auth; + + fn layer(&self, inner: S) -> Self::Service { + Auth::new(inner) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct Auth { + inner: S, +} + +impl Auth { + pub fn new(inner: S) -> Self { + Self { inner } + } + + pub fn layer() -> AuthLayer { + AuthLayer::new() + } +} + +impl Service> for Auth +where + S: Service, Response=Response>, + ResBody: Default, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let cookies = req.headers() + .get(COOKIE) + .and_then(|header| header.to_str().ok()) + .map(Cookie::split_parse_encoded); + + let token = cookies.and_then( + |cookies| cookies + .filter_map(|cookie| cookie.ok()) + .find(|cookie| cookie.name() == "auth_token") + .map(|cookie| cookie.value().to_string()) + ); + + ResponseFuture { + inner: self.inner.call(req), + token, + } + } +} + +pin_project! { + pub struct ResponseFuture { + #[pin] + inner: F, + #[pin] + token: Option, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + B: Default, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + dotenv().expect("Could not load environment variables from the .env file."); + + let this = self.project(); + + if !(*this.token).as_ref().is_some_and( + |token| token == env::var("AUTH_TOKEN") + .expect("The environment variable DATABASE_URL has to be set.").as_str() + ) { + let mut res = Response::new(B::default()); + *res.status_mut() = StatusCode::UNAUTHORIZED; + return Poll::Ready(Ok(res)); + } + + this.inner.poll(cx) + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 50ad1c8..512e582 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -3,3 +3,4 @@ pub(crate) mod projects; pub(crate) mod tasks; pub(crate) mod subtasks; pub(crate) mod internationalization; +mod middleware; diff --git a/src/server/tasks.rs b/src/server/tasks.rs index 192726f..7037b8a 100644 --- a/src/server/tasks.rs +++ b/src/server/tasks.rs @@ -14,10 +14,14 @@ use crate::models::subtask::Subtask; use crate::server::subtasks::restore_subtasks_of_task; #[server] +#[middleware(crate::server::middleware::AuthLayer::new())] pub(crate) async fn create_task(new_task: NewTask) -> Result>> { use crate::schema::tasks; + let headers: http::HeaderMap = extract().await.unwrap(); + todo!(); + new_task.validate() .map_err::, _>(|errors| errors.into())?;