Last active
February 19, 2021 23:47
-
-
Save odashi/813810a5bc06724ea3643456f8d3942d to your computer and use it in GitHub Desktop.
Augmented dataclass for JAX pytree.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses as dc | |
from jax import tree_util as jt | |
def register_jax_dataclass(cls): | |
"""Registers a dataclass as a JAX pytree.""" | |
if not dc.is_dataclass(cls): | |
raise TypeError('%s is not a dataclass.' % cls) | |
keys = [field.name for field in dc.fields(cls)] | |
def _flatten(obj): | |
return [getattr(obj, key) for key in keys], None | |
def _unflatten(_, children): | |
return cls(**dict(zip(keys, children))) | |
jt.register_pytree_node(cls, _flatten, _unflatten) | |
return cls | |
def jax_dataclass(cls): | |
"""Decorator function to define a dataclass with JAX bindings.""" | |
return register_jax_dataclass(dc.dataclass(cls)) | |
@jax_dataclass | |
class Data: | |
foo: int | |
bar: float | |
a = Data(1, 2.3) | |
leaves, treedef = jt.tree_flatten(a) | |
b = jt.tree_unflatten(treedef, leaves) | |
assert a == b |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment