Skip to content

Instantly share code, notes, and snippets.

@dan-zheng
Last active June 15, 2019 09:22
Show Gist options
  • Save dan-zheng/68e92b823eb85148744e22e552b16948 to your computer and use it in GitHub Desktop.
Save dan-zheng/68e92b823eb85148744e22e552b16948 to your computer and use it in GitHub Desktop.
Generic Adam optimizer experiment.

I tried to define a generic Adam optimizer using ElementaryFunctions, Differentiable, and VectorProtocol (latter two are being incubated on tensorflow branch).

It doesn't work due to missing scalar-vector and vector-vector operations.

$ swift adam.swift
adam.swift:82:95: error: binary operator '+' cannot be applied to operands of type 'Model.TangentVector' and 'Model.TangentVector.VectorSpaceScalar'
        model.move(along: -stepSize * firstMoments / (Model.TangentVector.sqrt(secondMoments) + epsilon))
                                                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ^ ~~~~~~~
adam.swift:82:95: note: expected an argument list of type '(Self, Self)'
        model.move(along: -stepSize * firstMoments / (Model.TangentVector.sqrt(secondMoments) + epsilon))
                                                                                              ^

There's another unhandled /(Model, Model) operation on line 82:

        model.move(along: -stepSize * firstMoments / (Model.TangentVector.sqrt(secondMoments) + epsilon))
                                       ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
$ swift --version
Apple Swift version 5.1-dev (LLVM 082dec2e22, Swift fc0ab61888)
Target: x86_64-apple-darwin18.6.0
// `Differentiable` and `VectorNumeric` from:
// https://github.com/apple/swift/blob/tensorflow/stdlib/public/core/AutoDiff.swift
public protocol Differentiable {
associatedtype TangentVector: Differentiable & AdditiveArithmetic
where TangentVector.TangentVector == TangentVector
mutating func move(along direction: TangentVector)
}
public protocol VectorProtocol : AdditiveArithmetic {
associatedtype VectorSpaceScalar : AdditiveArithmetic
func scaled(by scalar: VectorSpaceScalar) -> Self
mutating func scale(by scalar: VectorSpaceScalar)
}
public extension VectorProtocol {
mutating func scale(by scalar: VectorSpaceScalar) {
self = scaled(by: scalar)
}
static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self {
lhs.scaled(by: rhs)
}
static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self {
rhs.scaled(by: lhs)
}
static func *= (lhs: inout Self, rhs: VectorSpaceScalar) {
lhs.scale(by: rhs)
}
}
public extension VectorProtocol where VectorSpaceScalar: SignedNumeric {
static prefix func - (x: Self) -> Self {
.zero - x
}
}
/// Reference: "Adam - A Method for Stochastic Optimization".
/// https://arxiv.org/abs/1412.6980v8
@available(macOS 9999, *)
public class Adam<Model: Differentiable>
where
Model.TangentVector: VectorProtocol & ElementaryFunctions,
Model.TangentVector.VectorSpaceScalar: BinaryFloatingPoint & ElementaryFunctions
{
public typealias Scalar = Model.TangentVector.VectorSpaceScalar
public var learningRate: Scalar
public var beta1: Scalar
public var beta2: Scalar
public var epsilon: Scalar
public var decay: Scalar
public var step: Int = 0
public var firstMoments: Model.TangentVector = .zero
public var secondMoments: Model.TangentVector = .zero
public init(
for model: __shared Model,
learningRate: Scalar = 1e-3,
beta1: Scalar = 0.9,
beta2: Scalar = 0.999,
epsilon: Scalar = 1e-8,
decay: Scalar = 0
) {
precondition(learningRate >= 0, "Learning rate must be non-negative")
precondition(0 <= beta1 && beta1 <= 1, "Beta parameter must be between 0 and 1")
precondition(0 <= beta2 && beta2 <= 1, "Beta parameter must be between 0 and 1")
precondition(decay >= 0, "Learning rate decay must be non-negative")
self.learningRate = learningRate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay = decay
}
public func update(_ model: inout Model, along direction: Model.TangentVector) {
step += 1
let learningRate = self.learningRate / (1 + decay * Scalar(step))
// Note: `stepSize` is split into two lines to avoid the "compiler is unable to type-check
// this expression in reasonable time" error.
var stepSize = learningRate * Scalar.sqrt(1 - Scalar.pow(beta2, step))
stepSize = stepSize / (1 - Scalar.pow(beta1, step))
firstMoments = firstMoments * beta1 + (1 - beta1) * direction
secondMoments = secondMoments * beta2 + (1 - beta2) * Model.TangentVector.pow(direction, 2)
model.move(along: -stepSize * firstMoments / (Model.TangentVector.sqrt(secondMoments) + epsilon))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment