Skip to content

Instantly share code, notes, and snippets.

@dan-zheng
Last active February 18, 2019 01:19
Show Gist options
  • Save dan-zheng/85708d553e40bfb2a084074ea2859a55 to your computer and use it in GitHub Desktop.
Save dan-zheng/85708d553e40bfb2a084074ea2859a55 to your computer and use it in GitHub Desktop.
POC API change to avoid manual specification of super long generic type parameter
import TensorFlow
extension SGD {
// Take `Model` metatype as an argument, so it doesn't need to be written out
// explicitly as a generic parameter.
// IMPORTANT: This API change might not be desirable, because usage involves
// getting dynamic type via `type(of:)`.
convenience init(_ modelType: Model.Type, learningRate: Scalar) {
self.init(learningRate: learningRate)
}
}
let dense1 = Dense<Float>(inputSize: 3, outputSize: 4, activation: relu)
let dense2 = Dense<Float>(inputSize: 4, outputSize: 5, activation: relu)
let model = dense1 >> dense2
let opt = SGD(type(of: model), learningRate: 0.02)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment