Created
December 12, 2020 19:37
-
-
Save TyOverby/52c283d5d082f674be93024ee31a118d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open! Base | |
type binop = | |
[ `Add | |
| `Sub | |
| `Mul | |
| `Div | |
] | |
[@@deriving equal, sexp, compare, hash] | |
type unop = | |
[ `Sqrt | |
| `Neg | |
| `Square | |
] | |
[@@deriving equal, sexp, compare, hash] | |
module Structure = struct | |
type 'recursive t = | |
| Const of float | |
| Var of int | |
| Binop of binop * 'recursive * 'recursive | |
| Unop of unop * 'recursive | |
[@@deriving sexp, equal, compare, hash] | |
let label : _ t -> string = function | |
| Const f -> Printf.sprintf "%f" f | |
| Var i -> Printf.sprintf "[%d]" i | |
| Binop (`Add, _, _) -> "+" | |
| Binop (`Sub, _, _) -> "-" | |
| Binop (`Mul, _, _) -> "*" | |
| Binop (`Div, _, _) -> "/" | |
| Unop (`Sqrt, _) -> "sqrt" | |
| Unop (`Neg, _) -> "neg" | |
| Unop (`Square, _) -> "^2" | |
;; | |
end | |
module Tree = struct | |
type t = T of t Structure.t [@@deriving equal, compare] | |
let rec sexp_of_t (T t) = Structure.sexp_of_t sexp_of_t t | |
let label (T inner) = Structure.label inner | |
let to_graphviz tree = | |
let id = ref 0 in | |
let buf = Buffer.create 1024 in | |
let rec loop (T t) = | |
Int.incr id; | |
let id = Printf.sprintf "t_%d" !id in | |
Printf.bprintf buf "%s [label=\"%s\"];\n" id (label (T t)); | |
match t with | |
| Const _ | Var _ -> id | |
| Binop (_, a, b) -> | |
let a, b = loop a, loop b in | |
Printf.bprintf buf "%s -> %s;\n" a id; | |
Printf.bprintf buf "%s -> %s;\n" b id; | |
id | |
| Unop (_, a) -> | |
let a = loop a in | |
Printf.bprintf buf "%s -> %s;\n" a id; | |
id | |
in | |
let _root = loop tree in | |
"digraph G {\n" ^ Buffer.contents buf ^ "}" | |
;; | |
let rec optimize : t -> t = function | |
| T ((Const _ | Var _) as t) -> T t | |
| T (Binop (`Add, a, b)) -> | |
(match optimize a, optimize b with | |
| T (Const a), T (Const b) -> T (Const (a +. b)) | |
| T (Const 0.0), a -> T (Unop (`Neg, a)) | |
| a, T (Const 0.0) -> a | |
| a, b when equal a b -> T (Binop (`Mul, T (Const 2.0), a)) | |
| a, b -> T (Binop (`Add, a, b))) | |
| T (Binop (`Sub, a, b)) -> | |
(match optimize a, optimize b with | |
| T (Const a), T (Const b) -> T (Const (a -. b)) | |
| a, b when equal a b -> T (Const 0.0) | |
| a, b -> T (Binop (`Sub, a, b))) | |
| T (Binop (`Mul, a, b)) -> | |
(match optimize a, optimize b with | |
| T (Const a), T (Const b) -> T (Const (a *. b)) | |
| T (Const 0.0), _ | _, T (Const 0.0) -> T (Const 0.0) | |
| T (Const 1.0), a | a, T (Const 1.0) -> a | |
| a, b when equal a b -> T (Unop (`Square, a)) | |
| a, b -> T (Binop (`Mul, a, b))) | |
| T (Binop (`Div, a, b)) -> | |
(match optimize a, optimize b with | |
| T (Const 0.0), _ -> T (Const 0.0) | |
| a, b when equal a b -> T (Const 1.0) | |
| _, T (Const 0.0) -> T (Const Float.nan) | |
| T (Const a), T (Const b) -> T (Const (a /. b)) | |
| a, b -> T (Binop (`Div, a, b))) | |
| T (Unop (`Sqrt, a)) -> | |
(match optimize a with | |
| T (Unop (`Square, a)) -> a | |
| T (Const a) -> T (Const (Float.sqrt a)) | |
| a -> T (Unop (`Sqrt, a))) | |
| T (Unop (`Neg, a)) -> | |
(match optimize a with | |
| T (Unop (`Neg, a)) -> a | |
| a -> T (Unop (`Neg, a))) | |
| T (Unop (`Square, a)) -> | |
(match optimize a with | |
| T (Unop (`Sqrt, a)) -> a | |
| a -> T (Unop (`Square, a))) | |
;; | |
end | |
module Graph = struct | |
module Node = struct | |
type t = string Structure.t [@@deriving equal, sexp, compare, hash] | |
end | |
type t = | |
{ nodes : (string * Node.t) list | |
; root : string | |
} | |
[@@deriving equal, sexp, compare] | |
let collide_reorderable = function | |
| Structure.Binop (((`Add | `Mul) as op), a, b) -> | |
let a' = String.max a b in | |
let b' = String.min a b in | |
Structure.Binop (op, a', b') | |
| other -> other | |
;; | |
let of_tree tree = | |
let id = ref 0 in | |
let nodes = ref [] in | |
let seen = Hashtbl.create (module Node) in | |
let find_or_push node = | |
let node = collide_reorderable node in | |
match Hashtbl.find seen node with | |
| Some id -> id | |
| None -> | |
Int.incr id; | |
let id = Printf.sprintf "t_%d" !id in | |
nodes := (id, node) :: !nodes; | |
Hashtbl.set seen ~key:node ~data:id; | |
id | |
in | |
let rec loop : Tree.t -> string = function | |
| T (Const c) -> find_or_push (Const c) | |
| T (Var c) -> find_or_push (Var c) | |
| T (Binop (op, a, b)) -> | |
let a = loop a in | |
let b = loop b in | |
find_or_push (Binop (op, a, b)) | |
| T (Unop (op, a)) -> | |
let a = loop a in | |
find_or_push (Unop (op, a)) | |
in | |
let root = loop tree in | |
{ nodes = List.rev !nodes; root } | |
;; | |
let to_graphviz { root = _; nodes } = | |
let buf = Buffer.create 1024 in | |
List.iter nodes ~f:(function | |
| id, ((Var _ | Const _) as node) -> | |
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node) | |
| id, (Binop (_op, a, b) as node) -> | |
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node); | |
Printf.bprintf buf "%s -> %s;\n" a id; | |
Printf.bprintf buf "%s -> %s;\n" b id | |
| id, (Unop (_op, a) as node) -> | |
Printf.bprintf buf "%s [label=\"%s\"];\n" id (Structure.label node); | |
Printf.bprintf buf "%s -> %s\n" a id); | |
"digraph G {\n" ^ Buffer.contents buf ^ "}" | |
;; | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment