Skip to content

Instantly share code, notes, and snippets.

@jnape
Created December 27, 2017 20:15
Show Gist options
  • Save jnape/89821e751917578cc8c7b69fc1e9ec77 to your computer and use it in GitHub Desktop.
Save jnape/89821e751917578cc8c7b69fc1e9ec77 to your computer and use it in GitHub Desktop.
Type-level encoding and optimization of tail recursive functions in Java with Lambda
package spike;
import com.jnape.palatable.lambda.adt.choice.Choice2;
import com.jnape.palatable.lambda.adt.coproduct.CoProduct2;
import com.jnape.palatable.lambda.functions.Fn1;
import com.jnape.palatable.lambda.functions.Fn2;
import com.jnape.palatable.lambda.functions.builtin.fn2.Cons;
import com.jnape.palatable.lambda.functor.Bifunctor;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.Function;
import static com.jnape.palatable.lambda.adt.Maybe.just;
import static com.jnape.palatable.lambda.adt.Maybe.nothing;
import static com.jnape.palatable.lambda.adt.choice.Choice2.a;
import static com.jnape.palatable.lambda.adt.choice.Choice2.b;
import static com.jnape.palatable.lambda.adt.hlist.Tuple2.fill;
import static com.jnape.palatable.lambda.functions.builtin.fn1.Id.id;
import static com.jnape.palatable.lambda.functions.builtin.fn1.Last.last;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Cons.cons;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Partition.partition;
import static com.jnape.palatable.lambda.functions.builtin.fn2.Unfoldr.unfoldr;
import static spike.Spike.ContinuationStackFrame.recurse;
import static spike.Spike.ContinuationStackFrame.terminate;
import static spike.Spike.Corecursive.corecursive;
import static spike.Spike.OptimizedRecursiveCallStack.optimizedRecursiveCallStack;
import static spike.Spike.Recursive.recursive;
import static spike.Spike.Trampoline.trampoline;
public class Spike {
public static abstract class ContinuationStackFrame<A, B> implements CoProduct2<A, B, ContinuationStackFrame<A, B>>, Bifunctor<A, B, ContinuationStackFrame<?, ?>> {
@Override
public <C> ContinuationStackFrame<C, B> biMapL(Function<? super A, ? extends C> fn) {
throw new UnsupportedOperationException();
}
@Override
public <C> ContinuationStackFrame<A, C> biMapR(Function<? super B, ? extends C> fn) {
throw new UnsupportedOperationException();
}
@Override
public <C, D> ContinuationStackFrame<C, D> biMap(Function<? super A, ? extends C> lFn,
Function<? super B, ? extends D> rFn) {
throw new UnsupportedOperationException();
}
public static <A, B> ContinuationStackFrame<A, B> recurse(A a) {
return new Recurse<>(a);
}
public static <A, B> ContinuationStackFrame<A, B> terminate(B b) {
return new Terminate<>(b);
}
private static final class Recurse<A, B> extends ContinuationStackFrame<A, B> {
private final A a;
private Recurse(A a) {
this.a = a;
}
@Override
public <R> R match(Function<? super A, ? extends R> aFn, Function<? super B, ? extends R> bFn) {
return aFn.apply(a);
}
@Override
public String toString() {
return "Recurse{" +
"a=" + a +
'}';
}
}
private static final class Terminate<A, B> extends ContinuationStackFrame<A, B> {
private final B b;
private Terminate(B b) {
this.b = b;
}
@Override
public <R> R match(Function<? super A, ? extends R> aFn, Function<? super B, ? extends R> bFn) {
return bFn.apply(b);
}
@Override
public String toString() {
return "Terminate{" +
"b=" + b +
'}';
}
}
}
public static final class Trampoline<A, B> implements Fn2<Function<? super A, ? extends CoProduct2<A, B, ?>>, A, B> {
private static final Trampoline INSTANCE = new Trampoline<>();
@Override
public B apply(Function<? super A, ? extends CoProduct2<A, B, ?>> fn, A a) {
CoProduct2<? extends A, ? extends B, ?> next = fn.apply(a);
while (next.match(__ -> true, __ -> false))
next = fn.apply(next.match(id(), __ -> null));
return next.match(__ -> null, id());
}
@SuppressWarnings("unchecked")
public static <A, B> Trampoline<A, B> trampoline() {
return INSTANCE;
}
public static <A, B> Fn1<A, B> trampoline(Function<? super A, ? extends CoProduct2<A, B, ?>> fn) {
return Trampoline.<A, B>trampoline().apply(fn);
}
public static <A, B> B trampoline(Function<? super A, ? extends CoProduct2<A, B, ?>> fn, A a) {
return trampoline(fn).apply(a);
}
}
public interface Recursive<A, B> extends Fn1<A, ContinuationStackFrame<A, B>> {
default Fn1<A, OptimizedRecursiveCallStack<A, B>> unroll() {
return a -> optimizedRecursiveCallStack(cons(recurse(a), fill(apply(a))
.fmap(unfoldr(tc -> tc.match(next -> just(fill(apply(next))), __ -> nothing())))
.into(Cons::cons)));
}
static <A, B> Recursive<A, B> recursive(
Function<? super A, ? extends CoProduct2<? extends A, ? extends B, ?>> f) {
return a -> f.apply(a).match(ContinuationStackFrame::recurse, ContinuationStackFrame::terminate);
}
default Fn1<A, B> trampoline() {
return Trampoline.trampoline(this);
}
}
public static final class Corecursive<A, B, C> implements Recursive<CoProduct2<A, B, ?>, C> {
private final Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f;
private final Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g;
private Corecursive(Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f,
Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g) {
this.f = f;
this.g = g;
}
@Override
public ContinuationStackFrame<CoProduct2<A, B, ?>, C> apply(CoProduct2<A, B, ?> ab) {
return ab.match(a -> f.apply(a).match(nextA -> recurse(Choice2.<A, B>a(nextA)),
bc -> bc.match(b -> recurse(Choice2.<A, B>b(b)), ContinuationStackFrame::terminate)),
b -> g.apply(b).match(nextB -> recurse(Choice2.<A, B>b(nextB)),
ac -> ac.match(a -> recurse(Choice2.<A, B>a(a)), ContinuationStackFrame::terminate)));
}
public static <A, B, C> Corecursive<A, B, C> corecursive(
Recursive<A, ? extends CoProduct2<? extends B, ? extends C, ?>> f,
Recursive<B, ? extends CoProduct2<? extends A, ? extends C, ?>> g
) {
return new Corecursive<>(f, g);
}
}
public static void main(String... args) {
Recursive<Integer, ContinuationStackFrame<Integer, Integer>> evens = recursive(i -> i == 1
? terminate(terminate(i))
: i % 2 == 0
? recurse(i / 2)
: terminate(recurse(i)));
Recursive<Integer, Choice2<Integer, Integer>> odds = recursive(i -> i == 1 ? b(b(i)) : i % 2 == 1 ? a((i * 3) + 1) : b(a(i)));
Corecursive<Integer, Integer, Integer> conjecture = corecursive(evens, odds);
trampoline(corecursive(recursive(evens), recursive(odds)));
Choice2<Integer, Integer> arg = Choice2.a(9);
System.out.println(optimizedRecursiveCallStack(conjecture.unroll().apply(arg)).roll());
System.out.println(trampoline(conjecture).apply(arg));
}
public static final class OptimizedRecursiveCallStack<A, B> implements Iterable<ContinuationStackFrame<A, B>> {
private final Iterable<ContinuationStackFrame<A, B>> stack;
private OptimizedRecursiveCallStack(Iterable<ContinuationStackFrame<A, B>> stack) {
this.stack = stack;
}
public B roll() {
return last(partition(id(), stack)._2()).orElseThrow(NoSuchElementException::new);
}
@Override
public Iterator<ContinuationStackFrame<A, B>> iterator() {
return stack.iterator();
}
public static <A, B> OptimizedRecursiveCallStack<A, B> optimizedRecursiveCallStack(
Iterable<ContinuationStackFrame<A, B>> stack) {
return new OptimizedRecursiveCallStack<>(stack);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment