import java.util.Iterator; import java.util.concurrent.atomic.AtomicStampedReference; import java.util.ArrayList; import java.util.List; public class SetImpl> implements Set { private class Node { final T item; AtomicStampedReference nextRef; Node(T item) { this.item = item; this.nextRef = new AtomicStampedReference<>(null, 0); } Node(T item, Node next, int stamp) { this.item = item; this.nextRef = new AtomicStampedReference<>(next, stamp); } } private class Bounds { final Node lower, upper; Bounds(Node lower, Node upper) { this.lower = lower; this.upper = upper; } } private final Node head; public SetImpl() { head = new Node(null); head.nextRef = new AtomicStampedReference<>(null, 0); } @Override public boolean add(T value) { while (true) { Bounds bounds = find(value); Node prev = bounds.lower; Node curr = bounds.upper; if (curr != null && curr.item != null && value.compareTo(curr.item) == 0) { return false; } Node node = new Node(value, curr, 0); AtomicStampedReference setTo = (prev == null) ? head.nextRef : prev.nextRef; int[] stampHolder = {0}; setTo.get(stampHolder); if (stampHolder[0] == -1) continue; if (setTo.compareAndSet(curr, node, stampHolder[0], stampHolder[0] + 1)) { return true; } } } @Override public boolean remove(T value) { while (true) { Bounds bounds = find(value); Node prev = bounds.lower; Node curr = bounds.upper; if (curr == null || curr.item == null || value.compareTo(curr.item) != 0) { return false; } int[] stamp = {0}; Node next = curr.nextRef.get(stamp); if (stamp[0] == -1) { return false; } if (!curr.nextRef.compareAndSet(next, next, stamp[0], -1)) { continue; } AtomicStampedReference setTo = (prev == null) ? head.nextRef : prev.nextRef; int[] prevStamp = {0}; Node beforeCur = setTo.get(prevStamp); setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1); return true; } } @Override public boolean contains(T value) { Node curr = head.nextRef.getReference(); while (curr != null) { int[] stamp = {0}; Node next = curr.nextRef.get(stamp); if (curr.item != null && curr.item.compareTo(value) >= 0) { return curr.item.compareTo(value) == 0 && stamp[0] != -1; } curr = next; } return false; } @Override public boolean isEmpty() { return iterator().hasNext(); } @Override public Iterator iterator() { return new Iterator() { private final List snapshot; private int index = 0; { snapshot = new ArrayList<>(); while (true) { List tmpNodes = new ArrayList<>(); List tmpStamps = new ArrayList<>(); List tmpSnapshot = new ArrayList<>(); Node curr = head.nextRef.getReference(); while (curr != null) { int[] stamp = {0}; Node next = curr.nextRef.get(stamp); if (stamp[0] != -1 && curr.item != null) { tmpNodes.add(curr); tmpStamps.add(stamp[0]); tmpSnapshot.add(curr.item); } curr = next; } boolean consistent = true; for (int i = 0; i < tmpNodes.size(); i++) { if (tmpNodes.get(i).nextRef.getStamp() != tmpStamps.get(i)) { consistent = false; break; } } if (consistent) { snapshot.addAll(tmpSnapshot); break; } } } @Override public boolean hasNext() { return index < snapshot.size(); } @Override public T next() { return snapshot.get(index++); } }; } private Bounds find(T value) { retry: while (true) { Node prev = null; Node curr = head.nextRef.getReference(); while (true) { if (curr == null) return new Bounds(prev, null); int[] stamp = {0}; Node next = curr.nextRef.get(stamp); if (stamp[0] == -1) { AtomicStampedReference setTo = (prev == null) ? head.nextRef : prev.nextRef; int[] prevStamp = {0}; Node before = setTo.get(prevStamp); if (before != curr || !setTo.compareAndSet(curr, next, prevStamp[0], prevStamp[0] + 1)) { continue retry; } curr = next; continue; } if (curr.item == null || curr.item.compareTo(value) < 0) { prev = curr; curr = next; } else { return new Bounds(prev, curr); } } } } }