Skip to content

Instantly share code, notes, and snippets.

@sohang3112
Last active August 22, 2024 07:02
Show Gist options
  • Save sohang3112/f9cbd71fcabaf70855b1f5261e7db5e7 to your computer and use it in GitHub Desktop.
Save sohang3112/f9cbd71fcabaf70855b1f5261e7db5e7 to your computer and use it in GitHub Desktop.
Efficient set of integers, maintained using bit shift operations on an integer.
from __future__ import annotations
from typing import Iterable
class IntSet:
"""
Efficiently store integers in a set - internally uses bit shift operations.
NOTE: len() isn't supported - instead cast to list first: len(list(intset))
>>> IntSet(0b10011) # Set of all the bits which are set to 1 (using 1-based bit indexing)
IntSet([1, 2, 5])
>>> s = IntSet().add_from_iterable([1,2,7,1])
>>> list(s) # duplicates in input iterable were removed
[1, 2, 7]
>>> s.add(8)
>>> s
IntSet([1, 2, 7, 8])
>>> s.discard(1)
>>> 1 in s
False
>>> s2 = IntSet().add_from_iterable([2,3,5,7])
>>> s.union(s2)
IntSet([2, 3, 5, 7, 8])
>>> s.intersection(s2)
IntSet([2, 7])
>>> s.difference(s2)
IntSet([8])
>>> IntSet().add_from_iterable(range(1,6)).has_first_n(5)
True
"""
def __init__(self, bits: int = 0):
self._added = bits
def __bool__(self) -> int:
return self._added != 0
def __contains__(self, x: int) -> bool:
return 1 <= x and (self._added >> (x - 1)) & 1 == 1
def __iter__(self):
n = self._added
i = 1
while n > 0:
if n & 1:
yield i
i += 1
n //= 2
def __repr__(self) -> str:
return "{}({})".format(type(self).__name__, list(self))
def has_first_n(self, n: int) -> bool:
"""Does set have all of 1,2..n ?"""
return self._added == (1 << n) - 1
def add(self, x: int) -> None:
assert x > 0, f"{x} cannot be added because it is negative or 0"
self._added |= 1 << (x - 1)
def add_from_iterable(self, elems: Iterable[int]) -> IntSet:
for x in elems:
self.add(x)
return self
def discard(self, x: int) -> None:
"""Remove element from set, if present"""
self._added &= ~(1 << (x - 1))
def union(self, intset: IntSet) -> IntSet:
return IntSet(self._added | intset._added)
def intersection(self, intset: IntSet) -> IntSet:
return IntSet(self._added & intset._added)
def difference(self, intset: IntSet) -> IntSet:
return IntSet(self._added & ~intset._added)
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment