Skip to content

Instantly share code, notes, and snippets.

@eladn
Created October 18, 2023 10:15
Show Gist options
  • Save eladn/b8f102c1606be370bd0688893e71a4ed to your computer and use it in GitHub Desktop.
Save eladn/b8f102c1606be370bd0688893e71a4ed to your computer and use it in GitHub Desktop.
PyTorch batch sampler for non-balanced subsets
__author__ = "Elad Nachmias"
__email__ = "eladnah@gmail.com"
__date__ = "2023-10-18"
import dataclasses
import enum
import itertools
import warnings
from collections import defaultdict, Counter
from typing import Optional, Dict, List, Sequence, Iterable, Union, Set, Tuple
import numpy as np
import torch
from scipy.stats import binom
from torch.utils.data import Sampler
class SamplingPolicyKind(enum.Enum):
DownSample = 'DownSample'
CombinedUpDownSample = 'CombinedUpDownSample'
@property
def is_up_sampling(self) -> bool:
return self == SamplingPolicyKind.CombinedUpDownSample
@dataclasses.dataclass(frozen=True)
class DownSamplingInfo:
most_constraining_subset_idx: int
expected_epoch_size: int
first_batch_constraining_factors: np.ndarray
constraining_factors: np.ndarray
@dataclasses.dataclass(frozen=True)
class CombinedUpDownSamplingInfo:
percentages_of_subsets: np.ndarray
required_ratio_per_percentage: np.ndarray
class StratifiedSubsetsBatchSampler(Sampler[Sequence[int]]):
"""
Produces batches of samples indices. The samples are sampled from the different subsets wrt the given desired distribution. That is,
the expected %samples from this subset in an average batch should match the predefined desired probability of this subset.
Predefined fixed probabilities may be given in the form of percentages 'XX%' string. For such cases, the same given percentage is kept
as-is in the eventual distribution without any further normalization.
Otherwise, pure floats/integers are referred as "dynamic weights". These are non-normalized floats per each subset. It has to be
normalized to form a discrete distribution. It completes the complementary part of the sum of predefined percentages.
If a subset is not explicitly given in the weighing configuration, but does present in the data, it would be implicitly assigned with a
dynamic weight of 1. Weights can be defined as zeros (in that case, these subsets are discarded).
There are two sampling policies. In both policies the expected representation of a subset in a batch is wrt the chosen distribution.
The two policies differ in the overall epoch size and the amount of samples observed from each subset before concluding an epoch.
Note that a subset can be assigned with a high probability, but originally have very few samples available in it, comparing to
another subset with low probability to get chosen but with much bigger pool of samples. This means that some of the subsets may be
down-sampled more than others. The different policies treat this unbalancing differently.
* One policy is "pure sub-sampling". In this approach, when a subset is the first to get exhausted, the epoch immediately ends. Note
that with this policy, for cases when there's a significantly small subset with relatively high representation, other subsets would
be significantly sub-sampled. Note that each epoch re-shuffles the samples within each subset. So eventually (infinite number of
epochs) all samples of every subset would be consumed anyway.
* The second (default) policy combines both up-sampling together with sub-sampling. That is, unlike the pure down-sampling policy,
the combined policy requires some minimal amount of samples from each subset to be consumed before concluding the epoch. We don't
expect to cover the entire dataset in each epoch, as there might be some very big subsets that are represented in a low sampling
ratio that would cause undesired huge epochs. However, we still want to enforce some notion of dataset coverage, so we define a more
relaxed coverage condition. To conclude an epoch, the percentiles of: [(%samples consumed from s) for s in subsets] are required to
reach some predefined ratios. For example, at least 25% of the subsets should produce at least 80% of their data, and at least 75%
of the subsets should produce at least 5% of their data. The configuration is based on such 2 pairs of subsets-percentage &
subset-size-ratio. Having a per-epoch data coverage guarantee provides an additional semantic meaning of the atomic "unit" of an
epoch, in the sense that the model has effectively traversed a sufficient representative of the data.
Note that during an epoch, for each subset its samples are being re-shuffled every time this subset is exhausted. Thus, the most
constraining subset would be re-shuffled several times, while the least constraining subset would be shuffled only once in the very
beginning. This guarantees that all samples have been seen for k times before the 1st time a sample is seen for k+1 times. In that
sense, the combined up&down sampling policy is different from a simpler alternative of memory-less sampling.
Read more about the epoch stop-condition of the combined sampling policy inside the doc of `_should_stop_epoch()`.
"""
def __init__(
self,
batch_size: int,
subset_name_by_sample_idx: Sequence[str],
subsets_weighting: Dict[str, Union[str, float, int]],
generator: Optional[torch.Generator] = None,
drop_last_partial_batch: bool = False,
sampling_policy: SamplingPolicyKind = SamplingPolicyKind.CombinedUpDownSample,
stop_condition_subset_percentiles_lower_bounds: Optional[List[Dict[str, float]]] = None,
min_ratio_of_subsets_reached_induced_epoch_size: Optional[float] = None,
required_ratio_of_met_percentiles: float = 0.75,
size_ratio_margin: float = 5e-3):
"""
:param batch_size: Desired batch size. Might be smaller in sub-sampling if `drop_last_partial_batch` is False.
:param subset_name_by_sample_idx: Name of the subset that each sample belongs to. Sequential container (list) ordered by sample idx.
:param subsets_weighting: Explicit percentage / auto-normalized weight for each subset by name. Missing subsets treated as
auto-normalized weight of 1. Read more in the class main docstr above.
:param generator: Optional. PyTorch random number generator. Used both for (i) sampling the subsets that will participate in
each batch wrt given probabilities, and (ii) shuffling the samples within each subset.
:param drop_last_partial_batch: Relevant only for pure sub-sampling. Whether to discard incomplete batch.
:param sampling_policy: Whether pure sub-sample or combined up&sub-sampling. Read more about policies in the class main doc above.
:param stop_condition_subset_percentiles_lower_bounds: For up sampling's stop condition. Read more in stop condition doc.
:param min_ratio_of_subsets_reached_induced_epoch_size: For up sampling's stop condition. Read more in stop condition doc.
:param required_ratio_of_met_percentiles: For up sampling. Requires the flipped sigmoid function condition to be held for this
ratio of percentiles.
:param size_ratio_margin: For up sampling. Additively relaxes the sigmoid function requirement by constant -0.5% (as default).
"""
super().__init__(data_source=None)
self._batch_size = batch_size
self._generator = generator
if drop_last_partial_batch and sampling_policy.is_up_sampling:
raise ValueError(f'`drop_last_partial_batch` is not relevant for up-sampling policies, as the stopping condition is '
f'checked only once an epoch and not within each batch as in the pure sub-sampling policy.')
self._drop_last_partial_batch = drop_last_partial_batch
self._sampling_policy = sampling_policy
if self._sampling_policy.is_up_sampling and stop_condition_subset_percentiles_lower_bounds is None:
stop_condition_subset_percentiles_lower_bounds = [
{'subsets_percent': 0.25, 'size_ratio': 0.80},
{'subsets_percent': 0.75, 'size_ratio': 0.05},
]
self._stop_condition_subset_percentiles_lower_bounds = stop_condition_subset_percentiles_lower_bounds
self._min_ratio_of_subsets_reached_induced_epoch_size = min_ratio_of_subsets_reached_induced_epoch_size
self._required_ratio_of_met_percentiles = required_ratio_of_met_percentiles
self._size_ratio_margin = size_ratio_margin
self._subsets_sampling_probs = self._parse_subsets_weighting_cfg(
subsets_weighting_cfg=subsets_weighting,
given_subset_names=set(subset_name_by_sample_idx))
del subsets_weighting # to avoid re-using after removal of non-relevant subsets
# Note: this is after the removal af zero-probability subsets.
subset_names_set = set(self._subsets_sampling_probs.keys())
self._sorted_subset_names = tuple(sorted(subset_names_set)) # tuple for immutability
self._nr_subsets = len(subset_names_set)
self._subsets_sampling_probs_vector = np.array([
self._subsets_sampling_probs[subset_name] for subset_name in self._sorted_subset_names])
# Sample indices contained in each subset.
subset_names_to_samples_idx_map: Dict[str, List] = defaultdict(list)
for sample_idx, subset_name in enumerate(subset_name_by_sample_idx):
if subset_name not in subset_names_set:
continue
subset_names_to_samples_idx_map[subset_name].append(sample_idx)
del subset_name_by_sample_idx # to avoid re-using after removal of non-relevant subsets
self._subset_names_to_samples_indices_map = {k: np.array(v) for k, v in subset_names_to_samples_idx_map.items()}
self._subset_sizes = np.array([
len(self._subset_names_to_samples_indices_map[subset_name])
for subset_name in self._sorted_subset_names])
# Calculate start&end points per subset to simulate subset sampling (wrt given probabilities) using uniform random float within
# the interval [0,1]. The epsilon value we calculate below is for safely comparing floats to identify the range that contains
# any given floating point in the closed interval [0,1].
# Probs are already sanitized above this stage:
assert np.all((~np.isclose(self._subsets_sampling_probs_vector, 0.)) & (self._subsets_sampling_probs_vector > 0))
self._probs_limits_comparing_eps = min(1e-6, 0.01 * np.min(self._subsets_sampling_probs_vector))
assert self._probs_limits_comparing_eps > 0.
self._subsets_weighting_ends = np.cumsum(self._subsets_sampling_probs_vector) # inclusive cumulative sum
self._subsets_weighting_ends[-1] += self._probs_limits_comparing_eps # as when checking, the end should be strictly larger
self._subsets_weighting_starts = self._subsets_weighting_ends - self._subsets_sampling_probs_vector
# For each individual subset, we calculate the expected total number of samples that would be taken (from all subsets) until
# that subset is exhausted, assuming that subset is the only constraining one (restricted number of available samples to choose
# from for this subset only).
self._epoch_size_induced_by_subset = self._subset_sizes / self._subsets_sampling_probs_vector
self._down_sampling_info: Optional[DownSamplingInfo] = None
self._up_sampling_info: Optional[CombinedUpDownSamplingInfo] = None
if self.is_up_sampling:
self._up_sampling_info = self._calc_stop_condition_required_subset_sizes_ratios()
else:
self._down_sampling_info = self._calc_down_sampling_info()
self._check_and_warn_too_restrictive_subsets()
def _calc_down_sampling_info(self) -> 'DownSamplingInfo':
assert not self._sampling_policy.is_up_sampling
# The most constraining subset is the one that induces minimal epoch size.
most_constraining_subset_idx = np.argmin(self._epoch_size_induced_by_subset)
# Approximate estimator for the expected number of sample that would be produced by the sampler until exhausting the first subset.
expected_epoch_size = self._epoch_size_induced_by_subset[most_constraining_subset_idx]
# For each subset, probability to exhaust this subset before forming the first batch. That is, choosing it more times for the batch
# than the number of available samples in this subset.
first_batch_constraining_factors = self.calc_constraining_factors_by_epoch_size(self._batch_size)
# For each subset, probability to be the first exhausted subset in the entire epoch.
# Note: The constraining factors are calculated wrt a concrete given "epoch size". This is the number of samples that are drawn
# up to some point. Effectively, to identify the most restrictive subsets in the entire iteration, we calculate these factor at
# the point we estimate the iteration would be over (first exhausted subset).
constraining_factors = self.calc_constraining_factors_by_epoch_size(expected_epoch_size)
return DownSamplingInfo(
most_constraining_subset_idx=most_constraining_subset_idx,
expected_epoch_size=expected_epoch_size,
first_batch_constraining_factors=first_batch_constraining_factors,
constraining_factors=constraining_factors)
def _calc_stop_condition_required_subset_sizes_ratios(self) -> 'CombinedUpDownSamplingInfo':
# Calculate the subsets coverage requirement by fitting a (y-flipped) sigmoid that maps from a percentile of the subsets (in [0, 1])
# to the ratio (again, in [0,1]) of which the subset should be consumed.
assert self._sampling_policy.is_up_sampling
assert len(self._stop_condition_subset_percentiles_lower_bounds) == 2
percent_1 = self._stop_condition_subset_percentiles_lower_bounds[0]['subsets_percent']
percentile_size_ratio_1 = self._stop_condition_subset_percentiles_lower_bounds[0]['size_ratio']
percent_2 = self._stop_condition_subset_percentiles_lower_bounds[1]['subsets_percent']
percentile_size_ratio_2 = self._stop_condition_subset_percentiles_lower_bounds[1]['size_ratio']
assert np.sign(percent_1 - percent_2) * np.sign(percentile_size_ratio_1 - percentile_size_ratio_2) < 0, \
'subsets ratios constrains are expected to be monotonous decreasing'
# Analytically find (a, b) such that the (y-flipped) sigmoid 1 - 1 / (1 + exp(-a * (x+b))) passes through the 2 points given
# above (defined by the user). The given points can be seen as a general "example" of the coverage requirement function shape.
t1 = np.log(percentile_size_ratio_1 / (1 - percentile_size_ratio_1))
t2 = np.log(percentile_size_ratio_2 / (1 - percentile_size_ratio_2))
exp_scale = (t2 - t1) / (percent_1 - percent_2)
exp_translate = (t2 * percent_1 - t1 * percent_2) / (t1 - t2)
flipped_sigmoid_fn = lambda x: 1 - 1 / (1 + np.exp(-exp_scale * (x + exp_translate)))
# Here we sample the fitted sigmoid in the relevant points (wrt the #subsets we actually have).
nr_percentiles = self._nr_subsets + 1 # to span over [0%, 100%]
percentages_of_subsets = np.linspace(0., 1., num=nr_percentiles)
required_ratios = flipped_sigmoid_fn(percentages_of_subsets)
required_ratios = required_ratios[np.newaxis, ...].repeat(self.nr_subsets, axis=0)
return CombinedUpDownSamplingInfo(
percentages_of_subsets=percentages_of_subsets,
required_ratio_per_percentage=required_ratios)
def _check_and_warn_too_restrictive_subsets(
self,
max_allowed_constraining_factor: float = 0.05,
min_desired_percent_constraining_subsets: float = 20.,
big_subsets_factor: float = 2.,
max_percent_of_big_subsets: float = 70.):
"""
Emit warnings indicating about badly aligned subsets. That is, the most-constraining subset induces epoch size that is
substantially smaller than the epochs induced by many other subsets. Also warns when the first batch may not be created completely.
"""
assert not self._sampling_policy.is_up_sampling # the restrictions checked here are relevant only for pure sub-sampling policy.
most_constraining_subset_name = self._sorted_subset_names[self.down_sampling_info.most_constraining_subset_idx]
if self._down_sampling_info.expected_epoch_size < self._batch_size:
warnings.warn(
f'Expected epoch size [{int(round(self._down_sampling_info.expected_epoch_size)):,}] (induced by most constraining subset '
f'`{most_constraining_subset_name}` is smaller than target batch size [{self._batch_size:,}].')
if np.any(self._down_sampling_info.first_batch_constraining_factors > max_allowed_constraining_factor):
drop_last_additional_msg = ''
if self._drop_last_partial_batch:
drop_last_additional_msg = \
f'Note that drop_last_partial_batch is set, so not completing the 1st batch would practically yield nothing. '
warnings.warn(
f'Limited 1st batch: The following subsets may prevent forming a complete *first* batch (of size {self._batch_size:,}) '
f'with their corresponding probabilities. {drop_last_additional_msg}' +
('; '.join(
f'{self._sorted_subset_names[subset_idx]} '
f'[{100. * self._down_sampling_info.first_batch_constraining_factors[subset_idx]:.0f}%]'
for subset_idx in np.where(
self._down_sampling_info.first_batch_constraining_factors > max_allowed_constraining_factor)[0]
)))
limited_epoch_size_msgs = []
nr_constraining_subsets = np.sum(self._down_sampling_info.constraining_factors > max_allowed_constraining_factor)
percent_constraining_subsets = 100 * nr_constraining_subsets / self._nr_subsets
if percent_constraining_subsets < min_desired_percent_constraining_subsets:
limited_epoch_size_msgs.append(
f'Only {percent_constraining_subsets:.0f}% of the subsets are >{100. * max_allowed_constraining_factor:.0f}% '
f'probable to exhaust after reaching the expected epoch size.')
percent_of_big_subsets = \
100 * np.mean(self._epoch_size_induced_by_subset > self._down_sampling_info.expected_epoch_size * big_subsets_factor)
if percent_of_big_subsets > max_percent_of_big_subsets:
limited_epoch_size_msgs.append(
f'Too many other subsets ({percent_of_big_subsets:.0f}%) are comparably much less-constraining, as they induce epoch sizes '
f'that are bigger by x{big_subsets_factor:.2f} than the expected epoch size.')
if len(limited_epoch_size_msgs) > 0:
warnings.warn(
f'Very limited data utilization: Most-constraining subset `{most_constraining_subset_name}` induces very small '
f'expected epoch size [={int(round(self._down_sampling_info.expected_epoch_size)):,}]. ' +
(' '.join(limited_epoch_size_msgs)) +
f' Here are the most constraining subsets with their corresponding exhaust probabilities at the expected end of epoch: ' +
('; '.join(
f'{self._sorted_subset_names[subset_idx]}: {100. * self._down_sampling_info.constraining_factors[subset_idx]:.0f}%'
for subset_idx in np.where(self._down_sampling_info.constraining_factors > max_allowed_constraining_factor)[0]
)) + '.')
"""
Warning example (taken from test):
UserWarning: Very limited data utilization: Most-constraining subset `A2LS` induces very small expected epoch size [=174,280].
Only 8% of the subsets are >5% probable to exhaust after reaching the expected epoch size. Too many other subsets (75%) are
comparably much less-constraining, as they induce epoch sizes that are bigger by x2.00 than the expected epoch size. Here are
the most constraining subsets with their corresponding exhaust probabilities at the expected end of epoch: A2LS: 50%.
"""
@classmethod
def _parse_subsets_weighting_cfg(
cls,
subsets_weighting_cfg: Dict[str, Union[str, float, int]],
given_subset_names: Set[str]) -> Dict[str, float]:
if subsets_weighting_cfg is None:
subsets_weighting_cfg = {}
# Parse (optional) predefined fixed percentages `XX%`. These keep the same given percentage as-is in the eventual distribution
# without further normalization.
subsets_weighting = {k: v for k, v in subsets_weighting_cfg.items() if k in given_subset_names}
subsets_weighting.update({k: 1. for k in given_subset_names - subsets_weighting_cfg.keys()})
predefined_fixed_ratios = {
subset_name: float(weight[:-1]) / 100.
for subset_name, weight in subsets_weighting.items()
if isinstance(weight, str) and weight.endswith('%')}
tot_predefined_percentages = sum(predefined_fixed_ratios.values())
assert tot_predefined_percentages <= 1. or np.isclose(tot_predefined_percentages, 1.)
# Parse (optional) "dynamic weights". These are non-normalized floats per each subset. It has to be normalized to form a discrete
# distribution. It completes the complementary part of the sum of predefined percentages.
dynamic_weights = {
subset_name: float(weight)
for subset_name, weight in subsets_weighting.items()
if subset_name not in predefined_fixed_ratios
}
tot_dynamic_weights = sum(dynamic_weights.values())
if tot_dynamic_weights > 0 and np.isclose(tot_predefined_percentages, 1.):
raise ValueError(f'Non-zero dynamic weights are given, but the pre-defined fixed percentages are at 100%.')
residual_non_predefined_ratio = 1. - tot_predefined_percentages
dynamic_weights = {
subset_name: residual_non_predefined_ratio * weight / tot_dynamic_weights
for subset_name, weight in dynamic_weights.items()
}
ret_subsets_weighting = {**predefined_fixed_ratios, **dynamic_weights}
# Discard subsets with zero probability.
ret_subsets_weighting = {k: v for k, v in ret_subsets_weighting.items() if ~np.isclose(v, 0.)}
return ret_subsets_weighting
def nr_samples_in_subset(self, subset_name: str) -> int:
return len(self._subset_names_to_samples_indices_map.get(subset_name, ()))
@property
def sorted_subset_names(self) -> Tuple[str, ...]:
return self._sorted_subset_names
@property
def down_sampling_info(self) -> Optional['DownSamplingInfo']:
return self._down_sampling_info
@property
def batch_size(self) -> int:
return self._batch_size
@property
def nr_subsets(self) -> int:
return self._nr_subsets
def calc_constraining_factors_by_epoch_size(self, epoch_size: int) -> np.ndarray:
# For each subset, probability to exhaust the samples before producing `epoch_size` samples. Using the Binomial distribution, as
# this is equivalent to the complement of probability to success in at least (thus cumulative) k out of n Bernoulli independent
# experiments with success probability p.
return np.array([
1. - binom.cdf(k=subset_size, n=epoch_size, p=p)
for p, subset_size in zip(self._subsets_sampling_probs_vector, self._subset_sizes)
])
def _sample_subsets_indices(self, n: int, generator: torch.Generator) -> np.ndarray:
subset_probs = \
torch.rand(size=(n,), dtype=torch.float32, generator=generator).numpy()
# Extend last dim per each subset for the containing ranges identification below.
extended_subset_probs = subset_probs[..., np.newaxis].repeat(self._nr_subsets, axis=-1)
# Identify the interval (corresponds with a subset) that these sampled float values fall in.
# Note that for the i-th subset, the corresponding interval is closed on the left and open in the right [start, end). We use the
# predefined epsilon to perform safe floating-point comparisons.
chosen_subsets_mask = \
(self._subsets_weighting_starts - self._probs_limits_comparing_eps < extended_subset_probs) & \
(extended_subset_probs < self._subsets_weighting_ends)
chosen_subsets_indices = np.argmax(chosen_subsets_mask, axis=1)
assert np.all(np.any(chosen_subsets_mask, axis=-1))
return chosen_subsets_indices
def _initialize_iterator_state(self) -> 'IteratorInternalState':
# Note: If we get a rng in the c'tor, it lives throughout the entire life of the sampler (that is, preceding calls to iter() will
# affect it and won't get the same result. Otherwise, we create a generator per iteration with a random seed, similarly to the
# convention of other PyTorch samplers implementations.
if self._generator is None:
generator = torch.Generator()
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
else:
generator = self._generator
# Here we preliminarily permute the samples of each individual subset. These permutations are kept along the course of the current
# iteration (epoch). This guarantees both (i) non-repetition of same samples in different batches; (ii) batching samples together
# pseudo-randomly; and (iii) pseudo-randomly sub-sampling big (less constraining) subsets.
# Note: These are not permutations over the original sample indices, but over the indices of entries within each list-of-samples
# per subset. That is, for some subset `s`, that has k samples, the permutations indices are in {0, ..., k-1}. and the value `i`
# refers to the i-th entry within the list `self._subset_names_to_samples_indices_map[s]`.
permutation_per_subset = {
subset_name: self._create_within_subset_permutation(subset_name=subset_name, generator=generator)
for subset_name in self._sorted_subset_names}
# We update the followings during the batches production (after each batch).
nr_taken_samples_per_subset = np.full(shape=len(self._sorted_subset_names), fill_value=0)
nr_available_samples_per_subset = np.array([
self.nr_samples_in_subset(subset_name=subset_name)
for subset_name in self._sorted_subset_names])
return IteratorInternalState(
generator=generator,
permutation_per_subset=permutation_per_subset,
nr_taken_samples_per_subset=nr_taken_samples_per_subset,
nr_taken_samples_per_subset_in_cur_perm=np.copy(nr_taken_samples_per_subset),
nr_available_samples_per_subset=nr_available_samples_per_subset,
tot_nr_produced_samples=0,
is_last_batch=False)
def _create_within_subset_permutation(self, subset_name: str, generator: torch.Generator):
return torch.randperm(self.nr_samples_in_subset(subset_name=subset_name), generator=generator).numpy()
def __iter__(self) -> Iterable[Sequence[int]]:
"""
In each iteration, the iterator produces a list (batch) of sample indices that are sampled wrt the given distribution. The sampling
involves both sampling of the representation of each subset in the yielded batch, and also sampling the items themselves that are
being drawn from each subset.
The lifetime of an iterator usually correspond with a training epoch.
In the pure sub-sampling policy, the iterator keeps producing batches until the first (most constraining) subset is exhausted.
In this case, We stop sampling at this point to preserve the distribution. That is, always in the end of the iteration at least one
subset is fully produced, while other less-constraining subsets are sub-sampled.
Read about the stopping condition of the combined up&down sampling policy in `_should_stop_epoch()`.
"""
iterator_state = self._initialize_iterator_state()
for batch_idx in itertools.count():
if self._should_stop_epoch(iterator_state=iterator_state):
iterator_state.is_last_batch = True
break
chosen_subsets_indices = self._sample_subsets_indices(n=self._batch_size, generator=iterator_state.generator)
batch = []
for chosen_subset_idx in chosen_subsets_indices:
if self._should_stop_batch(chosen_subset_idx=chosen_subset_idx, iterator_state=iterator_state):
# Tried to consume sample from a previously-exhausted subset for the 1st time.
# End batch to preserve the distribution. So far, the drawn sample is iid.
iterator_state.is_last_batch = True
break
if self.is_up_sampling and iterator_state.nr_available_samples_per_subset[chosen_subset_idx] == 0:
subset_name = self._sorted_subset_names[chosen_subset_idx]
new_permutation = self._create_within_subset_permutation(subset_name=subset_name, generator=iterator_state.generator)
iterator_state.permutation_per_subset[subset_name] = new_permutation
iterator_state.nr_available_samples_per_subset[chosen_subset_idx] = self._subset_sizes[chosen_subset_idx]
iterator_state.nr_taken_samples_per_subset_in_cur_perm[chosen_subset_idx] = 0
sample_idx = self._draw_sample_from_subset(subset_idx=chosen_subset_idx, iterator_state=iterator_state)
batch.append(sample_idx)
assert iterator_state.is_last_batch or len(batch) == self._batch_size
if iterator_state.is_last_batch and (len(batch) == 0 or self._drop_last_partial_batch):
if batch_idx == 0:
# We didn't manage to complete even the 1st batch. This is a corner case. It might be caused due to high batch-size
# and too restrictive subset.
exhausted_subset_idx = int(np.argmax(iterator_state.nr_available_samples_per_subset < 1))
exhausted_subset_name = self._sorted_subset_names[exhausted_subset_idx]
max_constraining_subset_idx = int(np.argmax(self._down_sampling_info.first_batch_constraining_factors))
first_batch_most_constraining_subset_name = self._sorted_subset_names[max_constraining_subset_idx]
warnings.warn(
f'Could not form even the 1st batch in epoch. Fully exhausted subset {exhausted_subset_name} (totally '
f'containing {self._subset_sizes[exhausted_subset_idx]:,} samples, and sampled to batch with probability '
f'{100. * self._subsets_sampling_probs_vector[exhausted_subset_idx]:.0f}%, inducing batch constraining factor '
f'{100. * self._down_sampling_info.first_batch_constraining_factors[exhausted_subset_idx]:.0f}%). '
f'Most constraining subset is {first_batch_most_constraining_subset_name}. Consider decreasing batch size (chosen '
f'{self._batch_size}).')
return # firstly reached an exhausted subset - end of epoch - *before* producing last partial batch
yield batch
if iterator_state.is_last_batch:
return # firstly reached an exhausted subset - end of epoch - *after* producing last partial batch
def _draw_sample_from_subset(self, subset_idx: int, iterator_state: 'IteratorInternalState') -> int:
subset_name = self._sorted_subset_names[subset_idx]
assert iterator_state.nr_available_samples_per_subset[subset_idx] >= 1
# Use the pre-drawn permutation to effectively simulate sampling an item from the subset.
within_subset_idx = \
iterator_state.permutation_per_subset[subset_name][iterator_state.nr_taken_samples_per_subset_in_cur_perm[subset_idx]]
sample_idx = self._subset_names_to_samples_indices_map[subset_name][within_subset_idx]
iterator_state.tot_nr_produced_samples += 1
iterator_state.nr_taken_samples_per_subset[subset_idx] += 1
iterator_state.nr_taken_samples_per_subset_in_cur_perm[subset_idx] += 1
iterator_state.nr_available_samples_per_subset[subset_idx] -= 1
return sample_idx
@property
def is_up_sampling(self) -> bool:
return self._sampling_policy.is_up_sampling
def _should_stop_epoch(self, iterator_state: 'IteratorInternalState') -> bool:
"""
For up-sampling policy, check both:
(i) At least one subset is 100% exhausted (probably the most constraining one).
(ii) Percentiles of [(%samples consumed from s) for s in subsets] reach required ratios;
In the combined up&down sampling policy, we don't expect to cover the entire dataset, as there might be some very big
subsets that are represented in a low sampling ratio. However, we still want to enforce some notion of dataset coverage,
so below we define a more relaxed coverage condition.
The user provides two pairs of (percentile, required subset ratio). Then, the sampler analytically interpolates a monotonously
decreasing (flipped) sigmoid that passes through these 2 points. We use this fitted sigmoid to calculate all the relevant
percentiles. For example, if there are 4 subsets then the interesting percentiles to check are {25%, 50%, 75%, 100%}.
(iii) The #samples in the current epoch >= the 15% (by default, configurable) percentile of induced epoch sizes of all subsets.
"""
if not self.is_up_sampling:
return False # in the non-up-sampling case, the stop condition is already being checked within each sample in the batch
subsets_consumption_ratios = (iterator_state.nr_taken_samples_per_subset / self._subset_sizes)
at_least_one_subset_exhausted = np.any(np.isclose(subsets_consumption_ratios, 1.) | (subsets_consumption_ratios > 1.))
# Check that enough #subsets consumed enough of their overall volume. For example, at least 25% of the subsets consumed at least
# 80% of their overall volume, and at least 75% of the subsets consumed at least 5% of their volume.
subsets_consumption_ratios = \
subsets_consumption_ratios[..., np.newaxis].repeat(len(self._up_sampling_info.percentages_of_subsets), axis=1)
size_ratio_eps = max(1e-6, self._size_ratio_margin) # additionally to margin, also avoid floating-point comparison numerical error
percent_eps = min(1e-6, 0.1 * self._up_sampling_info.percentages_of_subsets[1]) # avoid floating-point comparison numerical error
assert percent_eps > 0
is_sample_consumption_requirement_met_per_percentile = \
np.mean(subsets_consumption_ratios > self._up_sampling_info.required_ratio_per_percentage - size_ratio_eps, axis=0) > \
self._up_sampling_info.percentages_of_subsets - percent_eps
ratio_of_subset_percentiles_meet_sample_consumption_requirement = np.mean(is_sample_consumption_requirement_met_per_percentile)
consumed_enough_samples = \
ratio_of_subset_percentiles_meet_sample_consumption_requirement > self._required_ratio_of_met_percentiles - percent_eps
# Each subset induces an individual effective epoch size (as if it was the only constraining subset). Check whether the current
# epoch produced enough sampled that reach the induced epoch size of enough subsets.
enough_subsets_reached_induced_epoch = True
if self._min_ratio_of_subsets_reached_induced_epoch_size is not None:
ratio_of_subsets_reached_induced_epoch_size = \
np.mean(self._epoch_size_induced_by_subset <= iterator_state.tot_nr_produced_samples)
enough_subsets_reached_induced_epoch = \
ratio_of_subsets_reached_induced_epoch_size >= self._min_ratio_of_subsets_reached_induced_epoch_size
should_stop = at_least_one_subset_exhausted and consumed_enough_samples and enough_subsets_reached_induced_epoch
return should_stop
def _should_stop_batch(self, chosen_subset_idx: int, iterator_state: 'IteratorInternalState') -> bool:
return not self.is_up_sampling and iterator_state.nr_available_samples_per_subset[chosen_subset_idx] == 0
@dataclasses.dataclass
class IteratorInternalState:
"""
Encapsulates the relevant fields that are used during an iteration (epoch).
This design patten allows inner aux methods to receive and update the state of the iteration.
"""
generator: torch.Generator
permutation_per_subset: Dict[str, np.ndarray]
nr_taken_samples_per_subset: np.ndarray
nr_taken_samples_per_subset_in_cur_perm: np.ndarray
nr_available_samples_per_subset: np.ndarray
tot_nr_produced_samples: int
is_last_batch: bool
@dataclasses.dataclass
class StratifiesSamplerTestInputs:
sampling_policy: SamplingPolicyKind
subset_names: List[str]
subset_name_to_idx: Dict[str, int]
subset_name_by_sample_idx: List[str]
tot_dataset_size: int
overall_nr_available_samples_per_subset: np.ndarray
desired_subsets_weighting_cfg: Dict[str, Union[str, float, int]]
desired_subsets_representation_ratio: np.ndarray
stop_condition_subset_percentiles_lower_bounds: List[Dict[str, float]]
nr_epochs_to_test: int
batch_size: int
rng_seed: int = 42
@property
def nr_subsets(self) -> int:
return len(self.subset_names)
def _init_stratified_subset_batch_sampler_test_input(sampling_policy: SamplingPolicyKind) -> StratifiesSamplerTestInputs:
subset_names = ["s1", "s2", "s3", "s4", "s5", "s6", "s7"]
# We ensure the subset names get the same indices as promised by the sampler for later calling & checking internal aux calculations
# of the sampler like the constraining factors.
subset_names = sorted(subset_names)
# Define the actual sizes of the subsets.
overall_available_samples_balance_across_subsets = {
# Ascending availability - these demonstrate inverse availability<->desired balance
"s1": 1 / 22, "s2": 2 / 22, "s3": 3 / 22, "s4": 4 / 22, "s5": 5 / 22,
# This subset will get zero desired weight although having >0 available samples
"s6": 3 / 22,
# This subset will have predefined desired representation percentage. Ensure that it is not the most-restrictive one.
"s7": 4 / 22
}
assert set(subset_names) == overall_available_samples_balance_across_subsets.keys()
assert np.isclose(sum(overall_available_samples_balance_across_subsets.values()), 1)
# Calculate the actual sizes wrt the desired total dataset size
tot_estimated_dataset_size = 1_000_000
overall_nr_available_samples_per_subset = np.array([
overall_available_samples_balance_across_subsets[subset_name] for subset_name in subset_names])
overall_nr_available_samples_per_subset = overall_nr_available_samples_per_subset / np.sum(overall_nr_available_samples_per_subset)
overall_nr_available_samples_per_subset = np.round(overall_nr_available_samples_per_subset * tot_estimated_dataset_size).astype(int)
tot_dataset_size = int(sum(overall_nr_available_samples_per_subset))
# This is given as the weighting input to the sampler.
desired_subsets_weighting_cfg = {
"s1": 5., "s2": 4., "s3": 3., "s4": 2., "s5": 1., # inverse to availability
"s6": 0., # Verify that zero weight is supported (for some subset)
"s7": '25%' # predefined percentage
}
# Calculate the final effective representation percentages. As an array sorted by the subset-name.
desired_subsets_representation_ratio = {
"s1": 0.25, "s2": 0.20, "s3": 0.15, "s4": 0.10, "s5": 0.05, # These are 75%, each unit is 5%
"s6": 0.,
"s7": 0.25
}
desired_subsets_representation_ratio = np.array([desired_subsets_representation_ratio[subset_name] for subset_name in subset_names])
stop_condition_subset_percentiles_lower_bounds = [
{'subsets_percent': 0.25, 'size_ratio': 0.80},
{'subsets_percent': 0.75, 'size_ratio': 0.05},
]
if sampling_policy == SamplingPolicyKind.DownSample:
# In pure down-sampling the most constraining subset cause little epochs. We'd like to test several batches per epoch.
batch_size = 10
else:
# In the combined up&down sampling, the up-sampling allows having big batches (and we want to verify it).
batch_size = 1_000
# The sampler requires this mapping (sample idx --> subset name)
subset_name_by_sample_idx = [
subset_name
for subset_name, nr_samples_in_subset in zip(subset_names, overall_nr_available_samples_per_subset)
for _ in range(nr_samples_in_subset)]
return StratifiesSamplerTestInputs(
sampling_policy=sampling_policy,
subset_names=subset_names,
subset_name_to_idx={subset_name: idx for idx, subset_name in enumerate(subset_names)},
subset_name_by_sample_idx=subset_name_by_sample_idx,
tot_dataset_size=tot_dataset_size,
overall_nr_available_samples_per_subset=overall_nr_available_samples_per_subset,
desired_subsets_weighting_cfg=desired_subsets_weighting_cfg,
desired_subsets_representation_ratio=desired_subsets_representation_ratio,
stop_condition_subset_percentiles_lower_bounds=stop_condition_subset_percentiles_lower_bounds,
nr_epochs_to_test=5,
batch_size=batch_size
)
@dataclasses.dataclass
class StratifiesSamplerTestResults:
sampler: StratifiedSubsetsBatchSampler
batches_samples_indices_per_epoch: List[List[List[int]]]
def execute_stratified_subset_batch_sampler_test_input(
test_input: StratifiesSamplerTestInputs) -> StratifiesSamplerTestResults:
rng = torch.Generator()
rng.manual_seed(test_input.rng_seed)
sampler = StratifiedSubsetsBatchSampler(
batch_size=test_input.batch_size,
subset_name_by_sample_idx=test_input.subset_name_by_sample_idx,
subsets_weighting=test_input.desired_subsets_weighting_cfg,
drop_last_partial_batch=False,
generator=rng,
sampling_policy=test_input.sampling_policy,
stop_condition_subset_percentiles_lower_bounds=test_input.stop_condition_subset_percentiles_lower_bounds if
test_input.sampling_policy.is_up_sampling else None,
min_ratio_of_subsets_reached_induced_epoch_size=0.25 if test_input.sampling_policy.is_up_sampling else None)
batches_samples_indices_per_epoch: List[List[List[int]]] = []
for epoch_idx in range(test_input.nr_epochs_to_test):
batches_samples_indices = [list(indices) for indices in sampler]
batches_samples_indices_per_epoch.append(batches_samples_indices)
return StratifiesSamplerTestResults(
sampler=sampler,
batches_samples_indices_per_epoch=batches_samples_indices_per_epoch
)
def test_stratified_subset_batch_sampler():
"""
The test creates a maximal unbalanced corner case where the initial number of available samples per subset is opposite than the desired
representation of the subsets. That is, the initial biggest subset is the less-represented one. This creates a maximal variance in the
induced epoch sizes. The test also checks the 2 different manners to provide subsets representation: both predefined percentage and
dynamic weight. Additionally, we check the support for providing zero weight. This test checks both pure-down and also combined up&down
sampling policies.
"""
# Test both sampling policies
for sampling_policy in SamplingPolicyKind:
test_input = _init_stratified_subset_batch_sampler_test_input(sampling_policy=sampling_policy)
execution_results = execute_stratified_subset_batch_sampler_test_input(test_input=test_input)
epoch_sample_firstly_seen: Dict[int, int] = {} # mapping: sample idx -> epoch idx
for epoch_idx in range(test_input.nr_epochs_to_test):
nr_occurrences_per_sample_idx_in_cur_epoch: Dict[int, int] = defaultdict(int)
sample_indices_per_nr_occurrences_per_subset: Dict[str, Dict[int, Set[int]]] = defaultdict(lambda: defaultdict(set))
for sample_idx in range(test_input.tot_dataset_size):
subset_name = test_input.subset_name_by_sample_idx[sample_idx]
sample_indices_per_nr_occurrences_per_subset[subset_name][0].add(sample_idx)
nr_new_samples_seen_in_cur_epoch = 0
batches_samples_indices = execution_results.batches_samples_indices_per_epoch[epoch_idx]
for batch_idx, batch in enumerate(batches_samples_indices):
for sample_idx in batch:
if sampling_policy == SamplingPolicyKind.DownSample:
# In pure down-sampling, each sample is expected to occur at most once in an epoch
assert sample_idx not in nr_occurrences_per_sample_idx_in_cur_epoch
# Check that for each subset, when a sample is seen for k times, then all other samples are seen k/k+1 times.
subset_name = test_input.subset_name_by_sample_idx[sample_idx]
sample_indices_per_nr_occurrences = sample_indices_per_nr_occurrences_per_subset[subset_name]
sample_indices_per_nr_occurrences[nr_occurrences_per_sample_idx_in_cur_epoch[sample_idx]].remove(sample_idx)
if len(sample_indices_per_nr_occurrences[nr_occurrences_per_sample_idx_in_cur_epoch[sample_idx]]) == 0:
sample_indices_per_nr_occurrences.pop(nr_occurrences_per_sample_idx_in_cur_epoch[sample_idx])
nr_occurrences_per_sample_idx_in_cur_epoch[sample_idx] += 1
sample_indices_per_nr_occurrences[nr_occurrences_per_sample_idx_in_cur_epoch[sample_idx]].add(sample_idx)
min_nr_occurrences_per_sample_idx = min(sample_indices_per_nr_occurrences.keys())
max_nr_occurrences_per_sample_idx = max(sample_indices_per_nr_occurrences.keys())
assert max_nr_occurrences_per_sample_idx - min_nr_occurrences_per_sample_idx in {0, 1}
# If this sample is seen for the first time (among all epochs), log the current epoch idx as the 1st one that this
# sample is encountered at. This aggregation is being checked later on.
if sample_idx not in epoch_sample_firstly_seen:
epoch_sample_firstly_seen[sample_idx] = epoch_idx
nr_new_samples_seen_in_cur_epoch += 1
assert nr_new_samples_seen_in_cur_epoch > 0 # sanity check
# Batches (except last one) are expected to be full.
assert all(len(batch) == test_input.batch_size for batch in batches_samples_indices[:-1])
tot_nr_samples_taken = sum(len(batch) for batch in batches_samples_indices)
# In pure down sample, it the most constraining subset wouldn't allow to exhaust all subset.
if sampling_policy == SamplingPolicyKind.DownSample:
assert tot_nr_samples_taken <= test_input.tot_dataset_size
# Aggregate the #occurrences per subset.
chosen_subsets_indices_counter = Counter(
test_input.subset_name_to_idx[test_input.subset_name_by_sample_idx[sample_idx]]
for batch in batches_samples_indices for sample_idx in batch)
nr_taken_samples_per_subset = np.full(shape=test_input.nr_subsets, fill_value=np.nan)
freqs = np.full(shape=test_input.nr_subsets, fill_value=0.)
for subset_idx, count in chosen_subsets_indices_counter.items():
freqs[subset_idx] = count / tot_nr_samples_taken
nr_taken_samples_per_subset[subset_idx] = count
# Check that a subset occurred at least once i.f.f it has desired weight >0.
desired_zero_mask = np.isclose(test_input.desired_subsets_representation_ratio, 0.) # subsets are zero represented
assert np.sum(desired_zero_mask) == 1 # test input sanity check - it should be only s6
assert np.isclose(freqs[desired_zero_mask], 0.), 'zero represented subset should not be chosen at all by the sampler'
assert np.all(~np.isclose(freqs, 0.) | desired_zero_mask), '>0 represented subsets should occur at least once'
# Check that the actual subsets representation realization is close enough to the desired representation.
abs_freq_err = np.abs(freqs - test_input.desired_subsets_representation_ratio)
# The rel err is wrt the ground-truth desired distribution.
rel_freq_err = abs_freq_err[~desired_zero_mask] / test_input.desired_subsets_representation_ratio[~desired_zero_mask]
if sampling_policy.is_up_sampling:
# In up sampling, we can expect a closer similarity between realization and desired distribution, as the sample size is
# big enough.
allowed_rel_err = 0.1
else:
# In pure down sampling, the total sample size is too small (as the most constraining subset is very restrictive), and thus
# at the time when the epoch ends, the total #taken samples is too small.
allowed_rel_err = 0.5
assert np.all(rel_freq_err < allowed_rel_err), 'subset occurrence freq in practice differs too much from desired representation'
# Verify that the priorly expected most-constraining subset is indeed exhausted.
if sampling_policy == SamplingPolicyKind.DownSample:
posterior_constraining_factors = \
execution_results.sampler.calc_constraining_factors_by_epoch_size(epoch_size=tot_nr_samples_taken)
posterior_most_constraining_subset_idx = np.argmax(posterior_constraining_factors)
most_constraining_subset_name = \
execution_results.sampler.sorted_subset_names[execution_results.sampler.down_sampling_info.most_constraining_subset_idx]
prior_most_constraining_subset_idx = test_input.subset_name_to_idx[most_constraining_subset_name]
exhausted_subsets_mask = nr_taken_samples_per_subset >= test_input.overall_nr_available_samples_per_subset
assert exhausted_subsets_mask[prior_most_constraining_subset_idx]
assert np.sum(exhausted_subsets_mask) >= 1
assert exhausted_subsets_mask[posterior_most_constraining_subset_idx]
nr_unique_encountered_samples = len(epoch_sample_firstly_seen)
ratio_unique_encountered_samples = nr_unique_encountered_samples / test_input.tot_dataset_size
if sampling_policy.is_up_sampling:
# In up sampling, more overall unique samples are expected to be encountered.
assert 0.65 < ratio_unique_encountered_samples < 0.70
else:
# In down sampling, the most constraining subset is very restrictive, which why less samples are expected to be encountered.
assert 0.45 < ratio_unique_encountered_samples < 0.50
if __name__ == '__main__':
test_stratified_subset_batch_sampler()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment