hyperion/web/
session.rs

1use std::{convert::TryInto, fmt::Display, num::NonZeroUsize, sync::Arc};
2
3use lru::LruCache;
4use thiserror::Error;
5use tokio::sync::RwLock;
6use warp::{ws::Message, Filter, Rejection, Reply};
7
8use crate::{
9    api::json::{
10        message::{HyperionMessage, HyperionResponse},
11        ClientConnection, JsonApiError,
12    },
13    global::{Global, InputSourceError},
14};
15
16#[derive(Debug, Error)]
17pub enum SessionError {
18    #[error(transparent)]
19    InputSource(#[from] InputSourceError),
20    #[error(transparent)]
21    Api(#[from] JsonApiError),
22    #[error("not implemented")]
23    NotImplemented,
24    #[error("invalid request: {0}")]
25    Serde(#[from] serde_json::Error),
26}
27
28#[derive(Default, Debug)]
29pub struct Session {
30    id: uuid::Uuid,
31    json_api: Option<ClientConnection>,
32}
33
34impl Session {
35    async fn json_api(&mut self, global: &Global) -> Result<&mut ClientConnection, SessionError> {
36        if self.json_api.is_none() {
37            // Generate a session ID
38            if self.id.is_nil() {
39                self.id = uuid::Uuid::new_v4();
40            }
41
42            // Can't use SocketAddr, see https://github.com/seanmonstar/warp/issues/830
43            self.json_api = Some(ClientConnection::new(
44                global
45                    .register_input_source(
46                        crate::global::InputSourceName::Web {
47                            session_id: self.id,
48                        },
49                        None,
50                    )
51                    .await?,
52            ));
53        }
54
55        Ok(self.json_api.as_mut().unwrap())
56    }
57
58    async fn handle_message(
59        &mut self,
60        global: &Global,
61        message: Message,
62    ) -> Result<Message, SessionError> {
63        let json_api = self.json_api(global).await?;
64
65        if message.is_text() {
66            let request: HyperionMessage = serde_json::from_str(message.to_str().unwrap())?;
67            let response = json_api.handle_request(request, global).await?;
68            return Ok(Message::text(serde_json::to_string(&response).unwrap()));
69        }
70
71        Err(SessionError::NotImplemented)
72    }
73
74    fn error_message<T: Display>(&self, e: T) -> Message {
75        Message::text(
76            serde_json::to_string(&serde_json::json!({ "error": e.to_string() })).unwrap(),
77        )
78    }
79
80    #[instrument(skip(global, result))]
81    pub async fn handle_result(
82        &mut self,
83        global: &Global,
84        result: Result<Message, warp::Error>,
85    ) -> Option<Message> {
86        match result {
87            Ok(message) => {
88                trace!(message = ?message, "ws message");
89
90                if message.is_close() {
91                    return None;
92                }
93
94                let response = self.handle_message(global, message).await;
95
96                trace!(response = ?response, "ws response");
97
98                match response {
99                    Ok(message) => Some(message),
100                    Err(error) => Some(self.error_message(error)),
101                }
102            }
103            Err(error) => Some(self.error_message(error)),
104        }
105    }
106
107    #[instrument(skip(global, request))]
108    pub async fn handle_request(
109        &mut self,
110        global: &Global,
111        request: HyperionMessage,
112    ) -> HyperionResponse {
113        trace!(request = ?request, "JSON RPC request");
114
115        let tan = request.tan;
116        let api = match self.json_api(global).await {
117            Ok(api) => api,
118            Err(error) => {
119                return HyperionResponse::error(&error).with_tan(tan);
120            }
121        };
122
123        let response = match api.handle_request(request, global).await {
124            Ok(response) => response,
125            Err(error) => {
126                error!(error =  %error, "error processing request");
127                HyperionResponse::error(&error)
128            }
129        };
130
131        trace!(response = ?response, "ws response");
132        response.with_tan(tan)
133    }
134}
135
136const COOKIE_NAME: &str = "hyperion_rs_sid";
137
138type SessionData = Arc<RwLock<LruCache<uuid::Uuid, Arc<RwLock<Session>>>>>;
139
140#[derive(Clone)]
141pub struct SessionStore {
142    sessions: SessionData,
143}
144
145pub struct SessionInstance {
146    session: Arc<RwLock<Session>>,
147    sessions: SessionData,
148}
149
150impl SessionInstance {
151    pub fn session(&self) -> &Arc<RwLock<Session>> {
152        &self.session
153    }
154}
155
156pub struct WithSession<T: Reply> {
157    reply: T,
158    set_cookie: Option<String>,
159}
160
161impl<T: Reply> WithSession<T> {
162    pub async fn new(reply: T, instance: SessionInstance) -> Self {
163        let id = instance.session.read().await.id;
164
165        let set_cookie = if instance.sessions.read().await.peek(&id).is_none() {
166            let mut sessions = instance.sessions.write().await;
167
168            if sessions.put(id, instance.session.clone()).is_none() {
169                Some(id.to_string())
170            } else {
171                // Not the same ID, another request set the cookie first
172                None
173            }
174        } else {
175            // Already have an ID, no need for more locking
176            None
177        };
178
179        Self { reply, set_cookie }
180    }
181}
182
183impl<T: Reply> Reply for WithSession<T> {
184    fn into_response(self) -> warp::reply::Response {
185        let mut inner = self.reply.into_response();
186
187        if let Some(cookie_value) = self.set_cookie {
188            // TODO: Other cookie options?
189            inner.headers_mut().insert(
190                "Set-Cookie",
191                cookie::Cookie::build((COOKIE_NAME, cookie_value))
192                    .to_string()
193                    .try_into()
194                    .unwrap(),
195            );
196        }
197
198        inner
199    }
200}
201
202impl SessionStore {
203    pub fn new(max_sessions: NonZeroUsize) -> Self {
204        Self {
205            sessions: Arc::new(RwLock::new(LruCache::new(max_sessions))),
206        }
207    }
208
209    pub fn request(
210        &self,
211    ) -> impl Filter<Extract = (SessionInstance,), Error = Rejection> + Clone + 'static {
212        let sessions = self.sessions.clone();
213
214        warp::any()
215            .and(warp::any().map(move || sessions.clone()))
216            .and(warp::cookie::optional(COOKIE_NAME))
217            .and_then(
218                |sessions: SessionData, sid_cookie: Option<String>| async move {
219                    match sid_cookie
220                        .and_then(|cookie_value| uuid::Uuid::parse_str(&cookie_value).ok())
221                    {
222                        Some(sid) => {
223                            // Get the existing session
224                            let session = sessions.write().await.get(&sid).cloned();
225
226                            // Create if the ID is not found
227                            let session = if let Some(session) = session {
228                                session
229                            } else {
230                                Arc::new(RwLock::new(Session::default()))
231                            };
232
233                            Ok::<_, Rejection>(SessionInstance {
234                                session,
235                                sessions: sessions.clone(),
236                            })
237                        }
238                        None => {
239                            // No session yet, create one
240                            Ok::<_, Rejection>(SessionInstance {
241                                session: Arc::new(RwLock::new(Session::default())),
242                                sessions: sessions.clone(),
243                            })
244                        }
245                    }
246                },
247            )
248    }
249}
250
251pub async fn reply_session<T: Reply>(
252    reply: T,
253    session: SessionInstance,
254) -> Result<WithSession<T>, Rejection> {
255    Ok(WithSession::new(reply, session).await)
256}