Complex-valued neural networks have been widely used in various science fields, such as the complex-valued wave functions in quantum physics and molecular chemistry, the complex-valued Fourier coefficients in signal processing, and manifold learning on torii that carry the complex structure. A recent survey is J. Bassey et al., arXiv:2101.12249.
One of the differences between complex- and real-valued neural networks is optimization. There is not yet a widely accepted way to generalize most optimizers to the complex domain, and most machine learning (ML) frameworks have not properly implemented them.
In the JAX community, Optax is the only actively maintained package dedicated to optimization, used in companion with other packages dedicated to neural network construction like Flax and Haiku. In particular, Flax developers have proposed to replace flax.optim
with Optax (see FLIP 1009). Optax is also used in various downstream projects like NetKet, JAX MD, RLax, and Brax. We would like to implement complex-valued optimization in Optax and benefit the whole community.
- PyTorch: Recently there is a long discussion about complex-valued optimizers in pytorch#59998, and they are being implemented now. Regarding the popularity of PyTorch, we may follow some design choices from them.
- TensorFlow: The maintainers refused to implement them in the main repo (see tensorflow#38541 comment), but it is possible to contribute to TF Addons.
- Julia language: Many kinds of complex-valued optimizers have been implemented in Optim.jl and GalacticOptim.jl. However, common optimizers for neural networks implemented in Flux or other packages are unaware of complex numbers, and there is no issue about them yet.
Many optimizers make use of the norm of variables (parameters and gradients). For example, in Adam we accumulate the first moment of gradients m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t
in the numerator, and the second moment of the norm of gradients v_t = \beta_2 v_{t-1} + (1-\beta_2) |g_t|^2
in the denominator. Most ML frameworks, including Optax, incorrectly assume that the square of the norm is g**2
, which is true only in the real domain. As a result, complex parameters will not be correctly optimized, as shown in pytorch#59998.
To generalize those optimizers to the complex domain, there are two natural choices: the complex norm and the split real norm. We may decide to implement either or both of them.
We define the complex norm as norm(g: complex) = (g.conj() * g).real
. This is the standard norm on the complex plane, and is reduced to the standard real norm if the imaginary part of g
is zero.
To implement such norm, we take as an example Optax's Adam optimizer:
def update_fn(updates, state, params=None):
del params
mu = _update_moment(updates, state.mu, b1, 1)
nu = _update_moment(updates, state.nu, b2, 2)
count_inc = numerics.safe_int32_increment(state.count)
mu_hat = utils.cast_tree(_bias_correction(mu, b1, count_inc), mu_dtype)
nu_hat = _bias_correction(nu, b2, count_inc)
updates = jax.tree_multimap(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
The only change needed is to accumulate the complex norm of nu
, so we replace the _update_moment
for nu
with another function defined as
def _update_norm_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment of the norm."""
def orderth_norm(g):
if jnp.isrealobj(g):
return g ** order
else:
return (g.conj() * g).real ** (order / 2)
return jax.tree_multimap(
lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)
which is semantically different from _update_moment
, because the semantics of _update_moment
is to accumulate the variable itself, not the norm.
This change is non-breaking, in the sense that it does not affect at all users who only do real-valued optimization.
Another choice of the norm is, quoting the PyTorch developer's comment, "Optimizers on complex tensors should behave the same way as if they were running on two real tensors". For example, in Adam we separately normalize g.real
and g.imag
for each complex gradient g
, as if there are two real parameters. This principle is consistent with the behavior of vjp
chosen by JAX.
To implement it following the composable design of Optax, we can write an optimizer wrapper that splits the complex parameters into pairs of real parameters before sending them to the update
of the wrapped optimizer, and merges the pairs of real updates into complex updates afterward:
def split_real_and_imaginary(inner):
def init_fn(params):
params = jax.tree_map(_complex_to_real_pair, params)
inner_state = inner.init(params)
return SplitRealAndImaginaryState(inner_state)
def update_fn(updates, state, params=None):
inner_state = state.inner_state
updates = jax.tree_map(_complex_to_real_pair, updates)
params = jax.tree_map(_complex_to_real_pair, params)
updates, inner_state = inner.update(updates, inner_state, params)
updates = jax.tree_map(_real_pair_to_complex, updates, is_leaf=_is_real_pair)
return updates, SplitRealAndImaginaryState(inner_state)
return base.GradientTransformation(init_fn, update_fn)
The usage is, for example, optimizer = optax.split_real_and_imaginary(optax.adam(learning_rate))
.
There is no change to the existing API.
It is also possible that we implement both the complex norm and the split real norm as proposed above, and let the user choose between them according to their needs. If the user needs the complex norm, they may ignore split_real_and_imaginary
and directly use optimizer = optax.adam(learning_rate)
. Otherwise, if the user needs the split real norm, they may use optimizer = optax.split_real_and_imaginary(optax.adam(learning_rate))
, so the adam
only processes real gradients, and the above change to adam
is irrelevant.
JAX takes the convention that the output of jax.grad
needs to be conjugated before being added to the parameter in gradient descent optimization. It originates from a choice of convention in the Wirtinger derivatives, and is different from the convention used by PyTorch, TensorFlow and Flux (see jax#4891 and pytorch#41857). Here is an example to demonstrate the difference:
import torch
from torch.autograd import grad
x = torch.tensor(1j, requires_grad=True)
print(grad(abs(x), x)) # 1j
from jax import numpy as jnp
from jax import grad
x = jnp.array(1j)
print(grad(abs)(x)) # -1j
In Optax, we should set up a guideline of how to do the conjugate in the optimization. One choice is to explicitly do the conjugate before optimizer.update
:
grads = jax.grad(compute_loss)(params)
grads = jax.tree_map(lambda x: x.conj(), grads)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
Some users already implemented it in their existing code, like the wrapped vjp
in NetKet.
Another choice is to implement the conjugate in optax.apply_updates
. Specifically, we add a conj
to the update u
:
def apply_updates(params, updates):
return jax.tree_multimap(
lambda p, u: jnp.asarray(p + u.conj()).astype(jnp.asarray(p).dtype),
params, updates)
This change does not affect users who only do real-valued optimization, and there is no performance regression as jax.jit
can eliminate the dispatch overhead of conj
. For the old users who already implemented the conjugate, they need to modify their code accordingly. This change reduces one line of coding for users, but may break some semantics of the gradient in JAX and Optax. For now, we do not intend to implement it until further consensus is reached.
The gradient clipping transformations also depend on the choice of the norm for complex numbers. If we decide to implement the complex norm (or both of the norms), we need to accordingly implement it in the gradient clipping, as in the PR #161.
If we decide on the split real norm, the usage of gradient clipping becomes
optimizer = optax.split_real_and_imaginary(
optax.chain(
optax.clip_by_global_norm(max_norm),
optax.adam(learning_rate)))
so the clip_by_global_norm
only processes real gradients, and there is no change to be made.
Thanks a lot for this nice proposal!
I'd personally vote for the split-real because. As you pointed out, it is consistent with the behaviour of jax.vjp. Another reason to prefer that approach for me might be that users wanting to use complex numbers in neural networks can currently do so by creating their own Haiku classes that split the real and imaginary parts of all inputs. For example, a user could write:
The split real norm approach would allow them to use normal
hk.Linear
instead, but have their network and optimisation behave the same way as with the above workaround.