Skip to main content

ove/
infer.rs

1// Copyright (C) 2026 Kamil Lulko <kamil.lulko@gmail.com>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4//
5// This file is part of oveRTOS.
6
7//! ML inference primitives for oveRTOS.
8//!
9//! Provides a safe Rust wrapper around the `ove_model_*` C API for
10//! running TFLite model inference.
11
12use crate::bindings;
13use crate::error::{Error, Result};
14
15// SAFETY (module-wide contract for the `unsafe { bindings::ove_*(...) }` FFI
16// calls below): any handle passed to the C API is non-null and refers to a
17// live RTOS object — wrapper constructors establish validity via
18// `Error::from_code`, and `Drop` (or an explicit `deinit`) is the only place
19// a handle is released. Pointer and slice arguments reference caller-owned
20// memory valid for the duration of the call; the C side copies whatever it
21// retains and does not alias them past return (verified against the
22// signatures in `include/ove/*.h`). Blocks that deviate — `transmute`, raw
23// pointer casts from user data, slice reconstruction via `from_raw_parts`,
24// or storing a callback across the FFI boundary — carry their own
25// `// SAFETY:` comment.
26
27/// Tensor element types matching the C `enum ove_tensor_type`.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[repr(u32)]
30pub enum TensorType {
31    Float32 = 0,
32    Int8 = 1,
33    Uint8 = 2,
34    Int16 = 3,
35    Int32 = 4,
36}
37
38/// Tensor metadata descriptor.
39#[derive(Debug, Clone)]
40pub struct TensorInfo {
41    /// Pointer to tensor data in the arena.
42    pub data: *mut u8,
43    /// Total size in bytes.
44    pub size: usize,
45    /// Element type.
46    pub tensor_type: TensorType,
47    /// Number of dimensions.
48    pub ndims: u32,
49    /// Shape array (up to 5 dimensions).
50    pub dims: [i32; 5],
51}
52
53/// Model configuration.
54pub struct ModelConfig<'a> {
55    /// Reference to the .tflite FlatBuffer data.
56    pub model_data: &'a [u8],
57    /// Tensor arena size in bytes.
58    pub arena_size: usize,
59}
60
61/// An ML inference model session.
62///
63/// Wraps a TFLM `MicroInterpreter` with automatic cleanup.
64pub struct Model {
65    handle: bindings::ove_model_t,
66}
67
68impl Model {
69    /// Create a new model via heap allocation.
70    #[cfg(not(zero_heap))]
71    pub fn new(config: &ModelConfig) -> Result<Self> {
72        let mut handle: bindings::ove_model_t = core::ptr::null_mut();
73        let c_cfg = bindings::ove_model_config {
74            model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
75            model_size: config.model_data.len(),
76            arena_size: config.arena_size,
77        };
78        let rc = unsafe { bindings::ove_model_create(&mut handle, &c_cfg) };
79        Error::from_code(rc)?;
80        Ok(Self { handle })
81    }
82
83    /// Create from caller-provided storage and arena.
84    ///
85    /// Available in both heap and zero-heap modes.  Useful when the same
86    /// storage/arena must be reused for different models (e.g. two-stage
87    /// inference pipelines).
88    ///
89    /// # Safety
90    /// Caller must ensure `storage` and `arena` outlive the `Model` and are
91    /// not shared with another primitive.
92    pub unsafe fn from_static(
93        storage: *mut bindings::ove_model_storage_t,
94        arena: *mut u8,
95        config: &ModelConfig,
96    ) -> Result<Self> {
97        let mut handle: bindings::ove_model_t = core::ptr::null_mut();
98        let c_cfg = bindings::ove_model_config {
99            model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
100            model_size: config.model_data.len(),
101            arena_size: config.arena_size,
102        };
103        let rc = unsafe {
104            bindings::ove_model_init(
105                &mut handle,
106                storage,
107                arena as *mut core::ffi::c_void,
108                &c_cfg,
109            )
110        };
111        Error::from_code(rc)?;
112        Ok(Self { handle })
113    }
114
115    /// Run the model forward pass.
116    pub fn invoke(&self) -> Result<()> {
117        let rc = unsafe { bindings::ove_model_invoke(self.handle) };
118        Error::from_code(rc)
119    }
120
121    /// Get tensor info for an input tensor.
122    pub fn input(&self, index: u32) -> Result<TensorInfo> {
123        let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
124        let rc = unsafe { bindings::ove_model_input(self.handle, index, &mut info) };
125        Error::from_code(rc)?;
126        Ok(TensorInfo {
127            data: info.data as *mut u8,
128            size: info.size,
129            tensor_type: match info.type_ {
130                1 => TensorType::Int8,
131                2 => TensorType::Uint8,
132                3 => TensorType::Int16,
133                4 => TensorType::Int32,
134                _ => TensorType::Float32,
135            },
136            ndims: info.ndims,
137            dims: info.dims,
138        })
139    }
140
141    /// Get tensor info for an output tensor.
142    pub fn output(&self, index: u32) -> Result<TensorInfo> {
143        let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
144        let rc = unsafe { bindings::ove_model_output(self.handle, index, &mut info) };
145        Error::from_code(rc)?;
146        Ok(TensorInfo {
147            data: info.data as *mut u8,
148            size: info.size,
149            tensor_type: match info.type_ {
150                1 => TensorType::Int8,
151                2 => TensorType::Uint8,
152                3 => TensorType::Int16,
153                4 => TensorType::Int32,
154                _ => TensorType::Float32,
155            },
156            ndims: info.ndims,
157            dims: info.dims,
158        })
159    }
160
161    /// Return last inference duration in microseconds.
162    pub fn last_inference_us(&self) -> u64 {
163        unsafe { bindings::ove_model_last_inference_us(self.handle) }
164    }
165}
166
167crate::ove_handle_impl!(Model, ove_model_destroy, ove_model_deinit);
168
169// ---------------------------------------------------------------------------
170// ModelStorage — safe reusable arena
171// ---------------------------------------------------------------------------
172
173/// Reusable model storage and arena pair for sequential inference.
174///
175/// Owns both the C `ove_model_storage_t` and a 16-byte aligned arena
176/// buffer.  Call [`load()`](ModelStorage::load) to create a [`Model`]
177/// session; the borrow checker ensures the storage is not shared.
178///
179/// # Example
180///
181/// ```ignore
182/// let mut storage = ModelStorage::<32768>::new();
183/// let model = storage.load(model_bytes)?;
184/// let input = model.input_slice_mut::<i16>(0)?;
185/// input[0] = 42;
186/// model.invoke()?;
187/// let output = model.output_slice::<i8>(0)?;
188/// ```
189#[repr(C, align(16))]
190pub struct ModelStorage<const ARENA_SIZE: usize> {
191    storage: bindings::ove_model_storage_t,
192    arena: [u8; ARENA_SIZE],
193}
194
195impl<const ARENA_SIZE: usize> Default for ModelStorage<ARENA_SIZE> {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl<const ARENA_SIZE: usize> ModelStorage<ARENA_SIZE> {
202    /// Create a zeroed storage + arena pair.
203    pub fn new() -> Self {
204        // SAFETY: ove_model_storage_t is a C struct that is valid when zeroed.
205        unsafe { core::mem::zeroed() }
206    }
207
208    /// Load a model into this storage, returning a session handle.
209    ///
210    /// The arena size is supplied by the const generic `ARENA_SIZE` —
211    /// no need to repeat it.  The returned [`Model`] borrows `self`
212    /// mutably, so the compiler prevents concurrent use or re-loading
213    /// until the model is dropped.
214    pub fn load(&mut self, model_data: &[u8]) -> Result<Model> {
215        let config = ModelConfig {
216            model_data,
217            arena_size: ARENA_SIZE,
218        };
219        unsafe { Model::from_static(&mut self.storage, self.arena.as_mut_ptr(), &config) }
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Typed tensor accessors
225// ---------------------------------------------------------------------------
226
227impl Model {
228    /// Get input tensor data as a mutable typed slice.
229    ///
230    /// The slice length is `tensor_info.size / size_of::<T>()`.
231    ///
232    /// # Errors
233    /// Returns an error if the tensor index is invalid.
234    #[allow(clippy::mut_from_ref)] // tensor arena is owned by the C session, not by &self
235    pub fn input_slice_mut<T>(&self, index: u32) -> Result<&mut [T]> {
236        let info = self.input(index)?;
237        let count = info.size / core::mem::size_of::<T>();
238        // SAFETY: The tensor arena is owned by the model session and
239        // valid for the lifetime of this Model.  We have &self so the
240        // model is alive.  The caller must not alias this slice with
241        // another call to input_slice_mut for the same tensor index.
242        Ok(unsafe { core::slice::from_raw_parts_mut(info.data as *mut T, count) })
243    }
244
245    /// Get output tensor data as a typed slice.
246    ///
247    /// The slice length is `tensor_info.size / size_of::<T>()`.
248    ///
249    /// # Errors
250    /// Returns an error if the tensor index is invalid.
251    pub fn output_slice<T>(&self, index: u32) -> Result<&[T]> {
252        let info = self.output(index)?;
253        let count = info.size / core::mem::size_of::<T>();
254        // SAFETY: Same as input_slice_mut, but immutable.
255        Ok(unsafe { core::slice::from_raw_parts(info.data as *const T, count) })
256    }
257}