Last active
December 19, 2017 20:16
-
-
Save adgaudio/a7a3ae68232ff1068e8b4086b3dd013d to your computer and use it in GitHub Desktop.
A weighted counter that remembers the most frequent and recent pairs on a 2-color graph.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.HashMap; | |
/* A weighted counter that remembers most frequent and recent pairs on a 2-color graph, where: | |
* - any pair (a_i, b_i) contains elements a_i from set A and elements b_i are from set B. A and B are disjoint. | |
* | |
* This counter basically implements a recurrence relation to maintain scores for each pair: | |
* score = memory * prev_score + (1-memory) * (+/-)1 | |
* | |
* "memory" is a value between 0 and 1 chooses how much history history to take into account. | |
* | |
* updateScores takes O(size(B)) time. | |
* getClosestMatch takes O(1) time. | |
* | |
* Space is O(size(A) * size(B) + size(A)), so only use for small sets or if the graph is sparse (few distinct pairs). | |
* */ | |
public class DecayCounter<A, B> { | |
private HashMap<A, HashMap<B, Double>> weights; | |
private Double memory; | |
private HashMap<A, B> maxVals; | |
public static void main(String[] args) { | |
System.out.println("hello world"); | |
DecayCounter<String, String> t = new DecayCounter<>(); | |
t.updateScores("courier", "eventType"); | |
System.out.println(t.weights); | |
System.out.println(t.maxVals); | |
t.updateScores("a", "B"); | |
System.out.println(t.weights); | |
System.out.println(t.maxVals); | |
t.updateScores("a", "A"); | |
System.out.println(t.weights); | |
System.out.println(t.maxVals); | |
t.updateScores("b", "A"); | |
System.out.println(t.weights); | |
System.out.println(t.maxVals); | |
} | |
DecayCounter(Double memory) { | |
this.weights = new HashMap<>(); | |
this.memory = memory; | |
this.maxVals = new HashMap<>(); | |
assert memory <= 1 && memory >= 0; | |
} | |
DecayCounter() { | |
this(0.75); | |
} | |
public void initializeWeights(HashMap<A, HashMap<B, Double>> weights) { | |
this.weights = weights; | |
// initialize maxVals | |
for (HashMap.Entry<A, HashMap<B, Double>> entry : weights.entrySet()) { | |
HashMap<B, Double> scores = entry.getValue(); | |
Double maxScore = -1.0; | |
B maxElementTypeB = null; | |
for (HashMap.Entry<B, Double> score : scores.entrySet()) { | |
if (score.getValue() > maxScore) { | |
maxScore = score.getValue(); | |
maxElementTypeB = score.getKey(); | |
} | |
} | |
if (maxElementTypeB != null) { | |
maxVals.put(entry.getKey(), maxElementTypeB); | |
} | |
} | |
} | |
public B getClosestMatch(A elementTypeA) { | |
return maxVals.get(elementTypeA); | |
} | |
public B updateScores(A elementTypeA, B elementTypeB) { | |
HashMap<B, Double> scores = weights.get(elementTypeA); | |
if (scores == null) { | |
HashMap<B, Double> c = new HashMap<>(); | |
c.put(elementTypeB, 1 - this.memory); | |
weights.put(elementTypeA, c); | |
maxVals.put(elementTypeA, elementTypeB); | |
return elementTypeB; | |
} | |
Double maxScore = -1.0; | |
B maxElementTypeB = null; | |
Double tmpScore; | |
for (B e : scores.keySet()) { | |
if (e.equals(elementTypeB)) { | |
tmpScore = scores.get(e) * this.memory + (1 - this.memory); | |
} else { | |
tmpScore = scores.get(e) * this.memory - (1 - this.memory); | |
} | |
scores.put(e, tmpScore); | |
// prepare a return value | |
if (tmpScore > maxScore) { | |
maxScore = tmpScore; | |
maxElementTypeB = e; | |
} | |
} | |
// handle the cold start condition | |
if (!scores.containsKey(elementTypeB)) { | |
scores.put(elementTypeB, (1 - this.memory)); | |
if ((1 - this.memory) > maxScore) { | |
maxScore = 1 - this.memory; | |
maxElementTypeB = elementTypeB; | |
} | |
} | |
maxVals.put(elementTypeA, maxElementTypeB); | |
return maxElementTypeB; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment