Created
May 22, 2024 15:36
-
-
Save ricardoV94/e8902b4c35c26e87e189ab477f8d9288 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pytensor | |
import pytensor.tensor as pt | |
from pytensor.graph.basic import Constant, io_toposort | |
from pytensor.compile import SharedVariable | |
from pytensor.compile.mode import get_mode | |
from pytensor.printing import _debugprint | |
# Define graph | |
x = pt.vector("x") | |
out = pt.where(x > 1, -pt.nan, pt.log(x)) | |
# Optimize function from inputs to outputs | |
mode = get_mode("FAST_RUN").excluding("inplace", "fusion") | |
out = pytensor.function(inputs=[x], outputs=out, mode=mode).maker.fgraph.outputs | |
# Seed initial values | |
evaled_vars = {x: np.random.default_rng(37).normal(size=(50,))} | |
# Compute every intermediate variable | |
for node in pytensor.graph.basic.io_toposort([x], out): | |
input_values = [ | |
inp.data | |
if isinstance(inp, Constant) | |
else ( | |
inp.get_value(borrow=True) | |
if isinstance(inp, SharedVariable) | |
else evaled_vars[inp] | |
) for inp in node.inputs | |
] | |
output_values = [[None] for _ in node.outputs] | |
node.op.perform(node, input_values, output_values) | |
for out, [out_value] in zip(node.outputs, output_values): | |
evaled_vars[out] = out_value | |
# Compile extra information to print next to each node | |
extra_info = {} | |
for key, value in evaled_vars.items(): | |
if key.type.dtype == "bool": | |
extra_info[key] = f"true={np.mean(value):.2%}" | |
elif key.type.dtype.startswith("float"): | |
extra_info[key] = f"nan={np.mean(np.isnan(value)):.2%}" | |
_debugprint(out, storage_map=extra_info) | |
# Switch [id A] nan=60.00% | |
# ├─ Gt [id B] true=10.00% | |
# │ ├─ x [id C] nan=0.00% | |
# │ └─ [1] [id D] | |
# ├─ [nan] [id E] | |
# └─ Log [id F] nan=50.00% | |
# └─ x [id C] nan=0.00% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment