Параллельная карта, написанная на Rust. Я новичок в Rust и мне интересно, есть ли что-то, что можно было бы сделать лучше или эффективнее.
use crossbeam_channel::unbounded;
use std::{thread, time};
fn parallel_map<T, U, F>(mut input_vec: Vec<T>, num_threads: usize, f: F) -> Vec<U>
where
F: FnOnce(T) -> U + Send + Copy + 'static,
T: Send + 'static,
U: Send + 'static + Default,
{
let mut output_vec: Vec<U> = Vec::with_capacity(input_vec.len());
let mut threads = Vec::new();
let (in_s, in_r) = unbounded::<(T, usize)>();
let (out_s, out_r) = unbounded::<(U, usize)>();
for _ in 0..num_threads {
let in_r = in_r.clone();
let out_s = out_s.clone();
threads.push(thread::spawn(move || {
while let Ok((value, index)) = in_r.recv() {
let res = f(value);
out_s.send((res, index)).expect("Failed to send");
}
}));
}
while let Some(val) = input_vec.pop() {
in_s.send((val, input_vec.len())).expect("Failed to send");
}
drop(in_s);
drop(out_s);
let mut collect_results: Vec<(U, usize)> = Vec::with_capacity(output_vec.capacity());
while let Ok(res) = out_r.recv() {
collect_results.push(res);
}
collect_results.sort_by(|(_, a_index), (_, b_index)| a_index.partial_cmp(b_index).unwrap());
output_vec.extend(collect_results.into_iter().map(|(val, _)| val));
for thread in threads {
thread.join().expect("Failed to join thread");
}
output_vec
}
fn main() {
let v = vec![6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 12, 18, 11, 5, 20];
let squares = parallel_map(v, 10, |num| {
println!("{} squared is {}", num, num * num);
thread::sleep(time::Duration::from_millis(500));
num * num
});
println!("squares: {:?}", squares);
}
```
1 ответ
Ящик rayon
реализует аналогичный функционал:
use rayon::prelude::*;
let mut par_iter = (0..5).into_par_iter().map(|x| x * x);
let squares: Vec<_> = par_iter.collect();
В вашем решении многопоточность выглядит нормально.
Та часть, где вы собираете результаты, могла бы быть выполнена более эффективно. В частности, можно было бы избежать одного распределения и копии данных с небольшой небезопасностью.
let mut output_vec: Vec<U> = Vec::with_capacity(initial_len);
unsafe {
output_vec.set_len(initial_len);
}
while let Ok((res, index)) = out_r.recv() {
let old = mem::replace(&mut output_vec[index], res);
mem::forget(old);
}
Такую же оптимизацию можно было бы сделать немного проще с помощью ptr::write
. Кроме того, вам может потребоваться утверждать, что output_vec
имеет правильную длину, и forget
vec заранее на случай, если утверждение не будет выполнено. Меня беспокоит то, что предложенный выше код не защищен от неправильных падений, если U
имеет деструктор и код panic
с.
Ваша граница U
в виде Default
кажется ненужным.