kataglyphis_rustprojecttemplate/api/
onnx.rs

1use log::error;
2
3/// Person detection result in *original image pixel coordinates*.
4#[derive(Clone, Debug)]
5pub struct PersonDetection {
6    pub x1: f32,
7    pub y1: f32,
8    pub x2: f32,
9    pub y2: f32,
10    pub score: f32,
11    pub class_id: i64,
12}
13
14#[flutter_rust_bridge::frb(sync)]
15pub fn detect_persons_rgba(
16    model_path: String,
17    rgba: Vec<u8>,
18    width: u32,
19    height: u32,
20    score_threshold: f32,
21) -> Vec<PersonDetection> {
22    // Note: `person_detection` is feature-gated; keep this function compilable even when
23    // ONNX features are disabled (e.g. for WASM builds).
24    let resolved_model_path = if model_path.trim().is_empty() {
25        #[cfg(any(feature = "onnx_tract", feature = "onnxruntime"))]
26        {
27            crate::person_detection::default_model_path()
28                .to_string_lossy()
29                .to_string()
30        }
31
32        #[cfg(not(any(feature = "onnx_tract", feature = "onnxruntime")))]
33        {
34            String::new()
35        }
36    } else {
37        model_path
38    };
39
40    match detect_persons_rgba_impl(&resolved_model_path, &rgba, width, height, score_threshold) {
41        Ok(v) => v,
42        Err(e) => {
43            error!("detect_persons_rgba failed: {e:#}");
44            Vec::new()
45        }
46    }
47}
48
49#[cfg(any(feature = "onnx_tract", feature = "onnxruntime"))]
50fn detect_persons_rgba_impl(
51    model_path: &str,
52    rgba: &[u8],
53    width: u32,
54    height: u32,
55    score_threshold: f32,
56) -> anyhow::Result<Vec<PersonDetection>> {
57    use std::sync::{Mutex, OnceLock};
58
59    use crate::person_detection::PersonDetector;
60
61    struct Cached {
62        model_path: String,
63        detector: PersonDetector,
64    }
65
66    static DETECTOR: OnceLock<Mutex<Option<Cached>>> = OnceLock::new();
67    let mutex = DETECTOR.get_or_init(|| Mutex::new(None));
68
69    let mut guard = mutex.lock().expect("Detector mutex poisoned");
70
71    let needs_reload = guard
72        .as_ref()
73        .map(|c| c.model_path != model_path)
74        .unwrap_or(true);
75
76    if needs_reload {
77        let detector = PersonDetector::new(model_path)?;
78        *guard = Some(Cached {
79            model_path: model_path.to_string(),
80            detector,
81        });
82    }
83
84    let detector = guard.as_ref().expect("Detector missing");
85    let dets = detector
86        .detector
87        .infer_persons_rgba(rgba, width, height, score_threshold)?;
88
89    Ok(dets
90        .into_iter()
91        .map(|d| PersonDetection {
92            x1: d.x1,
93            y1: d.y1,
94            x2: d.x2,
95            y2: d.y2,
96            score: d.score,
97            class_id: d.class_id,
98        })
99        .collect())
100}
101
102#[cfg(not(any(feature = "onnx_tract", feature = "onnxruntime")))]
103fn detect_persons_rgba_impl(
104    _model_path: &str,
105    _rgba: &[u8],
106    _width: u32,
107    _height: u32,
108    _score_threshold: f32,
109) -> anyhow::Result<Vec<PersonDetection>> {
110    anyhow::bail!("ONNX inference is disabled. Build with --features onnx_tract (or onnxruntime*)")
111}