1use std::net::SocketAddr;
4
5use futures::prelude::*;
6use thiserror::Error;
7use tokio::net::TcpStream;
8
9use crate::{
10 api::flat::{self, message, FlatApiError},
11 global::{Global, InputMessage, InputSourceHandle, PriorityGuard},
12};
13
14#[derive(Debug, Error)]
15pub enum FlatServerError {
16 #[error("i/o error: {0}")]
17 Io(#[from] futures_io::Error),
18 #[error("error decoding frame")]
19 FlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
20 #[error(transparent)]
21 Api(#[from] FlatApiError),
22}
23
24fn register_response(builder: &mut flatbuffers::FlatBufferBuilder, priority: i32) -> bytes::Bytes {
25 let mut reply = message::ReplyBuilder::new(builder);
26 reply.add_registered(priority);
27
28 let reply = reply.finish();
29
30 builder.finish(reply, None);
31 bytes::Bytes::copy_from_slice(builder.finished_data())
32}
33
34fn error_response(
35 builder: &mut flatbuffers::FlatBufferBuilder,
36 error: impl std::fmt::Display,
37) -> bytes::Bytes {
38 let error = builder.create_string(error.to_string().as_str());
39
40 let mut reply = message::ReplyBuilder::new(builder);
41 reply.add_error(error);
42
43 let reply = reply.finish();
44
45 builder.finish(reply, None);
46 bytes::Bytes::copy_from_slice(builder.finished_data())
47}
48
49async fn handle_request(
50 peer_addr: SocketAddr,
51 request_bytes: bytes::BytesMut,
52 source: &mut Option<InputSourceHandle<InputMessage>>,
53 global: &Global,
54 priority_guard: &mut Option<PriorityGuard>,
55) -> Result<(), FlatServerError> {
56 let request = message::root_as_request(request_bytes.as_ref())?;
57
58 trace!(request = ?request.command_type(), "processing");
59
60 Ok(flat::handle_request(peer_addr, request, source, global, priority_guard).await?)
61}
62
63#[instrument(skip(socket, global))]
64pub async fn handle_client(
65 (socket, peer_addr): (TcpStream, SocketAddr),
66 global: Global,
67) -> Result<(), FlatServerError> {
68 debug!("accepted new connection");
69
70 let framed = tokio_util::codec::LengthDelimitedCodec::builder()
71 .length_field_length(4)
72 .new_framed(socket);
73
74 let (mut writer, mut reader) = framed.split();
75
76 let mut source = None;
77 let mut priority_guard = None;
78 let mut builder = flatbuffers::FlatBufferBuilder::new();
79
80 while let Some(request_bytes) = reader.next().await {
81 let request_bytes = match request_bytes {
82 Ok(rb) => rb,
83 Err(error) => {
84 error!(error = %error, "error reading frame");
85 continue;
86 }
87 };
88
89 builder.reset();
90
91 let reply = match handle_request(
92 peer_addr,
93 request_bytes,
94 &mut source,
95 &global,
96 &mut priority_guard,
97 )
98 .await
99 {
100 Ok(()) => {
101 if let Some(source) = source.as_ref() {
102 register_response(&mut builder, source.priority().unwrap())
103 } else {
104 error_response(&mut builder, "unregistered source")
105 }
106 }
107 Err(error) => {
108 error!(error = %error, "error processing request");
109
110 error_response(&mut builder, error)
111 }
112 };
113
114 trace!(response = ?reply, "sending");
115 writer.send(reply).await?;
116 writer.flush().await?;
117 }
118
119 Ok(())
120}