Created
February 23, 2015 23:32
-
-
Save mapio/57299694ef94cc88dddb to your computer and use it in GitHub Desktop.
An implementation of an argmax collector using the Java 8 stream APIs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.Comparator; | |
import java.util.List; | |
import java.util.stream.Collector; | |
class ArgMaxCollector<T> { | |
private T max = null; | |
private ArrayList<T> argMax = new ArrayList<T>(); | |
private Comparator<? super T> comparator; | |
private ArgMaxCollector( Comparator<? super T> comparator ) { | |
this.comparator = comparator; | |
} | |
public void accept( T element ) { | |
int cmp = max == null ? -1 : comparator.compare( max, element ); | |
if ( cmp < 0 ) { | |
max = element; | |
argMax.clear(); | |
argMax.add( element ); | |
} else if ( cmp == 0 ) | |
argMax.add( element ); | |
} | |
public void combine( ArgMaxCollector<T> other ) { | |
int cmp = comparator.compare( max, other.max ); | |
if ( cmp < 0 ) { | |
max = other.max; | |
argMax = other.argMax; | |
} else if ( cmp == 0 ) { | |
argMax.addAll( other.argMax ); | |
} | |
} | |
public List<T> get() { | |
return argMax; | |
} | |
public static <T> Collector<T, ArgMaxCollector<T>, List<T>> collector( Comparator<? super T> comparator ) { | |
return Collector.of( | |
() -> new ArgMaxCollector<T>( comparator ), | |
( a, b ) -> a.accept( b ), | |
( a, b ) ->{ a.combine(b); return a; }, | |
a -> a.get() | |
); | |
} | |
} | |
public class ArgMax { | |
public static void main( String[] args ) { | |
List<String> names = Arrays.asList( new String[] { "one", "two", "three", "four", "five", "six", "seven" } ); | |
Collector<String, ArgMaxCollector<String>, List<String>> argMax = ArgMaxCollector.collector( Comparator.comparing( String::length ) ); | |
Collector<String, ArgMaxCollector<String>, List<String>> argMin = ArgMaxCollector.collector( Comparator.comparing( String::length ).reversed() ); | |
System.out.println( names.stream().collect( argMax ) ); | |
System.out.println( names.stream().collect( argMin ) ); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[three, seven] | |
[one, two, six] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment