1use std::{convert::TryInto, fmt::Display, num::NonZeroUsize, sync::Arc};
2
3use lru::LruCache;
4use thiserror::Error;
5use tokio::sync::RwLock;
6use warp::{ws::Message, Filter, Rejection, Reply};
7
8use crate::{
9 api::json::{
10 message::{HyperionMessage, HyperionResponse},
11 ClientConnection, JsonApiError,
12 },
13 global::{Global, InputSourceError},
14};
15
16#[derive(Debug, Error)]
17pub enum SessionError {
18 #[error(transparent)]
19 InputSource(#[from] InputSourceError),
20 #[error(transparent)]
21 Api(#[from] JsonApiError),
22 #[error("not implemented")]
23 NotImplemented,
24 #[error("invalid request: {0}")]
25 Serde(#[from] serde_json::Error),
26}
27
28#[derive(Default, Debug)]
29pub struct Session {
30 id: uuid::Uuid,
31 json_api: Option<ClientConnection>,
32}
33
34impl Session {
35 async fn json_api(&mut self, global: &Global) -> Result<&mut ClientConnection, SessionError> {
36 if self.json_api.is_none() {
37 if self.id.is_nil() {
39 self.id = uuid::Uuid::new_v4();
40 }
41
42 self.json_api = Some(ClientConnection::new(
44 global
45 .register_input_source(
46 crate::global::InputSourceName::Web {
47 session_id: self.id,
48 },
49 None,
50 )
51 .await?,
52 ));
53 }
54
55 Ok(self.json_api.as_mut().unwrap())
56 }
57
58 async fn handle_message(
59 &mut self,
60 global: &Global,
61 message: Message,
62 ) -> Result<Message, SessionError> {
63 let json_api = self.json_api(global).await?;
64
65 if message.is_text() {
66 let request: HyperionMessage = serde_json::from_str(message.to_str().unwrap())?;
67 let response = json_api.handle_request(request, global).await?;
68 return Ok(Message::text(serde_json::to_string(&response).unwrap()));
69 }
70
71 Err(SessionError::NotImplemented)
72 }
73
74 fn error_message<T: Display>(&self, e: T) -> Message {
75 Message::text(
76 serde_json::to_string(&serde_json::json!({ "error": e.to_string() })).unwrap(),
77 )
78 }
79
80 #[instrument(skip(global, result))]
81 pub async fn handle_result(
82 &mut self,
83 global: &Global,
84 result: Result<Message, warp::Error>,
85 ) -> Option<Message> {
86 match result {
87 Ok(message) => {
88 trace!(message = ?message, "ws message");
89
90 if message.is_close() {
91 return None;
92 }
93
94 let response = self.handle_message(global, message).await;
95
96 trace!(response = ?response, "ws response");
97
98 match response {
99 Ok(message) => Some(message),
100 Err(error) => Some(self.error_message(error)),
101 }
102 }
103 Err(error) => Some(self.error_message(error)),
104 }
105 }
106
107 #[instrument(skip(global, request))]
108 pub async fn handle_request(
109 &mut self,
110 global: &Global,
111 request: HyperionMessage,
112 ) -> HyperionResponse {
113 trace!(request = ?request, "JSON RPC request");
114
115 let tan = request.tan;
116 let api = match self.json_api(global).await {
117 Ok(api) => api,
118 Err(error) => {
119 return HyperionResponse::error(&error).with_tan(tan);
120 }
121 };
122
123 let response = match api.handle_request(request, global).await {
124 Ok(response) => response,
125 Err(error) => {
126 error!(error = %error, "error processing request");
127 HyperionResponse::error(&error)
128 }
129 };
130
131 trace!(response = ?response, "ws response");
132 response.with_tan(tan)
133 }
134}
135
136const COOKIE_NAME: &str = "hyperion_rs_sid";
137
138type SessionData = Arc<RwLock<LruCache<uuid::Uuid, Arc<RwLock<Session>>>>>;
139
140#[derive(Clone)]
141pub struct SessionStore {
142 sessions: SessionData,
143}
144
145pub struct SessionInstance {
146 session: Arc<RwLock<Session>>,
147 sessions: SessionData,
148}
149
150impl SessionInstance {
151 pub fn session(&self) -> &Arc<RwLock<Session>> {
152 &self.session
153 }
154}
155
156pub struct WithSession<T: Reply> {
157 reply: T,
158 set_cookie: Option<String>,
159}
160
161impl<T: Reply> WithSession<T> {
162 pub async fn new(reply: T, instance: SessionInstance) -> Self {
163 let id = instance.session.read().await.id;
164
165 let set_cookie = if instance.sessions.read().await.peek(&id).is_none() {
166 let mut sessions = instance.sessions.write().await;
167
168 if sessions.put(id, instance.session.clone()).is_none() {
169 Some(id.to_string())
170 } else {
171 None
173 }
174 } else {
175 None
177 };
178
179 Self { reply, set_cookie }
180 }
181}
182
183impl<T: Reply> Reply for WithSession<T> {
184 fn into_response(self) -> warp::reply::Response {
185 let mut inner = self.reply.into_response();
186
187 if let Some(cookie_value) = self.set_cookie {
188 inner.headers_mut().insert(
190 "Set-Cookie",
191 cookie::Cookie::build((COOKIE_NAME, cookie_value))
192 .to_string()
193 .try_into()
194 .unwrap(),
195 );
196 }
197
198 inner
199 }
200}
201
202impl SessionStore {
203 pub fn new(max_sessions: NonZeroUsize) -> Self {
204 Self {
205 sessions: Arc::new(RwLock::new(LruCache::new(max_sessions))),
206 }
207 }
208
209 pub fn request(
210 &self,
211 ) -> impl Filter<Extract = (SessionInstance,), Error = Rejection> + Clone + 'static {
212 let sessions = self.sessions.clone();
213
214 warp::any()
215 .and(warp::any().map(move || sessions.clone()))
216 .and(warp::cookie::optional(COOKIE_NAME))
217 .and_then(
218 |sessions: SessionData, sid_cookie: Option<String>| async move {
219 match sid_cookie
220 .and_then(|cookie_value| uuid::Uuid::parse_str(&cookie_value).ok())
221 {
222 Some(sid) => {
223 let session = sessions.write().await.get(&sid).cloned();
225
226 let session = if let Some(session) = session {
228 session
229 } else {
230 Arc::new(RwLock::new(Session::default()))
231 };
232
233 Ok::<_, Rejection>(SessionInstance {
234 session,
235 sessions: sessions.clone(),
236 })
237 }
238 None => {
239 Ok::<_, Rejection>(SessionInstance {
241 session: Arc::new(RwLock::new(Session::default())),
242 sessions: sessions.clone(),
243 })
244 }
245 }
246 },
247 )
248 }
249}
250
251pub async fn reply_session<T: Reply>(
252 reply: T,
253 session: SessionInstance,
254) -> Result<WithSession<T>, Rejection> {
255 Ok(WithSession::new(reply, session).await)
256}