Простой целочисленный хэш-набор Java – продолжение 2

(См. Предыдущую версию.)

Теперь у меня есть это:

com.github.coderodde.util.IntHashSet:

package com.github.coderodde.util;

import java.util.HashSet;
import java.util.Random;
import java.util.Set;

/**
 * This class implements a simple hash set for non-negative {@code int} values.
 * It is used in the {@link com.github.coderodde.util.LinkedList} in order to 
 * keep track of nodes that are being pointed to by fingers.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.6 (Aug 29, 2021)
 * @since 1.6 (Aug 29, 2021)
 */
public class IntHashSet {

    private static final int INITIAL_CAPACITY = 8;

    private static final class Node {
        Node next;
        int integer;

        Node(int integer, Node next) {
            this.integer = integer;
            this.next = next;
        }

        @Override
        public String toString() {
            return "Chain node, integer = " + integer;
        }
    }

    private Node[] table = new Node[INITIAL_CAPACITY];
    private int size = 0;
    private int mask = INITIAL_CAPACITY - 1;
    
    @Override
    public String toString() {
        return "size = " + size;
    }

    public void add(int integer) {
        int targetCollisionChainIndex = integer & mask;
        Node node = table[targetCollisionChainIndex];
        
        while (node != null) {
            if (node.integer == integer) {
                return;
            }
            
            node = node.next;
        }
        
        size++;
        
        if (size > table.length) {
            Node[] newTable = new Node[2 * table.length];
            mask = newTable.length - 1;
            
            for (Node currentNode : table) {
                while (currentNode != null) {
                    Node nextNode = currentNode.next;
                    
                    int newTableHash = currentNode.integer & mask;
                    currentNode.next = newTable[newTableHash];
                    newTable[newTableHash] = currentNode;
                    
                    currentNode = nextNode;
                }
            }
            
            table = newTable;
            targetCollisionChainIndex = integer & mask;
        }
        
        Node newNode = new Node(integer, table[targetCollisionChainIndex]);
        table[targetCollisionChainIndex] = newNode;
    }

    public boolean contains(int integer) {
        final int collisionChainIndex = integer & mask;
        Node node = table[collisionChainIndex];

        while (node != null) {
            if (node.integer == integer) {
                return true;
            }

            node = node.next;
        }

        return false;
    }

    public void remove(int integer) {
        int targetCollisionChainIndex = integer & mask;
        Node node = table[targetCollisionChainIndex];
        Node prev = null;
        
        while (node != null) {
            if (node.integer == integer) {
                break;
            }
            
            prev = node;
            node = node.next;
        }
        
        if (node == null) 
            return;

        size--;
        
        if (size * 4 <= table.length && table.length >= INITIAL_CAPACITY * 4) {
            Node[] newTable = new Node[table.length / 4];
            mask = newTable.length - 1;
            
            for (Node currentNode : table) {
                while (currentNode != null) {
                    if (currentNode == node) {
                        // Omit the node with the target integer:
                        currentNode = currentNode.next;
                        continue;
                    }
                    
                    Node nextNode = currentNode.next;
                    
                    int newTableHash = currentNode.integer & mask;
                    currentNode.next = newTable[newTableHash];
                    newTable[newTableHash] = currentNode;
                    
                    currentNode = nextNode;
                }
            }
            
            table = newTable;
        } else  if (prev == null) {
            table[targetCollisionChainIndex] = 
                    table[targetCollisionChainIndex].next;
        } else {
            prev.next = prev.next.next;
        }
    }

    public void clear() {
         size = 0;
         table = new Node[INITIAL_CAPACITY];
         mask = table.length - 1;
    }
    
    private static final int DATA_LENGTH = 5_000_000;
    

    
    public static void main(String[] args) {
        Random random = new Random(10L);
        
        int[] addData      = getAddData(random);
        int[] containsData = getContainsData(random);
        int[] removeData   = getRemoveData(random);
        
        for (int iter = 0; iter < 5; iter++) {
            System.out.println(">>> Iteration: " + (iter + 1) + "/5");
            
            IntHashSet myset = new IntHashSet();
            Set<Integer> set = new HashSet<>();

            long start = System.currentTimeMillis();
            for (int i : addData) {
                myset.add(i);
            }
            long end = System.currentTimeMillis();

            System.out.println("    IntHashSet.add in " + (end - start));

            start = System.currentTimeMillis();
            for (int i : addData) {
                set.add(i);
            }
            end = System.currentTimeMillis();

            System.out.println("    HashSet.add in " + (end - start) + "n");

            start = System.currentTimeMillis();
            for (int i : containsData) {
                myset.contains(i);
            }
            end = System.currentTimeMillis();

            System.out.println("    IntHashSet.contains in " + (end - start));

            start = System.currentTimeMillis();
            for (int i : containsData) {
                set.contains(i);
            }
            end = System.currentTimeMillis();

            System.out.println("    HashSet.contains in " + (end - start) + 
                    "n");

            start = System.currentTimeMillis();
            for (int i : removeData) {
                myset.remove(i);
            }
            end = System.currentTimeMillis();

            System.out.println("    IntHashSet.remove in " + (end - start));

            start = System.currentTimeMillis();
            for (int i : removeData) {
                set.remove(i);
            }
            end = System.currentTimeMillis();

            System.out.println("    HashSet.remove in " + (end - start) + "n");
        }
    }
        
    private static int[] getAddData(Random random) {
        return getData(DATA_LENGTH, 3 * DATA_LENGTH / 2, random);
    }
        
    private static int[] getContainsData(Random random) {
        return getData(DATA_LENGTH, 3 * DATA_LENGTH / 2, random);
    }
        
    private static int[] getRemoveData(Random random) {
        return getData(DATA_LENGTH, 3 * DATA_LENGTH / 2, random);
    }
    
    private static int[] getData(int length, int maxValue, Random random) {
        int[] data = new int[length];
        
        for (int i = 0; i < length; i++) {
            data[i] = random.nextInt(maxValue + 1);
        }
        
        return data;
    }
}

com.github.coderodde.util.IntHashSetTest:

package com.github.coderodde.util;

import java.util.Random;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Before;
import org.junit.Test;

public class IntHashSetTest {

    private final IntHashSet set = new IntHashSet();

    @Before
    public void beforeTest() {
        set.clear();
    }
    
    @Test
    public void removeFromCollisionChainBug() {
        bar("removeFromCollisionChainBug");
        
        set.add(0b00001); // 1
        set.add(0b01001); // 9
        set.add(0b10001); // 17
        
        set.remove(1); // remove from tail
        
        set.add(0b00001); // 1
        set.add(0b01001); // 9
        set.add(0b10001); // 17
        
        set.remove(1); // remove from head
        
        set.add(0b00001); // 1
        set.add(0b01001); // 9
        set.add(0b10001); // 17
    
        set.remove(17); // remove from middle
        
        bar("removeFromCollisionChainBug done!");
    }
    
    
    @Test
    public void removeBug() {
        bar("removeBug");
        
        for (int i = 0; i < 9; i++) 
            set.add(i);
        
        for (int i = 0; i < 9; i++) 
            set.remove(i);
        
        bar("removeBug done!");
    }
    
    @Test
    public void removeFirstMiddleLast() {
        bar("removeFirstMiddleLast");
        
        // All three ints will end up in the same collision chain:
        set.add(1);  // 0b00001
        set.add(9);  // 0b01001
        set.add(17); // 0b10001
        
        assertTrue(set.contains(1));
        assertTrue(set.contains(9));
        assertTrue(set.contains(17));
        
        set.remove(1);
        
        assertFalse(set.contains(1));
        assertTrue(set.contains(9));
        assertTrue(set.contains(17));
        
        set.remove(17);
        
        assertFalse(set.contains(1));
        assertTrue(set.contains(9));
        assertFalse(set.contains(17));
        
        set.remove(9);
        
        assertFalse(set.contains(1));
        assertFalse(set.contains(9));
        assertFalse(set.contains(17));
        
        bar("removeFirstMiddleLast done!");
    }

    @Test
    public void add() {
        bar("add");
        
        for (int i = 0; i < 500; i++) {
            set.add(i);
        }

        for (int i = 0; i < 500; i++) {
            assertTrue(set.contains(i));
        }

        for (int i = 500; i < 1_000; i++) {
            assertFalse(set.contains(i));
        }

        for (int i = 450; i < 550; i++) {
            set.remove(i);
        }

        for (int i = 450; i < 1_000; i++) {
            assertFalse(set.contains(i));
        }

        for (int i = 0; i < 450; i++) {
            assertTrue(set.contains(i));
        }
        
        bar("add done!");
    }

    @Test
    public void contains() {
        bar("contains");
        
        set.add(10);
        set.add(20);
        set.add(30);

        for (int i = 1; i < 40; i++) {
            if (i % 10 == 0) {
                assertTrue(set.contains(i));
            } else {
                assertFalse(set.contains(i));
            }
        }
        
        bar("contains done!");
    }

    @Test
    public void remove() {
        bar("remove");
        
        set.add(1);
        set.add(2);
        set.add(3);
        set.add(4);
        set.add(5);

        set.remove(2);
        set.remove(4);

        set.add(2);

        assertFalse(set.contains(4));

        assertTrue(set.contains(1));
        assertTrue(set.contains(2));
        assertTrue(set.contains(3));
        assertTrue(set.contains(5));
        
        bar("remove done!");
    }

    @Test
    public void clear() {
        bar("clear");
        
        for (int i = 0; i < 100; i++) {
            set.add(i);
        }

        for (int i = 0; i < 100; i++) {
            assertTrue(set.contains(i));
        }

        set.clear();

        for (int i = 0; i < 100; i++) {
            assertFalse(set.contains(i));
        }
        
        bar("clear done!");
    }

    @Test 
    public void bruteForceAdd() {
        bar("bruteForceAdd");
        
        Random random = new Random(13L);

        int[] data = new int[10_000];

        for (int i = 0; i < data.length; i++) {
            int datum = random.nextInt(5_000);
            data[i] = datum;
            set.add(datum);
        }

        for (int i = 0; i < data.length; i++) {
            assertTrue(set.contains(data[i]));
        }
        
        bar("bruteForceAdd done!");
    }

    @Test
    public void bruteForceRemove() {
        bar("bruteForceRemove");
        
        Random random = new Random(100L);

        int[] data = new int[10_000];

        for (int i = 0; i < data.length; i++) {
            int datum = random.nextInt(5_000);
            data[i] = datum;
            set.add(datum);
        }

        shuffle(data, random);

        for (int i = 0; i < data.length; i++) {
            int datum = data[i];

            if (set.contains(datum)) {
                set.remove(datum);
                if (set.contains(datum)) 
                    System.out.println("found i = " + i);
            } 

            assertFalse(set.contains(datum));
        }
        
        bar("bruteForceRemove done!");
    }

    private static void shuffle(int[] data, Random random) {
        for (int i = data.length - 1; i > 0; --i) {
            final int j = random.nextInt(i + 1);
            swap(data, i, j);
        }
    }

    private static void swap(int[] data, int i, int j) {
        int tmp = data[i];
        data[i] = data[j];
        data[j] = tmp;
    }
   
    private static void bar(String text) {
        System.out.println("--- " + text);
    }
}

Показатели производительности:

>>> Iteration: 1/5
    IntHashSet.add in 818
    HashSet.add in 1306

    IntHashSet.contains in 150
    HashSet.contains in 321

    IntHashSet.remove in 250
    HashSet.remove in 263

>>> Iteration: 2/5
    IntHashSet.add in 607
    HashSet.add in 1130

    IntHashSet.contains in 151
    HashSet.contains in 280

    IntHashSet.remove in 179
    HashSet.remove in 203

>>> Iteration: 3/5
    IntHashSet.add in 577
    HashSet.add in 1060

    IntHashSet.contains in 159
    HashSet.contains in 292

    IntHashSet.remove in 189
    HashSet.remove in 229

>>> Iteration: 4/5
    IntHashSet.add in 522
    HashSet.add in 891

    IntHashSet.contains in 166
    HashSet.contains in 316

    IntHashSet.remove in 193
    HashSet.remove in 233

>>> Iteration: 5/5
    IntHashSet.add in 665
    HashSet.add in 940

    IntHashSet.contains in 160
    HashSet.contains in 349

    IntHashSet.remove in 199
    HashSet.remove in 232

Запрос на критику

Пожалуйста, расскажите мне все, что придет в голову! ^^

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *