Skip to content

Instantly share code, notes, and snippets.

@dan-zheng
Last active December 15, 2019 06:03
Show Gist options
  • Save dan-zheng/c82f371a225f89a40306c9494f3235da to your computer and use it in GitHub Desktop.
Save dan-zheng/c82f371a225f89a40306c9494f3235da to your computer and use it in GitHub Desktop.
Dynamic parameters and `allKeyPaths` synthesis
// Demonstrates an array of parameters and how to synthesize `allKeyPaths` for it.
import TensorFlow
struct Parameters : ParameterAggregate {
var weights: [Tensor<Float>]
// `allKeyPaths` getter should be marked @inlinable so that it can be compiler-evaluated.
// Since `weights` is dynamic, `allKeyPaths` is not generally compiler-evaluable.
var allKeyPaths: [WritableKeyPath<Parameters, Tensor<Float>>] {
var result: [WritableKeyPath<Parameters, Tensor<Float>>] = []
for i in 0..<weights.count {
result.append(\Parameters.weights[i])
}
return result
}
}
var parameters = Parameters(weights: [Tensor(1), Tensor(2), Tensor(3)])
for kp in parameters.allKeyPaths {
parameters[keyPath: kp] *= 10
}
print(parameters)
// Parameters(weights: [10.0, 20.0, 30.0])
// Demonstrates generic `Parameterized` types as parameters, and how to synthesize `allKeyPaths` for them.
// `ParameterAggregate` and `Parameterized` are defined in the TensorFlow module.
// I show them here for clarity.
protocol ParameterAggregate {
associatedtype Parameter
var allKeyPaths: [WritableKeyPath<Self, Parameter>] { get }
}
protocol Parameterized {
associatedtype Parameters
var allParameters: Parameters { get set }
}
// Actual program begins here.
import TensorFlow
// Rough sketch of a general "module" abstraction.
protocol Module : Parameterized where Parameters : ParameterAggregate {
associatedtype Input
associatedtype Output
func applied(to: Input) -> Output
}
struct PolicyCore<Scalar : BinaryFloatingPoint, VisionModule : Module, OtherModule : Module> : Parameterized
// Three constraints in the `where` clause per generic parameter conforming to `Module`.
// Four if you count the constraint to `Module` itself.
where VisionModule.Parameters.Parameter == Tensor<Scalar>,
VisionModule.Input == Tensor<Scalar>,
VisionModule.Output == Tensor<Scalar>,
OtherModule.Parameters.Parameter == Tensor<Scalar>,
OtherModule.Input == Tensor<Scalar>,
OtherModule.Output == Tensor<Scalar> {
@TFParameter var visionModule: VisionModule
@TFParameter var otherModule: OtherModule
// Compiler synthesizes `Parameters` struct and `allParameters` instance.
struct Parameters : ParameterAggregate {
var visionModule: VisionModule.Parameters
var otherModule: OtherModule.Parameters
typealias Parameter = Tensor<Scalar>
// Compile-time evaluation of `allKeyPaths` is gonna be tough.
// But hey, it is possible to represent it!
// What matters is that the underlying "effective" parameter type is the same (in this case `Tensor<Scalar>`).
var allKeyPaths: [WritableKeyPath<PolicyCore.Parameters, Tensor<Scalar>>] {
var result: [WritableKeyPath<Parameters, Tensor<Scalar>>] = []
// Use `WritableKeyPath.appending(path:)` API:
// https://developer.apple.com/documentation/swift/writablekeypath
result += visionModule.allKeyPaths.map { (\Parameters.visionModule).appending(path: $0) }
result += otherModule.allKeyPaths.map { (\Parameters.otherModule).appending(path: $0) }
return result
}
}
var allParameters: Parameters {
get {
return Parameters(visionModule: visionModule.allParameters,
otherModule: otherModule.allParameters)
}
set {
visionModule.allParameters = newValue.visionModule
otherModule.allParameters = newValue.otherModule
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment