1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
//! view parameterization code
//!
//! This module contains all code used to parameterize views. Some parameters are direct
//! Kokkos replicates while others are Rust-specific. Currently supported parameters
//! include:
//!
//! - Type of the data owned by the view (Rust-specific)
//! - Memory layout
//!
//! Possible future implementations include:
//!
//! - Memory space
//! - Memory traits

use std::fmt::Debug;

#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
use atomic::Atomic;

/// Maximum possible depth (i.e. number of dimensions) for a view.
pub const MAX_VIEW_DEPTH: usize = 8;

/// Supertrait with common trait that elements of a View should implement.
pub trait DataTraits: Debug + Clone + Copy + Default {}

impl DataTraits for f64 {}
impl DataTraits for f32 {}

#[cfg(not(any(feature = "rayon", feature = "threads", feature = "gpu")))]
/// Generic alias for elements of type `T` of a View.
///
/// This alias automatically changes according to features to ensure thread-safety
/// of Views. There are two possible values:
///
/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
/// operations on views can be implemented using thread-safe methods.
/// - no feature enabled: `InnerDataType<T> = T`.
///
/// **Current version**: no feature
pub type InnerDataType<T> = T;

#[cfg(any(feature = "rayon", feature = "threads", feature = "gpu"))]
/// Generic alias for elements of type `T` of a View.
///
/// This alias automatically changes according to features to ensure thread-safety
/// of Views. There are two possible values:
///
/// - any feature enabled: `InnerDataType<T> = Atomic<T>`. By adding the atomic wrapping,
/// operations on views can be implemented using thread-safe methods.
/// - no feature enabled: `InnerDataType<T> = T`.
///
/// **Current version**: thread-safe
pub type InnerDataType<T> = Atomic<T>;

#[derive(Debug)]
/// Enum used to identify the type of data the view is holding.
///
/// This should eventually be removed. See the [view][crate::view] module documentation
/// for more information.
///
/// The policy used to implement the [PartialEq] trait is based on Kokkos'
/// [`equal` algorithm](https://kokkos.github.io/kokkos-core-wiki/API/algorithms/std-algorithms/all/StdEqual.html).
/// Essentially, it corresponds to equality by reference instead of equality by value.
pub enum DataType<'a, T>
where
    T: DataTraits,
{
    /// The view owns the data.
    Owned(Vec<InnerDataType<T>>),
    /// The view borrows the data and can only read it.
    Borrowed(&'a [InnerDataType<T>]),
    /// The view borrows the data and can both read and modify it.
    MutBorrowed(&'a mut [InnerDataType<T>]),
}

impl<'a, T> PartialEq for DataType<'a, T>
where
    T: DataTraits,
{
    fn eq(&self, other: &Self) -> bool {
        match (self, other) {
            // are the deref operations necessary?
            // this one is technically necessary because self==self should return true
            (Self::Owned(l0), Self::Owned(r0)) => (*l0).as_ptr() == (*r0).as_ptr(),
            // compare pointers
            // deref Owned only once, twice the others
            (Self::Owned(l0), Self::Borrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
            (Self::Owned(l0), Self::MutBorrowed(r0)) => (*l0).as_ptr() == (**r0).as_ptr(),
            (Self::Borrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),
            (Self::MutBorrowed(l0), Self::Owned(r0)) => (**l0).as_ptr() == (*r0).as_ptr(),

            (Self::Borrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
            (Self::Borrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
            (Self::MutBorrowed(l0), Self::MutBorrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
            (Self::MutBorrowed(l0), Self::Borrowed(r0)) => (**l0).as_ptr() == (**r0).as_ptr(),
        }
    }
}

/// Enum used to represent data layout. Struct enums is used in order to increase
/// readability.
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub enum Layout<const N: usize> {
    /// Highest stride for the first index, decreasing stride as index increases.
    /// Exact stride for each index can be computed from dimensions at view initialization.
    #[default]
    Right,
    /// Lowest stride for the first index, increasing stride as index decreases.
    /// Exact stride for each index can be computed from dimensions at view initialization.
    Left,
    /// Custom stride for each index. Must be compatible with dimensions.
    Stride { s: [usize; N] },
}

/// Compute correct strides of each index using dimensions and specified layout.
pub fn compute_stride<const N: usize>(dim: &[usize; N], layout: &Layout<N>) -> [usize; N] {
    assert_eq!(N.clamp(1, MAX_VIEW_DEPTH), N); // 1 <= N <= MAX_N
    match layout {
        Layout::Right => {
            let mut stride = [1; N];

            let mut tmp: usize = 1;
            for i in (1..N).rev() {
                tmp *= dim[i];
                stride[N - i] = tmp;
            }

            stride.reverse();
            stride
        }
        Layout::Left => {
            let mut stride = [1; N];

            let mut tmp: usize = 1;
            for i in 0..N - 1 {
                tmp *= dim[i];
                stride[i + 1] = tmp;
            }

            stride
        }
        Layout::Stride { s } => *s,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn stride_right() {
        // dim = [n0, n1, n2, n3]
        let dim = [3, 4, 5, 6];

        let cmp_stride = compute_stride(&dim, &Layout::Right);
        // n3 * n2 * n1, n3 * n2, n3, 1
        let ref_stride: [usize; 4] = [6 * 5 * 4, 6 * 5, 6, 1];

        assert_eq!(cmp_stride, ref_stride);
    }

    #[test]
    fn stride_left() {
        // dim = [n0, n1, n2, n3]
        let dim = [3, 4, 5, 6];

        let cmp_stride = compute_stride(&dim, &Layout::Left);
        // 1, n0, n0 * n1, n0 * n1 * n2
        let ref_stride: [usize; 4] = [1, 3, 3 * 4, 3 * 4 * 5];

        assert_eq!(cmp_stride, ref_stride);
    }

    #[test]
    fn one_d_stride() {
        // 1d view (vector) of length 1
        let dim: [usize; 1] = [8];
        let ref_stride: [usize; 1] = [1];
        let mut cmp_stride = compute_stride(&dim, &Layout::Right);
        assert_eq!(ref_stride, cmp_stride);
        cmp_stride = compute_stride(&dim, &Layout::Left);
        assert_eq!(ref_stride, cmp_stride);
    }
}