1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
//! view parameterization code
//!
//! This module contains all code used to parameterize views. Some parameters are direct
//! Kokkos replicates while others are Rust-specific. Currently supported parameters
//! include:
//!
//! - Type of the data owned by the view (Rust-specific)
//! - Memory layout
//!
//! Possible future implementations include:
//!
//! - Memory space
//! - Memory traits
use std::fmt::Debug;
#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
use atomic::Atomic;
/// Maximum possible depth (i.e. number of dimensions) for a view.
pub const MAX_VIEW_DEPTH: usize = 8;
/// Supertrait with common trait that elements of a View should implement.
pub trait DataTraits: Debug + Clone + Copy + Default {}
impl DataTraits for f64 {}
impl DataTraits for f32 {}
#[cfg(not(any(feature = "rayon", feature = "threads", feature = "gpu")))]
/// Generic alias for elements of type `T` of a View.
///
/// This alias automatically changes according to features to ensure thread-safety
/// of Views. There are two possible values:
///
/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
/// operations on views can be implemented using thread-safe methods.
/// - no feature enabled: `InnerDataType<T> = T`.
///
/// **Current version**: no feature
pub type InnerDataType<T> = T;
#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
/// Generic alias for elements of type `T` of a View.
///
/// This alias automatically changes according to features to ensure thread-safety
/// of Views. There are two possible values:
///
/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
/// operations on views can be implemented using thread-safe methods.
/// - no feature enabled: `InnerDataType<T> = T`.
///
/// **Current version**: thread-safe
pub type InnerDataType<T> = Atomic<T>;
#[derive(Debug)]
/// Enum used to identify the type of data the view is holding.
///
/// This should eventually be removed. See the [view][crate::view] module documentation
/// for more information.
///
/// The policy used to implement the [PartialEq] trait is based on Kokkos'
/// [`equal` algorithm](https://kokkos.github.io/kokkos-core-wiki/API/algorithms/std-algorithms/all/StdEqual.html).
/// Essentially, it corresponds to equality by reference instead of equality by value.
pub enum DataType<'a, T>
where
T: DataTraits,
{
/// The view owns the data.
Owned(Vec<InnerDataType<T>>),
/// The view borrows the data and can only read it.
Borrowed(&'a [InnerDataType<T>]),
/// The view borrows the data and can both read and modify it.
MutBorrowed(&'a mut [InnerDataType<T>]),
}
impl<'a, T> PartialEq for DataType<'a, T>
where
T: DataTraits,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
// are the deref operations necessary?
// this one is technically necessary because self==self should return true
(Self::Owned(l0), Self::Owned(r0)) => (*l0).as_ptr() == (*r0).as_ptr(),
// compare pointers
// deref Owned only once, twice the others
(Self::Owned(l0), Self::Borrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
(Self::Owned(l0), Self::MutBorrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
(Self::Borrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
(Self::MutBorrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
(Self::Borrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
(Self::Borrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
(Self::MutBorrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
(Self::MutBorrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
}
}
}
/// Enum used to represent data layout. Struct enums is used in order to increase
/// readability.
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum Layout<const N: usize> {
/// Highest stride for the first index, decreasing stride as index increases.
/// Exact stride for each index can be computed from dimensions at view initialization.
#[default]
Right,
/// Lowest stride for the first index, increasing stride as index decreases.
/// Exact stride for each index can be computed from dimensions at view initialization.
Left,
/// Custom stride for each index. Must be compatible with dimensions.
Stride { s: [usize; N] },
}
/// Compute correct strides of each index using dimensions and specified layout.
pub fn compute_stride<const N: usize>(dim: &[usize; N], layout: &Layout<N>) -> [usize; N] {
assert_eq!(N.clamp(1, MAX_VIEW_DEPTH), N); // 1 <= N <= MAX_N
match layout {
Layout::Right => {
let mut stride = [1; N];
let mut tmp: usize = 1;
for i in (1..N).rev() {
tmp *= dim[i];
stride[N - i] = tmp;
}
stride.reverse();
stride
}
Layout::Left => {
let mut stride = [1; N];
let mut tmp: usize = 1;
for i in 0..N - 1 {
tmp *= dim[i];
stride[i + 1] = tmp;
}
stride
}
Layout::Stride { s } => *s,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stride_right() {
// dim = [n0, n1, n2, n3]
let dim = [3, 4, 5, 6];
let cmp_stride = compute_stride(&dim, &Layout::Right);
// n3 * n2 * n1, n3 * n2, n3, 1
let ref_stride: [usize; 4] = [6 * 5 * 4, 6 * 5, 6, 1];
assert_eq!(cmp_stride, ref_stride);
}
#[test]
fn stride_left() {
// dim = [n0, n1, n2, n3]
let dim = [3, 4, 5, 6];
let cmp_stride = compute_stride(&dim, &Layout::Left);
// 1, n0, n0 * n1, n0 * n1 * n2
let ref_stride: [usize; 4] = [1, 3, 3 * 4, 3 * 4 * 5];
assert_eq!(cmp_stride, ref_stride);
}
#[test]
fn one_d_stride() {
// 1d view (vector) of length 1
let dim: [usize; 1] = [8];
let ref_stride: [usize; 1] = [1];
let mut cmp_stride = compute_stride(&dim, &Layout::Right);
assert_eq!(ref_stride, cmp_stride);
cmp_stride = compute_stride(&dim, &Layout::Left);
assert_eq!(ref_stride, cmp_stride);
}
}