Last active
August 29, 2023 23:40
-
-
Save Jokeren/fc756ad6b3b22c6dfcec32d5460a1e03 to your computer and use it in GitHub Desktop.
record function reproducer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import sys | |
device = torch.device('cpu') | |
left = torch.zeros(100, device=device, requires_grad=True) | |
right = torch.zeros(100, device=device, requires_grad=True) | |
grad = torch.zeros(100, device=device) | |
for _ in range(10): | |
output = torch.add(left, right) | |
output.backward(grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/extension.h> | |
#include <iostream> | |
int driver_register() { | |
at::addGlobalCallback( | |
at::RecordFunctionCallback( | |
[](const at::RecordFunction& fn) | |
-> std::unique_ptr<at::ObserverContext> { | |
std::cout << fn.forwardThreadId() << std::endl; | |
return nullptr; | |
}, | |
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { | |
return; | |
}) | |
.needsInputs(false) // TODO(Keren): monitor inputs if needed? | |
.needsOutputs(false) // TODO(Keren): monitor outputs if needed? | |
.scopes({})); | |
return 0; | |
} | |
int _ret = driver_register(); |
Author
Jokeren
commented
Aug 29, 2023
•
- Build
- Run
Just to consolidate the discussion here, @Jokeren said I always get fwd_thread_id = 18446744073709551615, which should be 0 or 1
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment