poc_kokkos_rs/routines/mod.rs
1//! parallel statement related code
2//!
3//! This module contains code used for the implementation of parallel statements, e.g.
4//! `parallel_for`, a Kokkos specific implementation of commonly used patterns.
5//!
6//! Parameters of aforementionned statements are defined in the [`parameters`] sub-module.
7//!
8//! Dispatch code is defined in the [`dispatch`] sub-module.
9//!
10//! Currently implemented statements:
11//!
12//! - `parallel_for`
13
14pub mod dispatch;
15pub mod parameters;
16
17use std::fmt::Display;
18
19use crate::functor::KernelArgs;
20
21use self::{dispatch::DispatchError, parameters::ExecutionPolicy};
22
23// Enums
24
25/// Enum used to classify possible errors occuring in a parallel statement.
26#[derive(Debug)]
27pub enum StatementError {
28 /// Error occured during dispatch; The specific [DispatchError] is
29 /// used as the internal value of this variant.
30 Dispatch(DispatchError),
31 /// Error raised when parallel hierarchy isn't respected.
32 InconsistentDepth,
33 /// What did I mean by this?
34 InconsistentExecSpace,
35}
36
37impl From<DispatchError> for StatementError {
38 fn from(e: DispatchError) -> Self {
39 StatementError::Dispatch(e)
40 }
41}
42
43impl Display for StatementError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 StatementError::Dispatch(e) => write!(f, "{}", e),
47 StatementError::InconsistentDepth => {
48 write!(f, "inconsistent depth & range policy association")
49 }
50 StatementError::InconsistentExecSpace => {
51 write!(f, "?")
52 }
53 }
54 }
55}
56
57impl std::error::Error for StatementError {
58 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
59 match self {
60 StatementError::Dispatch(e) => Some(e),
61 StatementError::InconsistentDepth => None,
62 StatementError::InconsistentExecSpace => None,
63 }
64 }
65}
66
67// Statements
68
69// All of this would be half as long if impl trait in type aliases was stabilized
70
71cfg_if::cfg_if! {
72 if #[cfg(feature = "threads")] {
73 /// Parallel For statement.
74 ///
75 /// **Current version**: `threads`
76 ///
77 /// ### Example
78 ///
79 /// ```rust
80 /// use poc_kokkos_rs::{
81 /// functor::KernelArgs,
82 /// routines::{
83 /// parallel_for,
84 /// parameters::{ExecutionPolicy, ExecutionSpace, RangePolicy, Schedule},
85 /// },
86 /// };
87 ///
88 /// let length: usize = 8;
89 ///
90 /// let kern = |arg: KernelArgs<1>| match arg {
91 /// KernelArgs::Index1D(i) => {
92 /// // body of the kernel
93 /// println!("Hello from iteration {i}")
94 /// },
95 /// KernelArgs::IndexND(_) => unimplemented!(),
96 /// KernelArgs::Handle => unimplemented!(),
97 /// };
98 ///
99 /// let execp = ExecutionPolicy {
100 /// space: ExecutionSpace::DeviceCPU,
101 /// range: RangePolicy::RangePolicy(0..length),
102 /// schedule: Schedule::Static,
103 /// };
104 ///
105 /// parallel_for(execp, kern).unwrap();
106 /// ```
107 pub fn parallel_for<const N: usize>(
108 execp: ExecutionPolicy<N>,
109 func: impl Fn(KernelArgs<N>) + Send + Sync + Clone,
110 ) -> Result<(), StatementError> {
111 // checks...
112
113 // data prep?
114 let kernel = Box::new(func);
115
116 // dispatch
117 let res = match execp.space {
118 parameters::ExecutionSpace::Serial => dispatch::serial(execp, kernel),
119 parameters::ExecutionSpace::DeviceCPU => dispatch::cpu(execp, kernel),
120 parameters::ExecutionSpace::DeviceGPU => dispatch::gpu(execp, kernel),
121 };
122
123 // Ok or converts error
124 res.map_err(|e| e.into())
125 }
126 } else if #[cfg(feature = "rayon")] {
127 /// Parallel For statement.
128 ///
129 /// **Current version**: `rayon`
130 ///
131 /// ### Example
132 ///
133 /// ```rust
134 /// use poc_kokkos_rs::{
135 /// functor::KernelArgs,
136 /// routines::{
137 /// parallel_for,
138 /// parameters::{ExecutionPolicy, ExecutionSpace, RangePolicy, Schedule},
139 /// },
140 /// };
141 ///
142 /// let length: usize = 8;
143 ///
144 /// let kern = |arg: KernelArgs<1>| match arg {
145 /// KernelArgs::Index1D(i) => {
146 /// // body of the kernel
147 /// println!("Hello from iteration {i}")
148 /// },
149 /// KernelArgs::IndexND(_) => unimplemented!(),
150 /// KernelArgs::Handle => unimplemented!(),
151 /// };
152 ///
153 /// let execp = ExecutionPolicy {
154 /// space: ExecutionSpace::DeviceCPU,
155 /// range: RangePolicy::RangePolicy(0..length),
156 /// schedule: Schedule::Static,
157 /// };
158 ///
159 /// parallel_for(execp, kern).unwrap();
160 /// ```
161 pub fn parallel_for<const N: usize>(
162 execp: ExecutionPolicy<N>,
163 func: impl Fn(KernelArgs<N>) + Send + Sync,
164 ) -> Result<(), StatementError> {
165 // checks...
166
167 // data prep?
168 let kernel = Box::new(func);
169
170 // dispatch
171 let res = match execp.space {
172 parameters::ExecutionSpace::Serial => dispatch::serial(execp, kernel),
173 parameters::ExecutionSpace::DeviceCPU => dispatch::cpu(execp, kernel),
174 parameters::ExecutionSpace::DeviceGPU => dispatch::gpu(execp, kernel),
175 };
176
177 // Ok or converts error
178 res.map_err(|e| e.into())
179 }
180 } else {
181 /// Parallel For statement.
182 ///
183 /// **Current version**: no feature
184 ///
185 /// ### Example
186 ///
187 /// ```rust
188 /// use poc_kokkos_rs::{
189 /// functor::KernelArgs,
190 /// routines::{
191 /// parallel_for,
192 /// parameters::{ExecutionPolicy, ExecutionSpace, RangePolicy, Schedule},
193 /// },
194 /// };
195 ///
196 /// let length: usize = 8;
197 ///
198 /// let kern = |arg: KernelArgs<1>| match arg {
199 /// KernelArgs::Index1D(i) => {
200 /// // body of the kernel
201 /// println!("Hello from iteration {i}")
202 /// },
203 /// KernelArgs::IndexND(_) => unimplemented!(),
204 /// KernelArgs::Handle => unimplemented!(),
205 /// };
206 ///
207 /// let execp = ExecutionPolicy {
208 /// space: ExecutionSpace::DeviceCPU,
209 /// range: RangePolicy::RangePolicy(0..length),
210 /// schedule: Schedule::Static,
211 /// };
212 ///
213 /// parallel_for(execp, kern).unwrap();
214 /// ```
215 pub fn parallel_for<const N: usize>(
216 execp: ExecutionPolicy<N>,
217 func: impl FnMut(KernelArgs<N>),
218 ) -> Result<(), StatementError> {
219 // checks...
220
221 // data prep?
222 let kernel = Box::new(func);
223
224 // dispatch
225 let res = match execp.space {
226 parameters::ExecutionSpace::Serial => dispatch::serial(execp, kernel),
227 parameters::ExecutionSpace::DeviceCPU => dispatch::cpu(execp, kernel),
228 parameters::ExecutionSpace::DeviceGPU => dispatch::gpu(execp, kernel),
229 };
230
231 // Ok or converts error
232 res.map_err(|e| e.into())
233 }
234 }
235}