hyperion/effects/providers/python/
context.rs

1use std::{
2    cell::RefCell,
3    panic,
4    sync::{Arc, Once, Weak},
5};
6
7use drop_bomb::DropBomb;
8use futures::Future;
9use pyo3::prelude::*;
10
11use super::{hyperion, RuntimeMethods};
12
13static INITIALIZED_PYTHON: Once = Once::new();
14
15thread_local! {
16    /// Current effect context
17    static CONTEXT: RefCell<Option<Context>> = const { RefCell::new(None) };
18}
19
20/// Python effect module context
21pub struct Context {
22    tstate: *mut pyo3::ffi::PyThreadState,
23    methods: Weak<dyn RuntimeMethods>,
24    bomb: DropBomb,
25}
26
27impl Context {
28    unsafe fn new(_py: Python, methods: Weak<dyn RuntimeMethods>) -> Result<Self, ()> {
29        // Get the main_state ptr
30        let main_state = pyo3::ffi::PyEval_SaveThread();
31
32        // Acquire GIL again
33        pyo3::ffi::PyEval_RestoreThread(main_state);
34
35        // Create new subinterp
36        let tstate = pyo3::ffi::Py_NewInterpreter();
37
38        // Restore GIL
39        pyo3::ffi::PyThreadState_Swap(main_state);
40
41        // Return object
42        if tstate.is_null() {
43            Err(())
44        } else {
45            Ok(Self {
46                tstate,
47                methods,
48                bomb: DropBomb::new("Context::release must be called before dropping it"),
49            })
50        }
51    }
52
53    unsafe fn release(&mut self, _py: Python) {
54        // TODO: Stop sub threads?
55
56        // Make this context subinterp current
57        let main_thread = pyo3::ffi::PyThreadState_Swap(self.tstate);
58
59        // Terminate it
60        pyo3::ffi::Py_EndInterpreter(self.tstate);
61
62        // Restore the main thread
63        pyo3::ffi::PyThreadState_Swap(main_thread);
64
65        // We're clear for dropping
66        self.bomb.defuse();
67    }
68
69    pub fn run<U>(&self, _py: Python, f: impl FnOnce() -> U) -> U {
70        unsafe {
71            // Switch to the context thread
72            let main_state = pyo3::ffi::PyThreadState_Swap(self.tstate);
73
74            // Run user function
75            let result = panic::catch_unwind(panic::AssertUnwindSafe(f));
76
77            // Switch back to the main thread
78            pyo3::ffi::PyThreadState_Swap(main_state);
79
80            // Return result
81            match result {
82                Ok(result) => result,
83                Err(panic) => panic::panic_any(panic),
84            }
85        }
86    }
87
88    pub fn with<U>(methods: Arc<dyn RuntimeMethods>, f: impl FnOnce(&Self) -> U) -> U {
89        unsafe {
90            // Initialize the Python interpreter global state
91            INITIALIZED_PYTHON.call_once(|| {
92                // Register our module through inittab
93                pyo3::append_to_inittab!(hyperion);
94                Python::initialize();
95            });
96
97            let result = CONTEXT.with(|ctx| {
98                // Initialize the thread-local state, i.e. interpreter
99                *ctx.borrow_mut() = Some(Python::attach(|py| {
100                    Self::new(py, Arc::downgrade(&methods))
101                        .expect("failed initializing python subinterp")
102                }));
103
104                // Run user callback
105                let result = {
106                    let borrow = ctx.borrow();
107                    let ctx = borrow.as_ref().unwrap();
108                    panic::catch_unwind(panic::AssertUnwindSafe(|| f(ctx)))
109                };
110
111                // Free the interpreter
112                if let Some(mut ctx) = ctx.borrow_mut().take() {
113                    Python::attach(|py| {
114                        ctx.release(py);
115                    })
116                }
117
118                result
119            });
120
121            // Return result
122            match result {
123                Ok(result) => result,
124                Err(panic) => panic::panic_any(panic),
125            }
126        }
127    }
128
129    pub fn with_current<F, U>(f: impl FnOnce(Arc<dyn RuntimeMethods>) -> F) -> U
130    where
131        F: Future<Output = U>,
132    {
133        CONTEXT.with(|ctx| {
134            futures::executor::block_on(f(ctx
135                .borrow()
136                .as_ref()
137                .expect("no current context")
138                .methods
139                .upgrade()
140                .expect("no current methods")))
141        })
142    }
143}