-
-
Save shwang/74a642161146ecdc702913ac2f992bdb to your computer and use it in GitHub Desktop.
Extract Sacred config entries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2019 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
import collections | |
import json | |
import os | |
import pickle | |
import re | |
from typing import Any, Callable, Iterable, Mapping, Optional, Tuple | |
import numpy as np | |
import pandas as pd | |
Config = Tuple[Any, ...] | |
DirConfigMapping = Mapping[str, Config] | |
def find_sacred_results(root_dir: str) -> DirStatsMapping: | |
"""Find result directories in root_dir, loading the associated config. | |
Finds all directories in `root_dir` that contains a subdirectory "sacred". | |
For each such directory, the config in "sacred/config.json" is loaded. | |
Args: | |
root_dir: A path to recursively search in. | |
Returns: | |
A dictionary of directory paths to Sacred configs. | |
Raises: | |
ValueError: if a results directory contains another results directory. | |
FileNotFoundError: if no results directories found. | |
""" | |
results = set() | |
for root, dirs, _ in os.walk(root_dir): | |
if "sacred" in dirs: | |
results.add(root) | |
if not results: | |
raise FileNotFoundError(f"No Sacred results directories in '{root_dir}'.") | |
# Sanity check: not expecting nested experiments | |
for result in results: | |
components = os.path.split(result) | |
for i in range(1, len(components)): | |
prefix = os.path.join(*components[0:i]) | |
if prefix in results: | |
raise ValueError(f"Parent {prefix} to {result} also a result directory") | |
configs = {} | |
for result in results: | |
config_path = os.path.join(result, "sacred", "config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
configs[result] = config | |
return configs | |
def dict_to_tuple(d, keys: Optional[Iterable[str]] = None): | |
"""Recursively convert dict's to namedtuple's, leaving other values intact.""" | |
if isinstance(d, dict): | |
if keys is None: | |
keys = sorted(d.keys()) | |
key_tuple_cls = collections.namedtuple("KeyTuple", keys) | |
return key_tuple_cls(**{k: dict_to_tuple(d[k]) for k in keys}) | |
else: | |
return d | |
def subset_keys(configs: DirStatsMapping, | |
keys: Iterable[str]) -> DirConfigMapping: | |
"""Extracts the subset of `keys` from each config in `configs`. | |
Args: | |
configs: Paths mapping to full Sacred configs, as returned by | |
`find_sacred_results`. | |
keys: The subset of keys to retain from the config. | |
Returns: | |
A mapping from paths to tuples of the keys. | |
Raises: | |
ValueError: If any of the config subsets are duplicates of each other. | |
""" | |
res = {} | |
configs_seen = set() | |
for path, config in configs.items(): | |
subset = dict_to_tuple(config, keys) # type: Tuple[Any, ...] | |
if subset in configs_seen: | |
raise ValueError(f"Duplicate config '{subset}'") | |
configs_seen.add(subset) | |
res[path] = subset | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment