Простой целочисленный хэш-набор Java — продолжение

После внесения изменений в предыдущий пост я придумал эту реализацию. Однако я оставил хеширование как есть.

com.github.coderodde.util.IntHashSet:

package com.github.coderodde.util;

/**
 * This class implements a simple hash set for non-negative {@code int} values.
 * It is used in the {@link com.github.coderodde.util.LinkedList} in order to 
 * keep track of nodes that are being pointed to by fingers.
 * 
 * @author Rodion "rodde" Efremov
 * @version 1.6 (Aug 29, 2021)
 * @since 1.6 (Aug 29, 2021)
 */
public class IntHashSet {

    private static final int INITIAL_CAPACITY = 8;
    private static final float MAXIMUM_LOAD_FACTOR = 0.75f;

    private static final class IntHashTableCollisionChainNode {
        IntHashTableCollisionChainNode next;
        int integer;

        IntHashTableCollisionChainNode(
                int integer, 
                IntHashTableCollisionChainNode next) {
            this.integer = integer;
            this.next = next;
        }

        @Override
        public String toString() {
            return "Chain node, integer = " + integer;
        }
    }

    private IntHashTableCollisionChainNode[] table = 
            new IntHashTableCollisionChainNode[INITIAL_CAPACITY];

    private int size = 0;
    private int mask = INITIAL_CAPACITY - 1;
    
    @Override
    public String toString() {
        return "size = " + size;
    }

    public void add(int integer) {
        if (contains(integer)) {
            return;
        }

        size++;

        if (shouldExpand())
            expand();

        final int targetCollisionChainIndex = integer & mask;
        final IntHashTableCollisionChainNode newNode = 
                new IntHashTableCollisionChainNode(
                        integer, 
                        table[targetCollisionChainIndex]);

        newNode.next = table[targetCollisionChainIndex];
        table[targetCollisionChainIndex] = newNode;
    }

    public boolean contains(int integer) {
        final int collisionChainIndex = integer & mask;
        IntHashTableCollisionChainNode node = table[collisionChainIndex];

        while (node != null) {
            if (node.integer == integer) {
                return true;
            }

            node = node.next;
        }

        return false;
    }

    public void remove(int integer) {
        if (!contains(integer)) {
            return;
        }

        size--;

        if (shouldContract()) 
            contract();

        final int targetCollisionChainIndex = integer & mask;

        IntHashTableCollisionChainNode current = 
                table[targetCollisionChainIndex];

        IntHashTableCollisionChainNode previous = null;

        while (current != null) {
            IntHashTableCollisionChainNode next = current.next;

            if (current.integer == integer) {
                if (previous == null) {
                    table[targetCollisionChainIndex] = next;
                } else {
                    previous.next = next;
                }

                return;
            }

            previous = current;
            current = next;
        }
    }

    public void clear() {
         size = 0;
         table = new IntHashTableCollisionChainNode[INITIAL_CAPACITY];
         mask = table.length - 1;
    }

    // Keep add(int) an amortized O(1)
    private boolean shouldExpand() {
        return size > table.length * MAXIMUM_LOAD_FACTOR;
    }

    // Keep remove(int) an amortized O(1)
    private boolean shouldContract() {
        if (table.length == INITIAL_CAPACITY) {
            return false;
        }
        
        final int maxCurrentQuota = (int)(table.length * MAXIMUM_LOAD_FACTOR);
        final int minCurrentQuota = maxCurrentQuota / 4;
        return size < minCurrentQuota;
    }

    private void expand() {
        IntHashTableCollisionChainNode[] newTable = 
                new IntHashTableCollisionChainNode[table.length * 2];

        rehash(newTable);
        table = newTable;
        mask = table.length - 1;
    }

    private void contract() {
        IntHashTableCollisionChainNode[] newTable = 
                new IntHashTableCollisionChainNode[table.length / 4];

        rehash(newTable);
        table = newTable;
        mask = table.length - 1;
    }

    private void rehash(IntHashTableCollisionChainNode[] newTable) {
        for (IntHashTableCollisionChainNode node : table) {
            while (node != null) {
                final IntHashTableCollisionChainNode next = node.next;
                final int rehashedIndex = getHashValue(node.integer, newTable);

                node.next = newTable[rehashedIndex];
                newTable[rehashedIndex] = node;
                node = next;
            }
        }
    }

    private static int getHashValue(
            int integer, 
            IntHashTableCollisionChainNode[] newTable) {
        return integer & (newTable.length - 1);
    }
}

com.github.coderodde.util.IntHashSetTest:

package com.github.coderodde.util;

import java.util.Random;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Before;
import org.junit.Test;

public class IntHashSetTest {

    private final IntHashSet set = new IntHashSet();

    @Before
    public void beforeTest() {
        set.clear();
    }
    
    @Test
    public void removeBug() {
        for (int i = 0; i < 9; i++) 
            set.add(i);
        
        for (int i = 0; i < 9; i++) 
            set.remove(i);
    }
    
    @Test
    public void removeFirstMiddleLast() {
        // All three ints will end up in the same collision chain:
        set.add(1);  // 0b00001
        set.add(9);  // 0b01001
        set.add(17); // 0b10001
        
        assertTrue(set.contains(1));
        assertTrue(set.contains(9));
        assertTrue(set.contains(17));
        
        set.remove(1);
        
        assertFalse(set.contains(1));
        assertTrue(set.contains(9));
        assertTrue(set.contains(17));
        
        set.remove(17);
        
        assertFalse(set.contains(1));
        assertTrue(set.contains(9));
        assertFalse(set.contains(17));
        
        set.remove(9);
        
        assertFalse(set.contains(1));
        assertFalse(set.contains(9));
        assertFalse(set.contains(17));
    }

    @Test
    public void add() {
        for (int i = 0; i < 500; i++) {
            set.add(i);
        }

        for (int i = 0; i < 500; i++) {
            assertTrue(set.contains(i));
        }

        for (int i = 500; i < 1_000; i++) {
            assertFalse(set.contains(i));
        }

        for (int i = 450; i < 550; i++) {
            set.remove(i);
        }

        for (int i = 450; i < 1_000; i++) {
            assertFalse(set.contains(i));
        }

        for (int i = 0; i < 450; i++) {
            assertTrue(set.contains(i));
        }
    }

    @Test
    public void contains() {
        set.add(10);
        set.add(20);
        set.add(30);

        for (int i = 1; i < 40; i++) {
            if (i % 10 == 0) {
                assertTrue(set.contains(i));
            } else {
                assertFalse(set.contains(i));
            }
        }
    }

    @Test
    public void remove() {
        set.add(1);
        set.add(2);
        set.add(3);
        set.add(4);
        set.add(5);

        set.remove(2);
        set.remove(4);

        set.add(2);

        assertFalse(set.contains(4));

        assertTrue(set.contains(1));
        assertTrue(set.contains(2));
        assertTrue(set.contains(3));
        assertTrue(set.contains(5));
    }

    @Test
    public void clear() {
        for (int i = 0; i < 100; i++) {
            set.add(i);
        }

        for (int i = 0; i < 100; i++) {
            assertTrue(set.contains(i));
        }

        set.clear();

        for (int i = 0; i < 100; i++) {
            assertFalse(set.contains(i));
        }
    }

    @Test 
    public void bruteForceAdd() {
        long seed = System.currentTimeMillis();

        System.out.println(
                "--- IntHashSetTest.bruteForceAdd: seed = " + seed + " ---");

        Random random = new Random(seed);

        int[] data = new int[10_000];

        for (int i = 0; i < data.length; i++) {
            int datum = random.nextInt(5_000);
            data[i] = datum;
            set.add(datum);
        }

        for (int i = 0; i < data.length; i++) {
            assertTrue(set.contains(data[i]));
        }
    }

    @Test
    public void bruteForceRemove() {
        long seed = System.currentTimeMillis();

        System.out.println(
                "--- IntHashSetTest.bruteForceRemove: seed = " + seed + " ---");

        Random random = new Random(seed);

        int[] data = new int[10_000];

        for (int i = 0; i < data.length; i++) {
            int datum = random.nextInt(5_000);
            data[i] = datum;
            set.add(datum);
        }

        shuffle(data, random);

        for (int i = 0; i < data.length; i++) {
            int datum = data[i];

            if (set.contains(datum)) {
                set.remove(datum);
            } 

            assertFalse(set.contains(datum));
        }
    }

    private static void shuffle(int[] data, Random random) {
        for (int i = data.length - 1; i > 0; --i) {
            final int j = random.nextInt(i + 1);
            swap(data, i, j);
        }
    }

    private static void swap(int[] data, int i, int j) {
        int tmp = data[i];
        data[i] = data[j];
        data[j] = tmp;
    }
}

Запрос на критику

Подскажите, пожалуйста, есть что улучшить?

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *