poc_kokkos_rs/functor.rs
1//! functor & kernel related code
2//!
3//! This module contains all functor and kernel related code. Its content
4//! is highly dependant on the features enabled since the traits that a
5//! kernel must satisfy changes totally depending on the backend used.
6//!
7//! Kernel signatures are handled using `cargo` features. Using conditionnal
8//! compilation, the exact trait kernels must implement are adjusted according
9//! to the backend used to dispatch statements.
10//!
11//! In order to have actual closures match the required trait implementation,
12//! the same mechanism is used to define operations on [`Views`][crate::view].
13
14#[cfg(doc)]
15use crate::routines::parameters::RangePolicy;
16
17/// Kernel argument enum
18///
19/// In the Kokkos library, there is a finite number of kernel signatures.
20/// Each is associated to/determined by a given execution policy.
21/// In order to have kernel genericity in Rust, without introducing overhead
22/// due to downcasting, the solution was to define kernel arguments as a
23/// struct-like enum.
24///
25/// ### Example
26///
27/// One-dimensional kernel:
28/// ```
29/// // Range is defined in the execution policy
30/// use poc_kokkos_rs::functor::KernelArgs;
31///
32/// let kern = |arg: KernelArgs<1>| match arg {
33/// KernelArgs::Index1D(i) => {
34/// // body of the kernel
35/// println!("Hello from iteration {i}")
36/// },
37/// KernelArgs::IndexND(_) => unimplemented!(),
38/// KernelArgs::Handle => unimplemented!(),
39/// };
40/// ```
41///
42/// 3D kernel:
43/// ```
44/// use poc_kokkos_rs::functor::KernelArgs;
45///
46/// // Use the array
47/// let kern = |arg: KernelArgs<3>| match arg {
48/// KernelArgs::Index1D(_) => unimplemented!(),
49/// KernelArgs::IndexND(idx) => { // idx: [usize; 3]
50/// // body of the kernel
51/// println!("Hello from iteration {idx:?}")
52/// },
53/// KernelArgs::Handle => unimplemented!(),
54/// };
55///
56/// // Decompose the array
57/// let kern = |arg: KernelArgs<3>| match arg {
58/// KernelArgs::Index1D(_) => unimplemented!(),
59/// KernelArgs::IndexND([i, j, k]) => { // i,j,k: usize
60/// // body of the kernel
61/// println!("Hello from iteration {i},{j},{k}");
62/// },
63/// KernelArgs::Handle => unimplemented!(),
64/// };
65/// ```
66pub enum KernelArgs<const N: usize> {
67 /// Arguments of a one-dimensionnal kernel (e.g. a [RangePolicy][RangePolicy::RangePolicy]).
68 Index1D(usize),
69 /// Arguments of a `N`-dimensionnal kernel (e.g. a [MDRangePolicy][RangePolicy::MDRangePolicy]).
70 IndexND([usize; N]),
71 /// Arguments of a team-based kernel.
72 Handle,
73}
74
75cfg_if::cfg_if! {
76 if #[cfg(feature = "rayon")] {
77 /// `parallel_for` kernel type. Depends on enabled feature(s).
78 ///
79 /// This type alias is configured according to enabled feature in order to adjust
80 /// the signatures of kernels to match the requirements of the underlying dispatch routines.
81 ///
82 /// ### Possible Values
83 /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
84 /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
85 /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
86 ///
87 /// **Current version**: `rayon`
88 pub type ForKernelType<'a, const N: usize> = Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>;
89 } else if #[cfg(feature = "threads")] {
90 /// `parallel_for` kernel type. Depends on enabled feature(s).
91 ///
92 /// This type alias is configured according to enabled feature in order to adjust
93 /// the signatures of kernels to match the requirements of the underlying dispatch routines.
94 ///
95 /// ### Possible Values
96 /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
97 /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
98 /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
99 ///
100 /// **Current version**: `threads`
101 pub type ForKernelType<'a, const N: usize> = Box<dyn Fn(KernelArgs<N>) + Send + 'a>;
102 } else {
103 /// `parallel_for` kernel type. Depends on enabled feature(s).
104 ///
105 /// This type alias is configured according to enabled feature in order to adjust
106 /// the signatures of kernels to match the requirements of the underlying dispatch routines.
107 ///
108 /// ### Possible Values
109 /// - `rayon` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + Sync + 'a>`
110 /// - `threads` feature enabled: `Box<dyn Fn(KernelArgs<N>) + Send + 'a>`
111 /// - no feature enabled: fall back to [`SerialForKernelType`][SerialForKernelType]
112 ///
113 /// **Current version**: no feature
114 pub type ForKernelType<'a, const N: usize> = SerialForKernelType<'a, N>;
115 }
116}
117
118/// Serial kernel type. Does not depend on enabled feature(s).
119///
120/// This is the minimal required trait implementation for closures passed to a
121/// `for_each` statement.
122pub type SerialForKernelType<'a, const N: usize> = Box<dyn FnMut(KernelArgs<N>) + 'a>;