Многопоточная реализация алгоритма кластеризации K-средних на Java

Здравствуйте, я написал многопоточную реализацию алгоритма кластеризации K-средних. Основные цели — это оперативность и масштабируемая производительность на процессорах mluticore. Я ожидаю, что в коде не будет условий гонки и гонки данных, и он будет хорошо масштабироваться с большим количеством ядер ЦП.

package bg.unisofia.fmi.rsa;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ParallelKmeans {

    private static CountDownLatch countDownLatch;
    private final int n;
    private final int k;
    public int numThreads = 1;
    List<Node> observations = new ArrayList<>();
    float[][] clusters;

    public ParallelKmeans(int n, int k) {
        this.n = n;
        this.k = k;
        clusters = new float[k][n];
        for (float[] cluster : clusters) {
            for (int i = 0; i < cluster.length; i++) {
                cluster[i] = (float) Math.random();
            }
        }
    }

    public void assignStep(ExecutorService executorService) throws InterruptedException {
        Runnable[] assignWorkers = new AssignWorker[numThreads];
        final int chunk = observations.size() / assignWorkers.length;
        countDownLatch = new CountDownLatch(numThreads);
        for (int j = 0; j < assignWorkers.length; j++) {
            assignWorkers[j] = new AssignWorker(j * chunk, (j + 1) * chunk);
            executorService.execute(assignWorkers[j]);
        }
        countDownLatch.await();

    }

    public void updateStep(ExecutorService executorService) throws InterruptedException {

        countDownLatch = new CountDownLatch(numThreads);

        UpdateWorker[] updateWorkers = new UpdateWorker[numThreads];
        final int chunk = observations.size() / updateWorkers.length;
        for (int j = 0; j < updateWorkers.length; j++) {
            updateWorkers[j] = new UpdateWorker(j * chunk, (j + 1) * chunk);
            executorService.execute(updateWorkers[j]);
        }
        countDownLatch.await();
        clusters = new float[k][n];
        int[] counts = new int[k];

        for (UpdateWorker u : updateWorkers) {
            VectorMath.add(counts, u.getCounts());
            for (int j = 0; j < k; j++) {
                VectorMath.add(clusters[j], u.getClusters()[j]);
            }
        }

        for (int j = 0; j < clusters.length; j++) {
            if (counts[j] != 0) {
                VectorMath.divide(clusters[j], counts[j]);
            }
        }
    }

    void cluster() throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);
        for (int i = 0; i < 50; i++) {
            assignStep(executorService);
            updateStep(executorService);
        }
        executorService.shutdown();
    }

    public static class Node {
        float[] vec;
        int cluster;
    }

    class AssignWorker implements Runnable {
        int l, r;

        public AssignWorker(int l, int r) {
            this.l = l;
            this.r = r;
        }

        @Override
        public void run() {
            List<Node> chunk = observations.subList(l, r);
            for (Node ob : chunk) {
                float minDist = Float.POSITIVE_INFINITY;
                int idx = 0;
                for (int i = 0; i < clusters.length; i++) {
                    if (minDist > VectorMath.dist(ob.vec, clusters[i])) {
                        minDist = VectorMath.dist(ob.vec, clusters[i]);
                        idx = i;
                    }
                }
                ob.cluster = idx;
            }
            countDownLatch.countDown();
        }
    }

    class UpdateWorker implements Runnable {
        int[] counts;
        int l, r;
        float[][] clusters;

        UpdateWorker(int l, int r) {
            this.l = l;
            this.r = r;
        }

        int[] getCounts() {
            return counts;
        }

        public float[][] getClusters() {
            return clusters;
        }

        @Override
        public void run() {
            this.counts = new int[k];
            this.clusters = new float[k][n];
            for (Node ob : observations.subList(l, r)) {
                VectorMath.add(this.clusters[ob.cluster], ob.vec);
                this.counts[ob.cluster]++;
            }
            countDownLatch.countDown();
        }
    }

}
```

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *