This gist is an in-depth exploration of the Swift parameter update
design,
focusing on dynamic parameters and how to synthesize
allKeyPaths
for them.
Last active
December 15, 2019 06:03
-
-
Save dan-zheng/c82f371a225f89a40306c9494f3235da to your computer and use it in GitHub Desktop.
Dynamic parameters and `allKeyPaths` synthesis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Demonstrates an array of parameters and how to synthesize `allKeyPaths` for it. | |
import TensorFlow | |
struct Parameters : ParameterAggregate { | |
var weights: [Tensor<Float>] | |
// `allKeyPaths` getter should be marked @inlinable so that it can be compiler-evaluated. | |
// Since `weights` is dynamic, `allKeyPaths` is not generally compiler-evaluable. | |
var allKeyPaths: [WritableKeyPath<Parameters, Tensor<Float>>] { | |
var result: [WritableKeyPath<Parameters, Tensor<Float>>] = [] | |
for i in 0..<weights.count { | |
result.append(\Parameters.weights[i]) | |
} | |
return result | |
} | |
} | |
var parameters = Parameters(weights: [Tensor(1), Tensor(2), Tensor(3)]) | |
for kp in parameters.allKeyPaths { | |
parameters[keyPath: kp] *= 10 | |
} | |
print(parameters) | |
// Parameters(weights: [10.0, 20.0, 30.0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Demonstrates generic `Parameterized` types as parameters, and how to synthesize `allKeyPaths` for them. | |
// `ParameterAggregate` and `Parameterized` are defined in the TensorFlow module. | |
// I show them here for clarity. | |
protocol ParameterAggregate { | |
associatedtype Parameter | |
var allKeyPaths: [WritableKeyPath<Self, Parameter>] { get } | |
} | |
protocol Parameterized { | |
associatedtype Parameters | |
var allParameters: Parameters { get set } | |
} | |
// Actual program begins here. | |
import TensorFlow | |
// Rough sketch of a general "module" abstraction. | |
protocol Module : Parameterized where Parameters : ParameterAggregate { | |
associatedtype Input | |
associatedtype Output | |
func applied(to: Input) -> Output | |
} | |
struct PolicyCore<Scalar : BinaryFloatingPoint, VisionModule : Module, OtherModule : Module> : Parameterized | |
// Three constraints in the `where` clause per generic parameter conforming to `Module`. | |
// Four if you count the constraint to `Module` itself. | |
where VisionModule.Parameters.Parameter == Tensor<Scalar>, | |
VisionModule.Input == Tensor<Scalar>, | |
VisionModule.Output == Tensor<Scalar>, | |
OtherModule.Parameters.Parameter == Tensor<Scalar>, | |
OtherModule.Input == Tensor<Scalar>, | |
OtherModule.Output == Tensor<Scalar> { | |
@TFParameter var visionModule: VisionModule | |
@TFParameter var otherModule: OtherModule | |
// Compiler synthesizes `Parameters` struct and `allParameters` instance. | |
struct Parameters : ParameterAggregate { | |
var visionModule: VisionModule.Parameters | |
var otherModule: OtherModule.Parameters | |
typealias Parameter = Tensor<Scalar> | |
// Compile-time evaluation of `allKeyPaths` is gonna be tough. | |
// But hey, it is possible to represent it! | |
// What matters is that the underlying "effective" parameter type is the same (in this case `Tensor<Scalar>`). | |
var allKeyPaths: [WritableKeyPath<PolicyCore.Parameters, Tensor<Scalar>>] { | |
var result: [WritableKeyPath<Parameters, Tensor<Scalar>>] = [] | |
// Use `WritableKeyPath.appending(path:)` API: | |
// https://developer.apple.com/documentation/swift/writablekeypath | |
result += visionModule.allKeyPaths.map { (\Parameters.visionModule).appending(path: $0) } | |
result += otherModule.allKeyPaths.map { (\Parameters.otherModule).appending(path: $0) } | |
return result | |
} | |
} | |
var allParameters: Parameters { | |
get { | |
return Parameters(visionModule: visionModule.allParameters, | |
otherModule: otherModule.allParameters) | |
} | |
set { | |
visionModule.allParameters = newValue.visionModule | |
otherModule.allParameters = newValue.otherModule | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment