Skip to content

Instantly share code, notes, and snippets.

@nicolaferraro
Created September 4, 2016 09:55
Show Gist options
  • Save nicolaferraro/201efb0fc3c725afff8a1f06b75faea5 to your computer and use it in GitHub Desktop.
Save nicolaferraro/201efb0fc3c725afff8a1f06b75faea5 to your computer and use it in GitHub Desktop.
Spark try-map functionality for Java
package it.test;
import java.io.Serializable;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.Accumulator;
import org.apache.spark.AccumulatorParam;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
/**
*
*/
public class TryMapJava {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setMaster("local").setAppName("Test Exceptions");
JavaSparkContext sc = new JavaSparkContext(conf);
Accumulator<HashSet<Throwable>> exceptions = sc.accumulator(new HashSet<>(), new HashSetAccumulatorParam<>());
// // Original
// long result = sc.parallelize(IntStream.range(1, 100).boxed().collect(Collectors.toList()))
// .map(TryMapJava::myConversion)
// .count();
// With try
long result = sc.parallelize(IntStream.range(1, 100).boxed().collect(Collectors.toList()))
.flatMap(TryWrapper.of(exceptions, TryMapJava::myConversion))
.count();
// Printing the result
System.out.println(result);
int i = 1;
for (Throwable t : exceptions.value()) {
System.out.println("Exception " + (i++) + ": " + t.getMessage());
}
}
public static String myConversion(int num) {
if (num % 5 == 0) {
throw new RuntimeException("A forced exception...");
}
return ">> " + num;
}
// Classes used to implement the functionality
public static class TryWrapper<S, R> implements FlatMapFunction<S, R>, Serializable {
private Accumulator<HashSet<Throwable>> exceptions;
private Function<S, R> function;
private TryWrapper(Accumulator<HashSet<Throwable>> exceptions, Function<S, R> function) {
this.exceptions = exceptions;
this.function = function;
}
public static <S, R> TryWrapper<S, R> of(Accumulator<HashSet<Throwable>> exceptions, Function<S, R> function) {
return new TryWrapper<S, R>(exceptions, function);
}
@Override
public Iterable<R> call(S s) throws Exception {
List<R> resultIterable = new LinkedList<>();
try {
R result = function.call(s);
resultIterable.add(result);
} catch (Throwable t) {
HashSet exc = new HashSet();
exc.add(t);
exceptions.add(exc);
}
return resultIterable;
}
}
public static class HashSetAccumulatorParam<T> implements AccumulatorParam<HashSet<T>> {
@Override
public HashSet<T> addAccumulator(HashSet<T> t1, HashSet<T> t2) {
t1.addAll(t2);
return t1;
}
@Override
public HashSet<T> addInPlace(HashSet<T> r1, HashSet<T> r2) {
return addAccumulator(r1, r2);
}
@Override
public HashSet<T> zero(HashSet<T> initialValue) {
return new HashSet<>();
}
}
}
@AjayKhetan
Copy link

when i imported the above code getting below error.
'call(S)' in 'TryMapJava.TryWrapper' clashes with 'call(T)' in 'org.apache.spark.api.java.function.FlatMapFunction'; attempting to use incompatible return type

@fazut
Copy link

fazut commented Aug 11, 2020

package it.test;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.spark.AccumulatorParam;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.util.AccumulatorV2;
import org.apache.spark.util.CollectionAccumulator;

public class TryMapJava {

    public static void main(String[] args) {

        SparkConf conf = new SparkConf().setMaster("local").setAppName("Test Exceptions");
        JavaSparkContext sc = new JavaSparkContext(conf);

        CollectionAccumulator<Throwable> exceptions = sc.sc().collectionAccumulator();

//        // Original
//        long result = sc.parallelize(IntStream.range(1, 100).boxed().collect(Collectors.toList()))
//                .map(TryMapJava::myConversion)
//                .count();

        // With try
        long result = sc.parallelize(IntStream.range(1, 100).boxed().collect(Collectors.toList()))
                .flatMap(TryWrapper.of(exceptions, TryMapJava::myConversion))
                .count();

        // Printing the result
        System.out.println(result);

        int i = 1;
        for (Throwable t : exceptions.value()) {
            System.out.println("Exception " + (i++) + ": " + t.getMessage());
        }

    }

    public static String myConversion(int num) {
        if (num % 5 == 0) {
            throw new RuntimeException("A forced exception...");
        }
        return ">> " + num;
    }

    public static class TryWrapper<S, R> implements FlatMapFunction<S, R>, Serializable {

        private CollectionAccumulator<Throwable> exceptions;

        private Function<S, R> function;

        private TryWrapper(CollectionAccumulator<Throwable> exceptions, Function<S, R> function) {
            this.exceptions = exceptions;
            this.function = function;
        }

        public static <S, R> TryWrapper<S, R> of(CollectionAccumulator<Throwable> exceptions, Function<S, R> function) {
            return new TryWrapper<>(exceptions, function);
        }

        @Override
        public Iterator<R> call(S s) throws Exception {
            List<R> resultIterable = new LinkedList<>();
            try {
                R result = function.call(s);
                resultIterable.add(result);
            } catch (Throwable t) {
                exceptions.add(t);
            }

            return resultIterable.iterator();
        }

    }

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment