-
-
Save hackwaly/f083d21df7eb3ac40eafb49f7d38e04a to your computer and use it in GitHub Desktop.
ppx for polymorphism gadt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(library | |
(name ppx_poly_gadt) | |
(kind ppx_rewriter) | |
(preprocess | |
(pps ppxlib.metaquot)) | |
(libraries ppxlib)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
open Ppxlib | |
module B = Ast_builder.Default | |
open Ast_builder.Make (struct | |
let loc = Location.none | |
end) | |
let derive_type_name ~prefix s = | |
let slen = String.length s in | |
let buf = Buffer.create slen in | |
let prev_is_lower = ref false in | |
for i = 0 to slen - 1 do | |
let c = s.[i] in | |
let lower_c = Char.lowercase_ascii c in | |
let is_lower = lower_c = c in | |
if (not is_lower) && !prev_is_lower then Buffer.add_char buf '_'; | |
Buffer.add_char buf lower_c; | |
prev_is_lower := is_lower | |
done; | |
let s = prefix ^ Buffer.contents buf in | |
if Keyword.is_keyword s then s ^ "_" else s | |
let preprocess_impl _ctxt (str : structure) = | |
let ctor_by_ty = Hashtbl.create 0 in | |
let ctors_by_ty = Hashtbl.create 0 in | |
List.concat_map | |
(fun (stri : structure_item) -> | |
match stri.pstr_desc with | |
| Pstr_type (_, tds) -> | |
let generated_stris = ref [] in | |
let defer = ref [] in | |
let has_poly_gadt = ref false in | |
let record_rows ty_name rows = | |
Hashtbl.replace ctors_by_ty ty_name.txt | |
(rows | |
|> List.map (fun row -> | |
match row.prf_desc with | |
| Rtag ({ txt; _ }, _, _) -> | |
let ty_name = String.lowercase_ascii txt in | |
Hashtbl.find ctor_by_ty ty_name | |
| Rinherit { ptyp_desc = Ptyp_constr (ty_lid, []); _ } -> | |
let ty_name = Longident.name ty_lid.txt in | |
Hashtbl.find ctor_by_ty ty_name | |
| _ -> assert false)) | |
in | |
let add_td_stri ty_name rows = | |
let td = | |
type_declaration ~name:ty_name ~params:[] ~cstrs:[] | |
~kind:Ptype_abstract ~private_:Public | |
~manifest:(Some (ptyp_variant rows Closed None)) | |
in | |
generated_stris := pstr_type Nonrecursive [ td ] :: !generated_stris; | |
record_rows ty_name rows | |
in | |
let proc_td (td : type_declaration) = | |
match td with | |
| { | |
ptype_kind = Ptype_variant cds; | |
ptype_params = [ ({ ptyp_desc = Ptyp_var tv_name; _ }, _) ]; | |
_; | |
} -> | |
let prefix = | |
match tv_name with | |
| "_" -> "" | |
| _ -> | |
if String.ends_with ~suffix:"_" tv_name then tv_name | |
else tv_name ^ "_" | |
in | |
let all_rows = ref [] in | |
let cds' = | |
List.map | |
(fun (cd : constructor_declaration) -> | |
let ty_name = | |
Located.mk (derive_type_name ~prefix cd.pcd_name.txt) | |
in | |
let row = | |
rtag | |
(Located.mk (String.capitalize_ascii ty_name.txt)) | |
true [] | |
in | |
all_rows := row :: !all_rows; | |
let cd' = | |
{ | |
cd with | |
pcd_res = | |
Some | |
(ptyp_constr | |
(Located.mk (lident td.ptype_name.txt)) | |
[ ptyp_variant [ row ] Open None ]); | |
} | |
in | |
Hashtbl.replace ctor_by_ty ty_name.txt cd.pcd_name.txt; | |
add_td_stri ty_name [ row ]; | |
cd') | |
cds | |
in | |
add_td_stri (Located.mk (prefix ^ "any")) !all_rows; | |
{ | |
td with | |
ptype_kind = Ptype_variant cds'; | |
ptype_attributes = | |
td.ptype_attributes | |
|> List.filter (fun it -> it.attr_name.txt <> "poly_gadt"); | |
} | |
| _ -> td | |
in | |
let analyze_td (td : type_declaration) () = | |
match td with | |
| { | |
ptype_kind = Ptype_abstract; | |
ptype_manifest = Some { ptyp_desc = Ptyp_constr (ty_lid, []); _ }; | |
_; | |
} -> | |
Hashtbl.replace ctors_by_ty td.ptype_name.txt | |
(Hashtbl.find ctors_by_ty (Longident.name ty_lid.txt)) | |
| { | |
ptype_kind = Ptype_abstract; | |
ptype_manifest = | |
Some { ptyp_desc = Ptyp_variant (rows, Closed, _); _ }; | |
_; | |
} -> | |
record_rows td.ptype_name rows | |
| _ -> () | |
in | |
let tds' = | |
List.map | |
(fun td -> | |
if | |
List.exists | |
(fun attr -> attr.attr_name.txt = "poly_gadt") | |
td.ptype_attributes | |
then ( | |
has_poly_gadt := true; | |
proc_td td) | |
else ( | |
defer := analyze_td td :: !defer; | |
td)) | |
tds | |
in | |
if !has_poly_gadt then List.iter (fun f -> f ()) !defer; | |
!generated_stris | |
@ [ { stri with pstr_desc = Pstr_type (Recursive, tds') } ] | |
| Pstr_value | |
( Nonrecursive, | |
[ | |
({ | |
pvb_expr = | |
{ | |
pexp_desc = | |
Pexp_extension ({ txt = "poly_field"; _ }, payload); | |
_; | |
}; | |
_; | |
} as binding); | |
] ) -> ( | |
let expr = | |
Ast_pattern.( | |
parse (single_expr_payload __) Location.none payload Fun.id) | |
in | |
match expr with | |
| { | |
pexp_desc = | |
Pexp_constraint | |
( { pexp_desc = Pexp_ident fld_lid; _ }, | |
({ | |
ptyp_desc = | |
Ptyp_constr | |
(_, [ { ptyp_desc = Ptyp_constr (ty_arg_lid, _); _ } ]); | |
_; | |
} as typ) ); | |
_; | |
} -> | |
let pats = | |
Hashtbl.find ctors_by_ty (Longident.name ty_arg_lid.txt) | |
|> List.map (fun ctor -> | |
ppat_construct | |
(Located.mk (lident ctor)) | |
(Some | |
(ppat_record | |
[ (Located.mk fld_lid.txt, [%pat? y]) ] | |
Open))) | |
in | |
let pat = | |
match pats with | |
| hd :: tl -> List.fold_left ppat_or hd tl | |
| _ -> assert false | |
in | |
let match_ = | |
pexp_match [%expr x] | |
[ case ~lhs:pat ~guard:None ~rhs:[%expr y] ] | |
in | |
let fn = [%expr fun (x : [%t typ]) -> [%e match_]] in | |
[ | |
{ | |
stri with | |
pstr_desc = | |
Pstr_value (Nonrecursive, [ { binding with pvb_expr = fn } ]); | |
}; | |
] | |
| _ -> assert false) | |
| _ -> [ stri ]) | |
str | |
let _ = Driver.V2.register_transformation ~preprocess_impl "poly_gadt" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment