poc_kokkos_rs/view/
parameters.rs

1//! view parameterization code
2//!
3//! This module contains all code used to parameterize views. Some parameters are direct
4//! Kokkos replicates while others are Rust-specific. Currently supported parameters
5//! include:
6//!
7//! - Type of the data owned by the view (Rust-specific)
8//! - Memory layout
9//!
10//! Possible future implementations include:
11//!
12//! - Memory space
13//! - Memory traits
14
15use std::fmt::Debug;
16
17#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
18use atomic::Atomic;
19use bytemuck::Pod;
20
21/// Maximum possible depth (i.e. number of dimensions) for a view.
22pub const MAX_VIEW_DEPTH: usize = 8;
23
24/// Supertrait with common trait that elements of a View should implement.
25pub trait DataTraits: Debug + Clone + Copy + Default + Pod {}
26
27impl DataTraits for f64 {}
28impl DataTraits for f32 {}
29
30#[cfg(not(any(feature = "rayon", feature = "threads", feature = "gpu")))]
31/// Generic alias for elements of type `T` of a View.
32///
33/// This alias automatically changes according to features to ensure thread-safety
34/// of Views. There are two possible values:
35///
36/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
37///   operations on views can be implemented using thread-safe methods.
38/// - no feature enabled: `InnerDataType<T> = T`.
39///
40/// **Current version**: no feature
41pub type InnerDataType<T> = T;
42
43#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
44/// Generic alias for elements of type `T` of a View.
45///
46/// This alias automatically changes according to features to ensure thread-safety
47/// of Views. There are two possible values:
48///
49/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
50///   operations on views can be implemented using thread-safe methods.
51/// - no feature enabled: `InnerDataType<T> = T`.
52///
53/// **Current version**: thread-safe
54pub type InnerDataType<T> = Atomic<T>;
55
56#[derive(Debug)]
57/// Enum used to identify the type of data the view is holding.
58///
59/// This should eventually be removed. See the [view][crate::view] module documentation
60/// for more information.
61///
62/// The policy used to implement the [PartialEq] trait is based on Kokkos'
63/// [`equal` algorithm](https://kokkos.github.io/kokkos-core-wiki/API/algorithms/std-algorithms/all/StdEqual.html).
64/// Essentially, it corresponds to equality by reference instead of equality by value.
65pub enum DataType<'a, T>
66where
67    T: DataTraits,
68{
69    /// The view owns the data.
70    Owned(Vec<InnerDataType<T>>),
71    /// The view borrows the data and can only read it.
72    Borrowed(&'a [InnerDataType<T>]),
73    /// The view borrows the data and can both read and modify it.
74    MutBorrowed(&'a mut [InnerDataType<T>]),
75}
76
77impl<T> PartialEq for DataType<'_, T>
78where
79    T: DataTraits,
80{
81    fn eq(&self, other: &Self) -> bool {
82        match (self, other) {
83            // are the deref operations necessary?
84            // this one is technically necessary because self==self should return true
85            (Self::Owned(l0), Self::Owned(r0)) => (*l0).as_ptr() == (*r0).as_ptr(),
86            // compare pointers
87            // deref Owned only once, twice the others
88            (Self::Owned(l0), Self::Borrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
89            (Self::Owned(l0), Self::MutBorrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
90            (Self::Borrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
91            (Self::MutBorrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
92
93            (Self::Borrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
94            (Self::Borrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
95            (Self::MutBorrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
96            (Self::MutBorrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
97        }
98    }
99}
100
101/// Enum used to represent data layout. Struct enums is used in order to increase
102/// readability.
103#[derive(Clone, Copy, Debug, Default, PartialEq)]
104pub enum Layout<const N: usize> {
105    /// Highest stride for the first index, decreasing stride as index increases.
106    /// Exact stride for each index can be computed from dimensions at view initialization.
107    #[default]
108    Right,
109    /// Lowest stride for the first index, increasing stride as index decreases.
110    /// Exact stride for each index can be computed from dimensions at view initialization.
111    Left,
112    /// Custom stride for each index. Must be compatible with dimensions.
113    Stride { s: [usize; N] },
114}
115
116/// Compute correct strides of each index using dimensions and specified layout.
117pub fn compute_stride<const N: usize>(dim: &[usize; N], layout: &Layout<N>) -> [usize; N] {
118    assert_eq!(N.clamp(1, MAX_VIEW_DEPTH), N); // 1 <= N <= MAX_N
119    match layout {
120        Layout::Right => {
121            let mut stride = [1; N];
122
123            let mut tmp: usize = 1;
124            for i in (1..N).rev() {
125                tmp *= dim[i];
126                stride[N - i] = tmp;
127            }
128
129            stride.reverse();
130            stride
131        }
132        Layout::Left => {
133            let mut stride = [1; N];
134
135            let mut tmp: usize = 1;
136            for i in 0..N - 1 {
137                tmp *= dim[i];
138                stride[i + 1] = tmp;
139            }
140
141            stride
142        }
143        Layout::Stride { s } => *s,
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn stride_right() {
153        // dim = [n0, n1, n2, n3]
154        let dim = [3, 4, 5, 6];
155
156        let cmp_stride = compute_stride(&dim, &Layout::Right);
157        // n3 * n2 * n1, n3 * n2, n3, 1
158        let ref_stride: [usize; 4] = [6 * 5 * 4, 6 * 5, 6, 1];
159
160        assert_eq!(cmp_stride, ref_stride);
161    }
162
163    #[test]
164    fn stride_left() {
165        // dim = [n0, n1, n2, n3]
166        let dim = [3, 4, 5, 6];
167
168        let cmp_stride = compute_stride(&dim, &Layout::Left);
169        // 1, n0, n0 * n1, n0 * n1 * n2
170        let ref_stride: [usize; 4] = [1, 3, 3 * 4, 3 * 4 * 5];
171
172        assert_eq!(cmp_stride, ref_stride);
173    }
174
175    #[test]
176    fn one_d_stride() {
177        // 1d view (vector) of length 1
178        let dim: [usize; 1] = [8];
179        let ref_stride: [usize; 1] = [1];
180        let mut cmp_stride = compute_stride(&dim, &Layout::Right);
181        assert_eq!(ref_stride, cmp_stride);
182        cmp_stride = compute_stride(&dim, &Layout::Left);
183        assert_eq!(ref_stride, cmp_stride);
184    }
185}