Параллельная карта в Rust

Параллельная карта, написанная на Rust. Я новичок в Rust и мне интересно, есть ли что-то, что можно было бы сделать лучше или эффективнее.

use crossbeam_channel::unbounded;
use std::{thread, time};

fn parallel_map<T, U, F>(mut input_vec: Vec<T>, num_threads: usize, f: F) -> Vec<U>
where
    F: FnOnce(T) -> U + Send + Copy + 'static,
    T: Send + 'static,
    U: Send + 'static + Default,
{
    let mut output_vec: Vec<U> = Vec::with_capacity(input_vec.len());
    let mut threads = Vec::new();

    let (in_s, in_r) = unbounded::<(T, usize)>();
    let (out_s, out_r) = unbounded::<(U, usize)>();

    for _ in 0..num_threads {
        let in_r = in_r.clone();
        let out_s = out_s.clone();
        threads.push(thread::spawn(move || {
            while let Ok((value, index)) = in_r.recv() {
                let res = f(value);
                out_s.send((res, index)).expect("Failed to send");
            }
        }));
    }

    while let Some(val) = input_vec.pop() {
        in_s.send((val, input_vec.len())).expect("Failed to send");
    }

    drop(in_s);
    drop(out_s);

    let mut collect_results: Vec<(U, usize)> = Vec::with_capacity(output_vec.capacity());
    while let Ok(res) = out_r.recv() {
        collect_results.push(res);
    }

    collect_results.sort_by(|(_, a_index), (_, b_index)| a_index.partial_cmp(b_index).unwrap());

    output_vec.extend(collect_results.into_iter().map(|(val, _)| val));

    for thread in threads {
        thread.join().expect("Failed to join thread");
    }

    output_vec
}

fn main() {
    let v = vec![6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 12, 18, 11, 5, 20];
    let squares = parallel_map(v, 10, |num| {
        println!("{} squared is {}", num, num * num);
        thread::sleep(time::Duration::from_millis(500));
        num * num
    });
    println!("squares: {:?}", squares);
}
```

1 ответ
1

Ящик rayon реализует аналогичный функционал:

use rayon::prelude::*;

let mut par_iter = (0..5).into_par_iter().map(|x| x * x);

let squares: Vec<_> = par_iter.collect();

В вашем решении многопоточность выглядит нормально.

Та часть, где вы собираете результаты, могла бы быть выполнена более эффективно. В частности, можно было бы избежать одного распределения и копии данных с небольшой небезопасностью.

let mut output_vec: Vec<U> = Vec::with_capacity(initial_len);
unsafe {
    output_vec.set_len(initial_len);
}
while let Ok((res, index)) = out_r.recv() {
    let old = mem::replace(&mut output_vec[index], res);
    mem::forget(old);
}

Такую же оптимизацию можно было бы сделать немного проще с помощью ptr::write. Кроме того, вам может потребоваться утверждать, что output_vec имеет правильную длину, и forget vec заранее на случай, если утверждение не будет выполнено. Меня беспокоит то, что предложенный выше код не защищен от неправильных падений, если U имеет деструктор и код panicс.

Ваша граница U в виде Default кажется ненужным.

    Добавить комментарий

    Ваш адрес email не будет опубликован. Обязательные поля помечены *