kataglyphis_rustprojecttemplate/api/
onnx.rs

1use log::error;
2
3use crate::detection::Detection;
4
5#[flutter_rust_bridge::frb(sync)]
6pub fn detect_persons_rgba(
7    model_path: String,
8    rgba: Vec<u8>,
9    width: u32,
10    height: u32,
11    score_threshold: f32,
12) -> Vec<Detection> {
13    // Note: `person_detection` is feature-gated; keep this function compilable even when
14    // ONNX features are disabled (e.g. for WASM builds).
15    let resolved_model_path = if model_path.trim().is_empty() {
16        #[cfg(onnx)]
17        {
18            crate::person_detection::default_model_path()
19                .to_string_lossy()
20                .to_string()
21        }
22
23        #[cfg(not(onnx))]
24        {
25            String::new()
26        }
27    } else {
28        model_path
29    };
30
31    match detect_persons_rgba_impl(&resolved_model_path, &rgba, width, height, score_threshold) {
32        Ok(v) => v,
33        Err(e) => {
34            error!("detect_persons_rgba failed: {e:#}");
35            Vec::new()
36        }
37    }
38}
39
40#[cfg(onnx)]
41fn detect_persons_rgba_impl(
42    model_path: &str,
43    rgba: &[u8],
44    width: u32,
45    height: u32,
46    score_threshold: f32,
47) -> anyhow::Result<Vec<Detection>> {
48    use std::sync::{Mutex, OnceLock};
49
50    use crate::person_detection::PersonDetector;
51
52    struct Cached {
53        model_path: String,
54        detector: PersonDetector,
55    }
56
57    static DETECTOR: OnceLock<Mutex<Option<Cached>>> = OnceLock::new();
58    let mutex = DETECTOR.get_or_init(|| Mutex::new(None));
59
60    let lock_guard = || {
61        mutex
62            .lock()
63            .unwrap_or_else(|poisoned| poisoned.into_inner())
64    };
65
66    // Check whether a reload is needed while holding the lock briefly.
67    let needs_reload = {
68        let guard = lock_guard();
69        guard
70            .as_ref()
71            .map(|c| c.model_path != model_path)
72            .unwrap_or(true)
73    };
74
75    // Load the model *outside* the lock so concurrent callers aren't blocked
76    // for the (potentially multi-second) model load.
77    if needs_reload {
78        let detector = PersonDetector::new(model_path)?;
79        let mut guard = lock_guard();
80        // Re-check: another thread may have loaded the same model while we
81        // were loading ours.
82        let still_needs = guard
83            .as_ref()
84            .map(|c| c.model_path != model_path)
85            .unwrap_or(true);
86        if still_needs {
87            *guard = Some(Cached {
88                model_path: model_path.to_string(),
89                detector,
90            });
91        }
92    }
93
94    let mut guard = lock_guard();
95    let Some(cached) = guard.as_mut() else {
96        return Ok(Vec::new());
97    };
98    cached
99        .detector
100        .infer_persons_rgba(rgba, width, height, score_threshold)
101}
102
103#[cfg(not(onnx))]
104fn detect_persons_rgba_impl(
105    _model_path: &str,
106    _rgba: &[u8],
107    _width: u32,
108    _height: u32,
109    _score_threshold: f32,
110) -> anyhow::Result<Vec<Detection>> {
111    anyhow::bail!("ONNX inference is disabled. Build with --features onnx_tract (or onnxruntime*)")
112}