ove/infer.rs
1// Copyright (C) 2026 Kamil Lulko <kamil.lulko@gmail.com>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4//
5// This file is part of oveRTOS.
6
7//! ML inference primitives for oveRTOS.
8//!
9//! Provides a safe Rust wrapper around the `ove_model_*` C API for
10//! running TFLite model inference.
11
12use crate::bindings;
13use crate::error::{Error, Result};
14
15/// Tensor element types matching the C `enum ove_tensor_type`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u32)]
18pub enum TensorType {
19 Float32 = 0,
20 Int8 = 1,
21 Uint8 = 2,
22 Int16 = 3,
23 Int32 = 4,
24}
25
26/// Tensor metadata descriptor.
27#[derive(Debug, Clone)]
28pub struct TensorInfo {
29 /// Pointer to tensor data in the arena.
30 pub data: *mut u8,
31 /// Total size in bytes.
32 pub size: usize,
33 /// Element type.
34 pub tensor_type: TensorType,
35 /// Number of dimensions.
36 pub ndims: u32,
37 /// Shape array (up to 5 dimensions).
38 pub dims: [i32; 5],
39}
40
41/// Model configuration.
42pub struct ModelConfig<'a> {
43 /// Reference to the .tflite FlatBuffer data.
44 pub model_data: &'a [u8],
45 /// Tensor arena size in bytes.
46 pub arena_size: usize,
47}
48
49/// An ML inference model session.
50///
51/// Wraps a TFLM `MicroInterpreter` with automatic cleanup.
52pub struct Model {
53 handle: bindings::ove_model_t,
54}
55
56impl Model {
57 /// Create a new model via heap allocation.
58 #[cfg(not(zero_heap))]
59 pub fn new(config: &ModelConfig) -> Result<Self> {
60 let mut handle: bindings::ove_model_t = core::ptr::null_mut();
61 let c_cfg = bindings::ove_model_config {
62 model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
63 model_size: config.model_data.len(),
64 arena_size: config.arena_size,
65 };
66 let rc = unsafe { bindings::ove_model_create(&mut handle, &c_cfg) };
67 Error::from_code(rc)?;
68 Ok(Self { handle })
69 }
70
71 /// Create from caller-provided storage and arena.
72 ///
73 /// Available in both heap and zero-heap modes. Useful when the same
74 /// storage/arena must be reused for different models (e.g. two-stage
75 /// inference pipelines).
76 ///
77 /// # Safety
78 /// Caller must ensure `storage` and `arena` outlive the `Model` and are
79 /// not shared with another primitive.
80 pub unsafe fn from_static(
81 storage: *mut bindings::ove_model_storage_t,
82 arena: *mut u8,
83 config: &ModelConfig,
84 ) -> Result<Self> {
85 let mut handle: bindings::ove_model_t = core::ptr::null_mut();
86 let c_cfg = bindings::ove_model_config {
87 model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
88 model_size: config.model_data.len(),
89 arena_size: config.arena_size,
90 };
91 let rc = bindings::ove_model_init(
92 &mut handle,
93 storage,
94 arena as *mut core::ffi::c_void,
95 &c_cfg,
96 );
97 Error::from_code(rc)?;
98 Ok(Self { handle })
99 }
100
101 /// Run the model forward pass.
102 pub fn invoke(&self) -> Result<()> {
103 let rc = unsafe { bindings::ove_model_invoke(self.handle) };
104 Error::from_code(rc)
105 }
106
107 /// Get tensor info for an input tensor.
108 pub fn input(&self, index: u32) -> Result<TensorInfo> {
109 let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
110 let rc = unsafe {
111 bindings::ove_model_input(self.handle, index, &mut info)
112 };
113 Error::from_code(rc)?;
114 Ok(TensorInfo {
115 data: info.data as *mut u8,
116 size: info.size,
117 tensor_type: match info.type_ {
118 1 => TensorType::Int8,
119 2 => TensorType::Uint8,
120 3 => TensorType::Int16,
121 4 => TensorType::Int32,
122 _ => TensorType::Float32,
123 },
124 ndims: info.ndims,
125 dims: info.dims,
126 })
127 }
128
129 /// Get tensor info for an output tensor.
130 pub fn output(&self, index: u32) -> Result<TensorInfo> {
131 let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
132 let rc = unsafe {
133 bindings::ove_model_output(self.handle, index, &mut info)
134 };
135 Error::from_code(rc)?;
136 Ok(TensorInfo {
137 data: info.data as *mut u8,
138 size: info.size,
139 tensor_type: match info.type_ {
140 1 => TensorType::Int8,
141 2 => TensorType::Uint8,
142 3 => TensorType::Int16,
143 4 => TensorType::Int32,
144 _ => TensorType::Float32,
145 },
146 ndims: info.ndims,
147 dims: info.dims,
148 })
149 }
150
151 /// Return last inference duration in microseconds.
152 pub fn last_inference_us(&self) -> u64 {
153 unsafe { bindings::ove_model_last_inference_us(self.handle) }
154 }
155}
156
157impl Drop for Model {
158 fn drop(&mut self) {
159 if !self.handle.is_null() {
160 #[cfg(not(zero_heap))]
161 unsafe {
162 bindings::ove_model_destroy(self.handle);
163 }
164 #[cfg(zero_heap)]
165 unsafe {
166 bindings::ove_model_deinit(self.handle);
167 }
168 }
169 }
170}
171
172unsafe impl Send for Model {}
173unsafe impl Sync for Model {}
174
175// ---------------------------------------------------------------------------
176// ModelStorage — safe reusable arena
177// ---------------------------------------------------------------------------
178
179/// Reusable model storage and arena pair for sequential inference.
180///
181/// Owns both the C `ove_model_storage_t` and a 16-byte aligned arena
182/// buffer. Call [`load()`](ModelStorage::load) to create a [`Model`]
183/// session; the borrow checker ensures the storage is not shared.
184///
185/// # Example
186///
187/// ```ignore
188/// let mut storage = ModelStorage::<32768>::new();
189/// let model = storage.load(&cfg)?;
190/// let input = model.input_slice_mut::<i16>(0)?;
191/// input[0] = 42;
192/// model.invoke()?;
193/// let output = model.output_slice::<i8>(0)?;
194/// ```
195#[repr(C, align(16))]
196pub struct ModelStorage<const ARENA_SIZE: usize> {
197 storage: bindings::ove_model_storage_t,
198 arena: [u8; ARENA_SIZE],
199}
200
201impl<const ARENA_SIZE: usize> ModelStorage<ARENA_SIZE> {
202 /// Create a zeroed storage + arena pair.
203 pub fn new() -> Self {
204 // SAFETY: ove_model_storage_t is a C struct that is valid when zeroed.
205 unsafe {
206 core::mem::zeroed()
207 }
208 }
209
210 /// Load a model into this storage, returning a session handle.
211 ///
212 /// The arena size is supplied by the const generic `ARENA_SIZE` —
213 /// no need to repeat it. The returned [`Model`] borrows `self`
214 /// mutably, so the compiler prevents concurrent use or re-loading
215 /// until the model is dropped.
216 pub fn load(&mut self, model_data: &[u8]) -> Result<Model> {
217 let config = ModelConfig {
218 model_data,
219 arena_size: ARENA_SIZE,
220 };
221 unsafe {
222 Model::from_static(
223 &mut self.storage,
224 self.arena.as_mut_ptr(),
225 &config,
226 )
227 }
228 }
229}
230
231// ---------------------------------------------------------------------------
232// Typed tensor accessors
233// ---------------------------------------------------------------------------
234
235impl Model {
236 /// Get input tensor data as a mutable typed slice.
237 ///
238 /// The slice length is `tensor_info.size / size_of::<T>()`.
239 ///
240 /// # Errors
241 /// Returns an error if the tensor index is invalid.
242 pub fn input_slice_mut<T>(&self, index: u32) -> Result<&mut [T]> {
243 let info = self.input(index)?;
244 let count = info.size / core::mem::size_of::<T>();
245 // SAFETY: The tensor arena is owned by the model session and
246 // valid for the lifetime of this Model. We have &self so the
247 // model is alive. The caller must not alias this slice with
248 // another call to input_slice_mut for the same tensor index.
249 Ok(unsafe {
250 core::slice::from_raw_parts_mut(info.data as *mut T, count)
251 })
252 }
253
254 /// Get output tensor data as a typed slice.
255 ///
256 /// The slice length is `tensor_info.size / size_of::<T>()`.
257 ///
258 /// # Errors
259 /// Returns an error if the tensor index is invalid.
260 pub fn output_slice<T>(&self, index: u32) -> Result<&[T]> {
261 let info = self.output(index)?;
262 let count = info.size / core::mem::size_of::<T>();
263 // SAFETY: Same as input_slice_mut, but immutable.
264 Ok(unsafe {
265 core::slice::from_raw_parts(info.data as *const T, count)
266 })
267 }
268}