poc_kokkos_rs/view/parameters.rs
1//! view parameterization code
2//!
3//! This module contains all code used to parameterize views. Some parameters are direct
4//! Kokkos replicates while others are Rust-specific. Currently supported parameters
5//! include:
6//!
7//! - Type of the data owned by the view (Rust-specific)
8//! - Memory layout
9//!
10//! Possible future implementations include:
11//!
12//! - Memory space
13//! - Memory traits
14
15use std::fmt::Debug;
16
17#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
18use atomic::Atomic;
19use bytemuck::Pod;
20
21/// Maximum possible depth (i.e. number of dimensions) for a view.
22pub const MAX_VIEW_DEPTH: usize = 8;
23
24/// Supertrait with common trait that elements of a View should implement.
25pub trait DataTraits: Debug + Clone + Copy + Default + Pod {}
26
27impl DataTraits for f64 {}
28impl DataTraits for f32 {}
29
30#[cfg(not(any(feature = "rayon", feature = "threads", feature = "gpu")))]
31/// Generic alias for elements of type `T` of a View.
32///
33/// This alias automatically changes according to features to ensure thread-safety
34/// of Views. There are two possible values:
35///
36/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
37/// operations on views can be implemented using thread-safe methods.
38/// - no feature enabled: `InnerDataType<T> = T`.
39///
40/// **Current version**: no feature
41pub type InnerDataType<T> = T;
42
43#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
44/// Generic alias for elements of type `T` of a View.
45///
46/// This alias automatically changes according to features to ensure thread-safety
47/// of Views. There are two possible values:
48///
49/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
50/// operations on views can be implemented using thread-safe methods.
51/// - no feature enabled: `InnerDataType<T> = T`.
52///
53/// **Current version**: thread-safe
54pub type InnerDataType<T> = Atomic<T>;
55
56#[derive(Debug)]
57/// Enum used to identify the type of data the view is holding.
58///
59/// This should eventually be removed. See the [view][crate::view] module documentation
60/// for more information.
61///
62/// The policy used to implement the [PartialEq] trait is based on Kokkos'
63/// [`equal` algorithm](https://kokkos.github.io/kokkos-core-wiki/API/algorithms/std-algorithms/all/StdEqual.html).
64/// Essentially, it corresponds to equality by reference instead of equality by value.
65pub enum DataType<'a, T>
66where
67 T: DataTraits,
68{
69 /// The view owns the data.
70 Owned(Vec<InnerDataType<T>>),
71 /// The view borrows the data and can only read it.
72 Borrowed(&'a [InnerDataType<T>]),
73 /// The view borrows the data and can both read and modify it.
74 MutBorrowed(&'a mut [InnerDataType<T>]),
75}
76
77impl<T> PartialEq for DataType<'_, T>
78where
79 T: DataTraits,
80{
81 fn eq(&self, other: &Self) -> bool {
82 match (self, other) {
83 // are the deref operations necessary?
84 // this one is technically necessary because self==self should return true
85 (Self::Owned(l0), Self::Owned(r0)) => (*l0).as_ptr() == (*r0).as_ptr(),
86 // compare pointers
87 // deref Owned only once, twice the others
88 (Self::Owned(l0), Self::Borrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
89 (Self::Owned(l0), Self::MutBorrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
90 (Self::Borrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
91 (Self::MutBorrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
92
93 (Self::Borrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
94 (Self::Borrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
95 (Self::MutBorrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
96 (Self::MutBorrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
97 }
98 }
99}
100
101/// Enum used to represent data layout. Struct enums is used in order to increase
102/// readability.
103#[derive(Clone, Copy, Debug, Default, PartialEq)]
104pub enum Layout<const N: usize> {
105 /// Highest stride for the first index, decreasing stride as index increases.
106 /// Exact stride for each index can be computed from dimensions at view initialization.
107 #[default]
108 Right,
109 /// Lowest stride for the first index, increasing stride as index decreases.
110 /// Exact stride for each index can be computed from dimensions at view initialization.
111 Left,
112 /// Custom stride for each index. Must be compatible with dimensions.
113 Stride { s: [usize; N] },
114}
115
116/// Compute correct strides of each index using dimensions and specified layout.
117pub fn compute_stride<const N: usize>(dim: &[usize; N], layout: &Layout<N>) -> [usize; N] {
118 assert_eq!(N.clamp(1, MAX_VIEW_DEPTH), N); // 1 <= N <= MAX_N
119 match layout {
120 Layout::Right => {
121 let mut stride = [1; N];
122
123 let mut tmp: usize = 1;
124 for i in (1..N).rev() {
125 tmp *= dim[i];
126 stride[N - i] = tmp;
127 }
128
129 stride.reverse();
130 stride
131 }
132 Layout::Left => {
133 let mut stride = [1; N];
134
135 let mut tmp: usize = 1;
136 for i in 0..N - 1 {
137 tmp *= dim[i];
138 stride[i + 1] = tmp;
139 }
140
141 stride
142 }
143 Layout::Stride { s } => *s,
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn stride_right() {
153 // dim = [n0, n1, n2, n3]
154 let dim = [3, 4, 5, 6];
155
156 let cmp_stride = compute_stride(&dim, &Layout::Right);
157 // n3 * n2 * n1, n3 * n2, n3, 1
158 let ref_stride: [usize; 4] = [6 * 5 * 4, 6 * 5, 6, 1];
159
160 assert_eq!(cmp_stride, ref_stride);
161 }
162
163 #[test]
164 fn stride_left() {
165 // dim = [n0, n1, n2, n3]
166 let dim = [3, 4, 5, 6];
167
168 let cmp_stride = compute_stride(&dim, &Layout::Left);
169 // 1, n0, n0 * n1, n0 * n1 * n2
170 let ref_stride: [usize; 4] = [1, 3, 3 * 4, 3 * 4 * 5];
171
172 assert_eq!(cmp_stride, ref_stride);
173 }
174
175 #[test]
176 fn one_d_stride() {
177 // 1d view (vector) of length 1
178 let dim: [usize; 1] = [8];
179 let ref_stride: [usize; 1] = [1];
180 let mut cmp_stride = compute_stride(&dim, &Layout::Right);
181 assert_eq!(ref_stride, cmp_stride);
182 cmp_stride = compute_stride(&dim, &Layout::Left);
183 assert_eq!(ref_stride, cmp_stride);
184 }
185}