Skip to content

Instantly share code, notes, and snippets.

@dylon
Last active December 22, 2015 10:59
Show Gist options
  • Save dylon/6462512 to your computer and use it in GitHub Desktop.
Save dylon/6462512 to your computer and use it in GitHub Desktop.
Demonstrates how to randomly sort a collection of elements, where each class may have a different weight (all weights must be non-negative and sum to 1.0)
import java.util.Map;
import java.util.HashMap;
public class WeightedRandomTest {
private static final int NUM_CLASSIFICATIONS = 1_000_000;
private static Classification classify() {
double factor = Math.random();
for (final Classification classification : Classification.values()) {
factor -= classification.getWeight();
if (factor <= 0.0) {
return classification;
}
}
return null;
}
public static void main(final String[] args) {
final Map<Classification, Integer> histogram = new HashMap<>();
for (final Classification classification : Classification.values()) {
histogram.put(classification, 0);
}
for (int i = 0; i < NUM_CLASSIFICATIONS; ++i) {
final Classification classification = classify();
histogram.put(classification, 1 + histogram.get(classification));
}
for (final Classification classification : Classification.values()) {
System.out.printf("%s{expected=%.5f, actual=%.5f}\n",
classification,
classification.getWeight(),
(double) histogram.get(classification) / NUM_CLASSIFICATIONS);
}
}
private enum Classification {
TRAIN(0.60),
TEST(0.20),
CROSS_VALIDATION(0.20);
private final double weight;
private Classification(final double weight) {
this.weight = weight;
}
public double getWeight() {
return weight;
}
}
}
@dylon
Copy link
Author

dylon commented Sep 6, 2013

Sample output:

TRAIN{expected=0.60000, actual=0.60054}
TEST{expected=0.20000, actual=0.19918}
CROSS_VALIDATION{expected=0.20000, actual=0.20028}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment