ove/infer.rs
1// Copyright (C) 2026 Kamil Lulko <kamil.lulko@gmail.com>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4//
5// This file is part of oveRTOS.
6
7//! ML inference primitives for oveRTOS.
8//!
9//! Provides a safe Rust wrapper around the `ove_model_*` C API for
10//! running TFLite model inference.
11
12use crate::bindings;
13use crate::error::{Error, Result};
14
15// SAFETY (module-wide contract for the `unsafe { bindings::ove_*(...) }` FFI
16// calls below): any handle passed to the C API is non-null and refers to a
17// live RTOS object — wrapper constructors establish validity via
18// `Error::from_code`, and `Drop` (or an explicit `deinit`) is the only place
19// a handle is released. Pointer and slice arguments reference caller-owned
20// memory valid for the duration of the call; the C side copies whatever it
21// retains and does not alias them past return (verified against the
22// signatures in `include/ove/*.h`). Blocks that deviate — `transmute`, raw
23// pointer casts from user data, slice reconstruction via `from_raw_parts`,
24// or storing a callback across the FFI boundary — carry their own
25// `// SAFETY:` comment.
26
27/// Tensor element types matching the C `enum ove_tensor_type`.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[repr(u32)]
30pub enum TensorType {
31 Float32 = 0,
32 Int8 = 1,
33 Uint8 = 2,
34 Int16 = 3,
35 Int32 = 4,
36}
37
38/// Tensor metadata descriptor.
39#[derive(Debug, Clone)]
40pub struct TensorInfo {
41 /// Pointer to tensor data in the arena.
42 pub data: *mut u8,
43 /// Total size in bytes.
44 pub size: usize,
45 /// Element type.
46 pub tensor_type: TensorType,
47 /// Number of dimensions.
48 pub ndims: u32,
49 /// Shape array (up to 5 dimensions).
50 pub dims: [i32; 5],
51}
52
53/// Model configuration.
54pub struct ModelConfig<'a> {
55 /// Reference to the .tflite FlatBuffer data.
56 pub model_data: &'a [u8],
57 /// Tensor arena size in bytes.
58 pub arena_size: usize,
59}
60
61/// An ML inference model session.
62///
63/// Wraps a TFLM `MicroInterpreter` with automatic cleanup.
64pub struct Model {
65 handle: bindings::ove_model_t,
66}
67
68impl Model {
69 /// Create a new model via heap allocation.
70 #[cfg(not(zero_heap))]
71 pub fn new(config: &ModelConfig) -> Result<Self> {
72 let mut handle: bindings::ove_model_t = core::ptr::null_mut();
73 let c_cfg = bindings::ove_model_config {
74 model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
75 model_size: config.model_data.len(),
76 arena_size: config.arena_size,
77 };
78 let rc = unsafe { bindings::ove_model_create(&mut handle, &c_cfg) };
79 Error::from_code(rc)?;
80 Ok(Self { handle })
81 }
82
83 /// Create from caller-provided storage and arena.
84 ///
85 /// Available in both heap and zero-heap modes. Useful when the same
86 /// storage/arena must be reused for different models (e.g. two-stage
87 /// inference pipelines).
88 ///
89 /// # Safety
90 /// Caller must ensure `storage` and `arena` outlive the `Model` and are
91 /// not shared with another primitive.
92 pub unsafe fn from_static(
93 storage: *mut bindings::ove_model_storage_t,
94 arena: *mut u8,
95 config: &ModelConfig,
96 ) -> Result<Self> {
97 let mut handle: bindings::ove_model_t = core::ptr::null_mut();
98 let c_cfg = bindings::ove_model_config {
99 model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
100 model_size: config.model_data.len(),
101 arena_size: config.arena_size,
102 };
103 let rc = unsafe {
104 bindings::ove_model_init(
105 &mut handle,
106 storage,
107 arena as *mut core::ffi::c_void,
108 &c_cfg,
109 )
110 };
111 Error::from_code(rc)?;
112 Ok(Self { handle })
113 }
114
115 /// Run the model forward pass.
116 pub fn invoke(&self) -> Result<()> {
117 let rc = unsafe { bindings::ove_model_invoke(self.handle) };
118 Error::from_code(rc)
119 }
120
121 /// Get tensor info for an input tensor.
122 pub fn input(&self, index: u32) -> Result<TensorInfo> {
123 let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
124 let rc = unsafe { bindings::ove_model_input(self.handle, index, &mut info) };
125 Error::from_code(rc)?;
126 Ok(TensorInfo {
127 data: info.data as *mut u8,
128 size: info.size,
129 tensor_type: match info.type_ {
130 1 => TensorType::Int8,
131 2 => TensorType::Uint8,
132 3 => TensorType::Int16,
133 4 => TensorType::Int32,
134 _ => TensorType::Float32,
135 },
136 ndims: info.ndims,
137 dims: info.dims,
138 })
139 }
140
141 /// Get tensor info for an output tensor.
142 pub fn output(&self, index: u32) -> Result<TensorInfo> {
143 let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
144 let rc = unsafe { bindings::ove_model_output(self.handle, index, &mut info) };
145 Error::from_code(rc)?;
146 Ok(TensorInfo {
147 data: info.data as *mut u8,
148 size: info.size,
149 tensor_type: match info.type_ {
150 1 => TensorType::Int8,
151 2 => TensorType::Uint8,
152 3 => TensorType::Int16,
153 4 => TensorType::Int32,
154 _ => TensorType::Float32,
155 },
156 ndims: info.ndims,
157 dims: info.dims,
158 })
159 }
160
161 /// Return last inference duration in microseconds.
162 pub fn last_inference_us(&self) -> u64 {
163 unsafe { bindings::ove_model_last_inference_us(self.handle) }
164 }
165}
166
167crate::ove_handle_impl!(Model, ove_model_destroy, ove_model_deinit);
168
169// ---------------------------------------------------------------------------
170// ModelStorage — safe reusable arena
171// ---------------------------------------------------------------------------
172
173/// Reusable model storage and arena pair for sequential inference.
174///
175/// Owns both the C `ove_model_storage_t` and a 16-byte aligned arena
176/// buffer. Call [`load()`](ModelStorage::load) to create a [`Model`]
177/// session; the borrow checker ensures the storage is not shared.
178///
179/// # Example
180///
181/// ```ignore
182/// let mut storage = ModelStorage::<32768>::new();
183/// let model = storage.load(model_bytes)?;
184/// let input = model.input_slice_mut::<i16>(0)?;
185/// input[0] = 42;
186/// model.invoke()?;
187/// let output = model.output_slice::<i8>(0)?;
188/// ```
189#[repr(C, align(16))]
190pub struct ModelStorage<const ARENA_SIZE: usize> {
191 storage: bindings::ove_model_storage_t,
192 arena: [u8; ARENA_SIZE],
193}
194
195impl<const ARENA_SIZE: usize> Default for ModelStorage<ARENA_SIZE> {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl<const ARENA_SIZE: usize> ModelStorage<ARENA_SIZE> {
202 /// Create a zeroed storage + arena pair.
203 pub fn new() -> Self {
204 // SAFETY: ove_model_storage_t is a C struct that is valid when zeroed.
205 unsafe { core::mem::zeroed() }
206 }
207
208 /// Load a model into this storage, returning a session handle.
209 ///
210 /// The arena size is supplied by the const generic `ARENA_SIZE` —
211 /// no need to repeat it. The returned [`Model`] borrows `self`
212 /// mutably, so the compiler prevents concurrent use or re-loading
213 /// until the model is dropped.
214 pub fn load(&mut self, model_data: &[u8]) -> Result<Model> {
215 let config = ModelConfig {
216 model_data,
217 arena_size: ARENA_SIZE,
218 };
219 unsafe { Model::from_static(&mut self.storage, self.arena.as_mut_ptr(), &config) }
220 }
221}
222
223// ---------------------------------------------------------------------------
224// Typed tensor accessors
225// ---------------------------------------------------------------------------
226
227impl Model {
228 /// Get input tensor data as a mutable typed slice.
229 ///
230 /// The slice length is `tensor_info.size / size_of::<T>()`.
231 ///
232 /// # Errors
233 /// Returns an error if the tensor index is invalid.
234 #[allow(clippy::mut_from_ref)] // tensor arena is owned by the C session, not by &self
235 pub fn input_slice_mut<T>(&self, index: u32) -> Result<&mut [T]> {
236 let info = self.input(index)?;
237 let count = info.size / core::mem::size_of::<T>();
238 // SAFETY: The tensor arena is owned by the model session and
239 // valid for the lifetime of this Model. We have &self so the
240 // model is alive. The caller must not alias this slice with
241 // another call to input_slice_mut for the same tensor index.
242 Ok(unsafe { core::slice::from_raw_parts_mut(info.data as *mut T, count) })
243 }
244
245 /// Get output tensor data as a typed slice.
246 ///
247 /// The slice length is `tensor_info.size / size_of::<T>()`.
248 ///
249 /// # Errors
250 /// Returns an error if the tensor index is invalid.
251 pub fn output_slice<T>(&self, index: u32) -> Result<&[T]> {
252 let info = self.output(index)?;
253 let count = info.size / core::mem::size_of::<T>();
254 // SAFETY: Same as input_slice_mut, but immutable.
255 Ok(unsafe { core::slice::from_raw_parts(info.data as *const T, count) })
256 }
257}