Last active
June 18, 2019 06:55
-
-
Save dan-zheng/d153d950687820418dbae65517259ff0 to your computer and use it in GitHub Desktop.
Loop differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tangent | |
from tangent import grad | |
def nested_loop(x): | |
outer = x | |
for _ in range(1, 3): | |
outer = outer * x | |
inner = outer | |
i = 1 | |
while i < 3: | |
inner = inner + x | |
i += 1 | |
outer = inner | |
return outer | |
def test_pytorch(x): | |
x = torch.FloatTensor([x]) | |
x.requires_grad = True | |
y = nested_loop(x) | |
y.backward() | |
print(x.grad) | |
def test_tangent(x): | |
print(grad(nested_loop)(x)) | |
test_pytorch(2.) | |
test_pytorch(4.) | |
test_tangent(2.) | |
test_tangent(4.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import StdlibUnittest | |
func nested_loop(_ x: Float, count: Int) -> Float { | |
var outer = x | |
for _ in 1..<count { | |
outer = outer * x | |
var inner = outer | |
var i = 1 | |
while i < count { | |
inner = inner + x | |
i += 1 | |
} | |
outer = inner | |
} | |
return outer | |
} | |
expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop(x, count: 3) })) | |
expectEqual((104, 66), valueWithGradient(at: 4, in: { x in nested_loop(x, count: 3) })) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment