Created
June 8, 2023 21:37
-
-
Save forax/b9cd176b472f064ecefbae3178b748de to your computer and use it in GitHub Desktop.
A universal record builder with a TypeSafe API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.Serializable; | |
import java.lang.invoke.MethodHandle; | |
import java.lang.invoke.MethodHandles; | |
import java.lang.invoke.MethodHandles.Lookup; | |
import java.lang.invoke.MethodType; | |
import java.lang.invoke.SerializedLambda; | |
import java.lang.reflect.RecordComponent; | |
import java.lang.reflect.UndeclaredThrowableException; | |
import java.util.Arrays; | |
import java.util.Map; | |
import java.util.Objects; | |
import java.util.concurrent.ConcurrentHashMap; | |
import java.util.function.Function; | |
import java.util.stream.Collectors; | |
import java.util.stream.IntStream; | |
public class RecordBuilder<R extends Record> { | |
@FunctionalInterface | |
public interface SerializableFunction<T, R> extends Function<T, R>, Serializable { } | |
public static final class Factory { | |
private record RecordData(Map<String, Integer> slotMap, MethodType constructorType, Object[] prototype) {} | |
private static final ClassValue<RecordData> RECORD_DATA_CLASS_VALUE = new ClassValue<>() { | |
@Override | |
protected RecordData computeValue(Class<?> type) { | |
var components = type.getRecordComponents(); | |
if (components == null) { | |
throw new IllegalArgumentException("not a record " + type.getName()); | |
} | |
var slotMap = IntStream.range(0, components.length) | |
.boxed() | |
.collect(Collectors.toMap(i -> components[i].getName(), i -> i)); | |
var prototype = Arrays.stream(components) | |
.map(component -> defaultValue(component.getType())) | |
.toArray(); | |
var parameterTypes = Arrays.stream(components) | |
.map(RecordComponent::getType) | |
.toArray(Class<?>[]::new); | |
var constructorType = MethodType.methodType(void.class, parameterTypes); | |
return new RecordData(slotMap, constructorType, prototype); | |
} | |
private Object defaultValue(Class<?> type) { | |
if (!type.isPrimitive()) { | |
return null; | |
} | |
return switch (type.getName()) { | |
case "boolean" -> false; | |
case "byte" -> (byte) 0; | |
case "short" -> (short) 0; | |
case "char" -> '\0'; | |
case "int" -> 0; | |
case "long" -> 0L; | |
case "float" -> 0f; | |
case "double" -> 0.; | |
default -> throw new AssertionError("unknown primitive type"); | |
}; | |
} | |
}; | |
private final Lookup lookup; | |
private final ClassValue<MethodHandle> constructorCache = new ClassValue<>() { | |
@Override | |
protected MethodHandle computeValue(Class<?> type) { | |
var recordData = RECORD_DATA_CLASS_VALUE.get(type); | |
try { | |
return lookup.findConstructor(type, recordData.constructorType); | |
} catch (NoSuchMethodException e) { | |
throw (NoSuchMethodError) new NoSuchMethodError().initCause(e); | |
} catch (IllegalAccessException e) { | |
throw (IllegalAccessError) new IllegalAccessError().initCause(e); | |
} | |
} | |
}; | |
private final ConcurrentHashMap<Object, Integer> lambdaSlotCache = new ConcurrentHashMap<>(); | |
int getAccessorSlot(SerializableFunction<?,?> accessor) { | |
var cachedSlot = lambdaSlotCache.get(accessor); | |
if (cachedSlot != null) { | |
return cachedSlot; | |
} | |
var slot = computeLambdaSlot(accessor); | |
cachedSlot = lambdaSlotCache.putIfAbsent(accessor, slot); | |
if (cachedSlot == null) { | |
return slot; | |
} | |
return cachedSlot; | |
} | |
private int computeLambdaSlot(SerializableFunction<?, ?> accessor) { | |
var lambdaType = accessor.getClass(); | |
if (!lambdaType.isHidden()) { | |
throw new IllegalArgumentException("not a lambda"); | |
} | |
MethodHandle mh; | |
try { | |
mh = lookup.findVirtual(lambdaType, "writeReplace", MethodType.methodType(Object.class)); | |
} catch (NoSuchMethodException e) { | |
throw (NoSuchMethodError) new NoSuchMethodError().initCause(e); | |
} catch (IllegalAccessException e) { | |
throw (IllegalAccessError) new IllegalAccessError().initCause(e); | |
} | |
SerializedLambda lambda; | |
try { | |
lambda = (SerializedLambda) (Object) mh.invoke(accessor); | |
} catch (RuntimeException | Error e) { | |
throw e; | |
} catch(Throwable e) { | |
throw new UndeclaredThrowableException(e); | |
} | |
if (lambda.getCapturedArgCount() != 0) { | |
throw new IllegalArgumentException("capturing lambda"); | |
} | |
Class<?> recordType; | |
try { | |
recordType = lookup.findClass(lambda.getImplClass()); | |
} catch (ClassNotFoundException e) { | |
throw (NoSuchMethodError) new NoSuchMethodError().initCause(e); | |
} catch (IllegalAccessException e) { | |
throw (IllegalAccessError) new IllegalAccessError().initCause(e); | |
} | |
var recordData = RECORD_DATA_CLASS_VALUE.get(recordType); | |
return recordData.slotMap.get(lambda.getImplMethodName()); | |
} | |
private Factory(Lookup lookup) { | |
this.lookup = lookup; | |
} | |
public <R extends Record> RecordBuilder<R> builder(Class<R> recordType) { | |
Objects.requireNonNull(recordType); | |
var recordData = RECORD_DATA_CLASS_VALUE.get(recordType); | |
var constructor = constructorCache.get(recordType); | |
var prototype = recordData.prototype; | |
return new RecordBuilder<>(this, constructor, Arrays.copyOf(prototype, prototype.length)); | |
} | |
} | |
public static Factory factory(Lookup lookup) { | |
Objects.requireNonNull(lookup); | |
return new Factory(lookup); | |
} | |
private final Factory factory; | |
private final MethodHandle constructor; | |
private final Object[] data; | |
private RecordBuilder(Factory factory, MethodHandle constructor, Object[] data) { | |
this.factory = factory; | |
this.constructor = constructor; | |
this.data = data; | |
} | |
public <T> RecordBuilder<R> with(SerializableFunction<? super R, ? extends T> accessor, T value) { | |
var slot = factory.getAccessorSlot(accessor); | |
data[slot] = value; | |
return this; | |
} | |
@SuppressWarnings("unchecked") | |
public R build() { | |
try { | |
return (R) constructor.invokeWithArguments(data); | |
} catch (RuntimeException | Error e) { | |
throw e; | |
} catch(Throwable e) { | |
throw new UndeclaredThrowableException(e); | |
} | |
} | |
// --- Use at your own risk | |
// --- Example | |
private static final RecordBuilder.Factory FACTORY = RecordBuilder.factory(MethodHandles.lookup()); | |
public static void main(String[] args) { | |
record Foo(int x, String name, long def) {} | |
var foo = FACTORY.builder(Foo.class) | |
.with(Foo::x, 3) | |
.with(Foo::name, "hello") | |
.build(); | |
System.out.println(foo); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment