Created
August 4, 2023 20:42
-
-
Save heiner/7ed5802eb9b3218a418262d542ecb827 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Repro for ownership issue with Python stores for torch.distributed. | |
rm -rf /tmp/store_bug; for i in {0..1}; do python -u store_bug.py $i 2 & done && wait | |
""" | |
import os | |
import datetime | |
import sys | |
import time | |
import fcntl | |
import pathlib | |
import torch | |
import numpy as np | |
DEBUG = False | |
POLL_SLEEP_TIME = 0.5 | |
def _sanitize(key): | |
return "".join([x if x.isalnum() or x in "._-()" else "_" for x in key]) | |
class Store(torch.distributed.Store): | |
def __init__(self, path, rank, timeout=5): | |
super().__init__() | |
self.path = path | |
self.path.mkdir(exist_ok=True) | |
self.rank = rank | |
self._timeout = timeout # self.timeout not writable here. | |
def set(self, key, value): | |
if DEBUG: | |
print("set with rank %i, key %s, value %s" % (self.rank, key, value)) | |
if not isinstance(key, (str, bytes)): | |
raise AssertionError("Expected set to be called with string key") | |
if type(value) is not bytes: | |
raise AssertionError("Expected set to be called with bytes value") | |
with (self.path / _sanitize(key)).open("bw") as f: | |
fcntl.flock(f, fcntl.LOCK_EX) | |
f.truncate() | |
f.write(value) | |
def get(self, key): | |
path = self.path / _sanitize(key) | |
if DEBUG: | |
print("get on rank %i, key %s" % (self.rank, key)) | |
# Busy loop. | |
start = time.time() | |
while True: | |
if self._timeout and time.time() - start > self._timeout: | |
raise RuntimeError(f"Timeout {self._timeout} hit on get({key!r})") | |
try: | |
with path.open("br") as f: | |
fcntl.flock(f, fcntl.LOCK_EX) | |
value = f.read() | |
if DEBUG: | |
print("%i get done, key %s, value %s" % (self.rank, key, repr(value))) | |
return value | |
except FileNotFoundError: | |
time.sleep(POLL_SLEEP_TIME) | |
continue | |
def add(self, key, value): | |
if DEBUG: | |
print("add with rank %i, key %s, value %s" % (self.rank, key, repr(value))) | |
# No way to open in RDWR mode and create the file if it doesn't exist ... | |
with os.fdopen(os.open(self.path / _sanitize(key), os.O_RDWR | os.O_CREAT), "rb+") as f: | |
fcntl.flock(f, fcntl.LOCK_EX) | |
result = np.fromfile(f, dtype=np.int64, count=1) | |
if not result.size: | |
result = np.zeros([], dtype=np.int64) | |
result += value | |
f.truncate() | |
f.seek(0) | |
result.tofile(f) | |
time.sleep(POLL_SLEEP_TIME) | |
return result.item() | |
def wait(self, keys, timeout=None): | |
paths = [self.path / _sanitize(key) for key in keys] | |
if timeout is not None: | |
timeout = timeout.total_seconds() # timedelta object. | |
else: | |
timeout = self._timeout | |
if DEBUG: | |
print("wait with rank %i, keys %s, timeout %s" % (self.rank, keys, timeout)) | |
start = time.time() | |
while True: | |
if timeout and time.time() - start > timeout: | |
raise RuntimeError(f"Timeout {timeout} hit on wait({keys})") | |
values = [] | |
for path in paths: | |
try: | |
with path.open("br") as f: | |
fcntl.flock(f, fcntl.LOCK_EX) | |
values.append(f.read()) | |
except FileNotFoundError: | |
values.append(None) | |
# This assumes keys once read as non None don't get deleted while we wait. | |
keys = [key for key, value in zip(keys, values) if value is None] | |
if not keys: | |
return | |
time.sleep(POLL_SLEEP_TIME) | |
def init(rank, world_size, use_workaround=False, *, _cache={}): | |
store = Store(pathlib.Path("/tmp/store_bug"), rank) | |
torch.distributed.init_process_group( | |
backend="gloo", | |
world_size=world_size, | |
rank=rank, | |
timeout=datetime.timedelta(seconds=3.0), | |
store=store, | |
) | |
if use_workaround: | |
_cache["store"] = store | |
def main(): | |
rank = int(sys.argv[1]) | |
world_size = int(sys.argv[2]) | |
use_workaround = sys.argv[3] if len(sys.argv) > 3 else False | |
print(f"starting rank {rank}/{world_size} with use_workaround={use_workaround}") | |
init(rank, world_size, use_workaround) | |
g = torch.distributed.new_group(timeout=datetime.timedelta(seconds=2.0)) | |
payload = torch.tensor([3 * rank + 1], dtype=torch.int64) | |
torch.distributed.all_reduce(payload, group=g) | |
print(f"Rank {rank}, sum is {payload}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment