hyperion/servers/
flat.rs

1//! flatbuffers flatcol server implementation
2
3use std::net::SocketAddr;
4
5use futures::prelude::*;
6use thiserror::Error;
7use tokio::net::TcpStream;
8
9use crate::{
10    api::flat::{self, message, FlatApiError},
11    global::{Global, InputMessage, InputSourceHandle, PriorityGuard},
12};
13
14#[derive(Debug, Error)]
15pub enum FlatServerError {
16    #[error("i/o error: {0}")]
17    Io(#[from] futures_io::Error),
18    #[error("error decoding frame")]
19    FlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
20    #[error(transparent)]
21    Api(#[from] FlatApiError),
22}
23
24fn register_response(builder: &mut flatbuffers::FlatBufferBuilder, priority: i32) -> bytes::Bytes {
25    let mut reply = message::ReplyBuilder::new(builder);
26    reply.add_registered(priority);
27
28    let reply = reply.finish();
29
30    builder.finish(reply, None);
31    bytes::Bytes::copy_from_slice(builder.finished_data())
32}
33
34fn error_response(
35    builder: &mut flatbuffers::FlatBufferBuilder,
36    error: impl std::fmt::Display,
37) -> bytes::Bytes {
38    let error = builder.create_string(error.to_string().as_str());
39
40    let mut reply = message::ReplyBuilder::new(builder);
41    reply.add_error(error);
42
43    let reply = reply.finish();
44
45    builder.finish(reply, None);
46    bytes::Bytes::copy_from_slice(builder.finished_data())
47}
48
49async fn handle_request(
50    peer_addr: SocketAddr,
51    request_bytes: bytes::BytesMut,
52    source: &mut Option<InputSourceHandle<InputMessage>>,
53    global: &Global,
54    priority_guard: &mut Option<PriorityGuard>,
55) -> Result<(), FlatServerError> {
56    let request = message::root_as_request(request_bytes.as_ref())?;
57
58    trace!(request = ?request.command_type(), "processing");
59
60    Ok(flat::handle_request(peer_addr, request, source, global, priority_guard).await?)
61}
62
63#[instrument(skip(socket, global))]
64pub async fn handle_client(
65    (socket, peer_addr): (TcpStream, SocketAddr),
66    global: Global,
67) -> Result<(), FlatServerError> {
68    debug!("accepted new connection");
69
70    let framed = tokio_util::codec::LengthDelimitedCodec::builder()
71        .length_field_length(4)
72        .new_framed(socket);
73
74    let (mut writer, mut reader) = framed.split();
75
76    let mut source = None;
77    let mut priority_guard = None;
78    let mut builder = flatbuffers::FlatBufferBuilder::new();
79
80    while let Some(request_bytes) = reader.next().await {
81        let request_bytes = match request_bytes {
82            Ok(rb) => rb,
83            Err(error) => {
84                error!(error = %error, "error reading frame");
85                continue;
86            }
87        };
88
89        builder.reset();
90
91        let reply = match handle_request(
92            peer_addr,
93            request_bytes,
94            &mut source,
95            &global,
96            &mut priority_guard,
97        )
98        .await
99        {
100            Ok(()) => {
101                if let Some(source) = source.as_ref() {
102                    register_response(&mut builder, source.priority().unwrap())
103                } else {
104                    error_response(&mut builder, "unregistered source")
105                }
106            }
107            Err(error) => {
108                error!(error = %error, "error processing request");
109
110                error_response(&mut builder, error)
111            }
112        };
113
114        trace!(response = ?reply, "sending");
115        writer.send(reply).await?;
116        writer.flush().await?;
117    }
118
119    Ok(())
120}