poc_kokkos_rs/
functor.rs

1//! functor & kernel related code
2//!
3//! This module contains all functor and kernel related code. Its content
4//! is highly dependant on the features enabled since the traits that a
5//! kernel must satisfy changes totally depending on the backend used.
6//!
7//! Kernel signatures are handled using `cargo` features. Using conditionnal
8//! compilation, the exact trait kernels must implement are adjusted according
9//! to the backend used to dispatch statements.
10//!
11//! In order to have actual closures match the required trait implementation,
12//! the same mechanism is used to define operations on [`Views`][crate::view].
13
14#[cfg(doc)]
15use crate::routines::parameters::RangePolicy;
16
17/// Kernel argument enum
18///
19/// In the Kokkos library, there is a finite number of kernel signatures.
20/// Each is associated to/determined by a given execution policy.
21/// In order to have kernel genericity in Rust, without introducing overhead
22/// due to downcasting, the solution was to define kernel arguments as a
23/// struct-like enum.
24///
25/// ### Example
26///
27/// One-dimensional kernel:
28/// ```
29/// // Range is defined in the execution policy
30/// use poc_kokkos_rs::functor::KernelArgs;
31///
32/// let kern = |arg: KernelArgs<1>| match arg {
33///         KernelArgs::Index1D(i) => {
34///             // body of the kernel
35///             println!("Hello from iteration {i}")
36///         },
37///         KernelArgs::IndexND(_) => unimplemented!(),
38///         KernelArgs::Handle => unimplemented!(),
39///     };
40/// ```
41///
42/// 3D kernel:
43/// ```
44/// use poc_kokkos_rs::functor::KernelArgs;
45///
46/// // Use the array
47/// let kern = |arg: KernelArgs<3>| match arg {
48///         KernelArgs::Index1D(_) => unimplemented!(),
49///         KernelArgs::IndexND(idx) => { // idx: [usize; 3]
50///             // body of the kernel
51///             println!("Hello from iteration {idx:?}")
52///         },
53///         KernelArgs::Handle => unimplemented!(),
54///     };
55///
56/// // Decompose the array
57/// let kern = |arg: KernelArgs<3>| match arg {
58///         KernelArgs::Index1D(_) => unimplemented!(),
59///         KernelArgs::IndexND([i, j, k]) => { // i,j,k: usize
60///             // body of the kernel
61///             println!("Hello from iteration {i},{j},{k}");
62///         },
63///         KernelArgs::Handle => unimplemented!(),
64///     };
65/// ```
66pub enum KernelArgs<const N: usize> {
67    /// Arguments of a one-dimensionnal kernel (e.g. a [RangePolicy][RangePolicy::RangePolicy]).
68    Index1D(usize),
69    /// Arguments of a `N`-dimensionnal kernel (e.g. a [MDRangePolicy][RangePolicy::MDRangePolicy]).
70    IndexND([usize; N]),
71    /// Arguments of a team-based kernel.
72    Handle,
73}
74
75cfg_if::cfg_if! {
76    if #[cfg(feature = "rayon")] {
77        /// `parallel_for` kernel type. Depends on enabled feature(s).
78        ///
79        /// This type alias is configured according to enabled feature in order to adjust
80        /// the signatures of kernels to match the requirements of the underlying dispatch routines.
81        ///
82        /// ### Possible Values
83        /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
84        /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
85        /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
86        ///
87        /// **Current version**: `rayon`
88        pub type ForKernelType<'a, const N: usize> = Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>;
89    } else if #[cfg(feature = "threads")] {
90        /// `parallel_for` kernel type. Depends on enabled feature(s).
91        ///
92        /// This type alias is configured according to enabled feature in order to adjust
93        /// the signatures of kernels to match the requirements of the underlying dispatch routines.
94        ///
95        /// ### Possible Values
96        /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
97        /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
98        /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
99        ///
100        /// **Current version**: `threads`
101        pub type ForKernelType<'a, const N: usize> = Box<dyn Fn(KernelArgs<N>) + Send + 'a>;
102    } else {
103        /// `parallel_for` kernel type. Depends on enabled feature(s).
104        ///
105        /// This type alias is configured according to enabled feature in order to adjust
106        /// the signatures of kernels to match the requirements of the underlying dispatch routines.
107        ///
108        /// ### Possible Values
109        /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
110        /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
111        /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
112        ///
113        /// **Current version**: no feature
114        pub type ForKernelType<'a, const N: usize> = SerialForKernelType<'a, N>;
115    }
116}
117
118/// Serial kernel type. Does not depend on enabled feature(s).
119///
120/// This is the minimal required trait implementation for closures passed to a
121/// `for_each` statement.
122pub type SerialForKernelType<'a, const N: usize> = Box<dyn FnMut(KernelArgs<N>) + 'a>;