ove/
infer.rs

1// Copyright (C) 2026 Kamil Lulko <kamil.lulko@gmail.com>
2//
3// SPDX-License-Identifier: GPL-3.0-or-later
4//
5// This file is part of oveRTOS.
6
7//! ML inference primitives for oveRTOS.
8//!
9//! Provides a safe Rust wrapper around the `ove_model_*` C API for
10//! running TFLite model inference.
11
12use crate::bindings;
13use crate::error::{Error, Result};
14
15/// Tensor element types matching the C `enum ove_tensor_type`.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u32)]
18pub enum TensorType {
19    Float32 = 0,
20    Int8 = 1,
21    Uint8 = 2,
22    Int16 = 3,
23    Int32 = 4,
24}
25
26/// Tensor metadata descriptor.
27#[derive(Debug, Clone)]
28pub struct TensorInfo {
29    /// Pointer to tensor data in the arena.
30    pub data: *mut u8,
31    /// Total size in bytes.
32    pub size: usize,
33    /// Element type.
34    pub tensor_type: TensorType,
35    /// Number of dimensions.
36    pub ndims: u32,
37    /// Shape array (up to 5 dimensions).
38    pub dims: [i32; 5],
39}
40
41/// Model configuration.
42pub struct ModelConfig<'a> {
43    /// Reference to the .tflite FlatBuffer data.
44    pub model_data: &'a [u8],
45    /// Tensor arena size in bytes.
46    pub arena_size: usize,
47}
48
49/// An ML inference model session.
50///
51/// Wraps a TFLM `MicroInterpreter` with automatic cleanup.
52pub struct Model {
53    handle: bindings::ove_model_t,
54}
55
56impl Model {
57    /// Create a new model via heap allocation.
58    #[cfg(not(zero_heap))]
59    pub fn new(config: &ModelConfig) -> Result<Self> {
60        let mut handle: bindings::ove_model_t = core::ptr::null_mut();
61        let c_cfg = bindings::ove_model_config {
62            model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
63            model_size: config.model_data.len(),
64            arena_size: config.arena_size,
65        };
66        let rc = unsafe { bindings::ove_model_create(&mut handle, &c_cfg) };
67        Error::from_code(rc)?;
68        Ok(Self { handle })
69    }
70
71    /// Create from caller-provided storage and arena.
72    ///
73    /// Available in both heap and zero-heap modes.  Useful when the same
74    /// storage/arena must be reused for different models (e.g. two-stage
75    /// inference pipelines).
76    ///
77    /// # Safety
78    /// Caller must ensure `storage` and `arena` outlive the `Model` and are
79    /// not shared with another primitive.
80    pub unsafe fn from_static(
81        storage: *mut bindings::ove_model_storage_t,
82        arena: *mut u8,
83        config: &ModelConfig,
84    ) -> Result<Self> {
85        let mut handle: bindings::ove_model_t = core::ptr::null_mut();
86        let c_cfg = bindings::ove_model_config {
87            model_data: config.model_data.as_ptr() as *const core::ffi::c_void,
88            model_size: config.model_data.len(),
89            arena_size: config.arena_size,
90        };
91        let rc = bindings::ove_model_init(
92            &mut handle,
93            storage,
94            arena as *mut core::ffi::c_void,
95            &c_cfg,
96        );
97        Error::from_code(rc)?;
98        Ok(Self { handle })
99    }
100
101    /// Run the model forward pass.
102    pub fn invoke(&self) -> Result<()> {
103        let rc = unsafe { bindings::ove_model_invoke(self.handle) };
104        Error::from_code(rc)
105    }
106
107    /// Get tensor info for an input tensor.
108    pub fn input(&self, index: u32) -> Result<TensorInfo> {
109        let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
110        let rc = unsafe {
111            bindings::ove_model_input(self.handle, index, &mut info)
112        };
113        Error::from_code(rc)?;
114        Ok(TensorInfo {
115            data: info.data as *mut u8,
116            size: info.size,
117            tensor_type: match info.type_ {
118                1 => TensorType::Int8,
119                2 => TensorType::Uint8,
120                3 => TensorType::Int16,
121                4 => TensorType::Int32,
122                _ => TensorType::Float32,
123            },
124            ndims: info.ndims,
125            dims: info.dims,
126        })
127    }
128
129    /// Get tensor info for an output tensor.
130    pub fn output(&self, index: u32) -> Result<TensorInfo> {
131        let mut info: bindings::ove_tensor_info = unsafe { core::mem::zeroed() };
132        let rc = unsafe {
133            bindings::ove_model_output(self.handle, index, &mut info)
134        };
135        Error::from_code(rc)?;
136        Ok(TensorInfo {
137            data: info.data as *mut u8,
138            size: info.size,
139            tensor_type: match info.type_ {
140                1 => TensorType::Int8,
141                2 => TensorType::Uint8,
142                3 => TensorType::Int16,
143                4 => TensorType::Int32,
144                _ => TensorType::Float32,
145            },
146            ndims: info.ndims,
147            dims: info.dims,
148        })
149    }
150
151    /// Return last inference duration in microseconds.
152    pub fn last_inference_us(&self) -> u64 {
153        unsafe { bindings::ove_model_last_inference_us(self.handle) }
154    }
155}
156
157impl Drop for Model {
158    fn drop(&mut self) {
159        if !self.handle.is_null() {
160            #[cfg(not(zero_heap))]
161            unsafe {
162                bindings::ove_model_destroy(self.handle);
163            }
164            #[cfg(zero_heap)]
165            unsafe {
166                bindings::ove_model_deinit(self.handle);
167            }
168        }
169    }
170}
171
172unsafe impl Send for Model {}
173unsafe impl Sync for Model {}
174
175// ---------------------------------------------------------------------------
176// ModelStorage — safe reusable arena
177// ---------------------------------------------------------------------------
178
179/// Reusable model storage and arena pair for sequential inference.
180///
181/// Owns both the C `ove_model_storage_t` and a 16-byte aligned arena
182/// buffer.  Call [`load()`](ModelStorage::load) to create a [`Model`]
183/// session; the borrow checker ensures the storage is not shared.
184///
185/// # Example
186///
187/// ```ignore
188/// let mut storage = ModelStorage::<32768>::new();
189/// let model = storage.load(&cfg)?;
190/// let input = model.input_slice_mut::<i16>(0)?;
191/// input[0] = 42;
192/// model.invoke()?;
193/// let output = model.output_slice::<i8>(0)?;
194/// ```
195#[repr(C, align(16))]
196pub struct ModelStorage<const ARENA_SIZE: usize> {
197    storage: bindings::ove_model_storage_t,
198    arena: [u8; ARENA_SIZE],
199}
200
201impl<const ARENA_SIZE: usize> ModelStorage<ARENA_SIZE> {
202    /// Create a zeroed storage + arena pair.
203    pub fn new() -> Self {
204        // SAFETY: ove_model_storage_t is a C struct that is valid when zeroed.
205        unsafe {
206            core::mem::zeroed()
207        }
208    }
209
210    /// Load a model into this storage, returning a session handle.
211    ///
212    /// The arena size is supplied by the const generic `ARENA_SIZE` —
213    /// no need to repeat it.  The returned [`Model`] borrows `self`
214    /// mutably, so the compiler prevents concurrent use or re-loading
215    /// until the model is dropped.
216    pub fn load(&mut self, model_data: &[u8]) -> Result<Model> {
217        let config = ModelConfig {
218            model_data,
219            arena_size: ARENA_SIZE,
220        };
221        unsafe {
222            Model::from_static(
223                &mut self.storage,
224                self.arena.as_mut_ptr(),
225                &config,
226            )
227        }
228    }
229}
230
231// ---------------------------------------------------------------------------
232// Typed tensor accessors
233// ---------------------------------------------------------------------------
234
235impl Model {
236    /// Get input tensor data as a mutable typed slice.
237    ///
238    /// The slice length is `tensor_info.size / size_of::<T>()`.
239    ///
240    /// # Errors
241    /// Returns an error if the tensor index is invalid.
242    pub fn input_slice_mut<T>(&self, index: u32) -> Result<&mut [T]> {
243        let info = self.input(index)?;
244        let count = info.size / core::mem::size_of::<T>();
245        // SAFETY: The tensor arena is owned by the model session and
246        // valid for the lifetime of this Model.  We have &self so the
247        // model is alive.  The caller must not alias this slice with
248        // another call to input_slice_mut for the same tensor index.
249        Ok(unsafe {
250            core::slice::from_raw_parts_mut(info.data as *mut T, count)
251        })
252    }
253
254    /// Get output tensor data as a typed slice.
255    ///
256    /// The slice length is `tensor_info.size / size_of::<T>()`.
257    ///
258    /// # Errors
259    /// Returns an error if the tensor index is invalid.
260    pub fn output_slice<T>(&self, index: u32) -> Result<&[T]> {
261        let info = self.output(index)?;
262        let count = info.size / core::mem::size_of::<T>();
263        // SAFETY: Same as input_slice_mut, but immutable.
264        Ok(unsafe {
265            core::slice::from_raw_parts(info.data as *const T, count)
266        })
267    }
268}