Skip to content

Instantly share code, notes, and snippets.

@hackwaly
Last active August 6, 2023 04:13
Show Gist options
  • Save hackwaly/f083d21df7eb3ac40eafb49f7d38e04a to your computer and use it in GitHub Desktop.
Save hackwaly/f083d21df7eb3ac40eafb49f7d38e04a to your computer and use it in GitHub Desktop.
ppx for polymorphism gadt
(library
(name ppx_poly_gadt)
(kind ppx_rewriter)
(preprocess
(pps ppxlib.metaquot))
(libraries ppxlib))
open Ppxlib
module B = Ast_builder.Default
open Ast_builder.Make (struct
let loc = Location.none
end)
let derive_type_name ~prefix s =
let slen = String.length s in
let buf = Buffer.create slen in
let prev_is_lower = ref false in
for i = 0 to slen - 1 do
let c = s.[i] in
let lower_c = Char.lowercase_ascii c in
let is_lower = lower_c = c in
if (not is_lower) && !prev_is_lower then Buffer.add_char buf '_';
Buffer.add_char buf lower_c;
prev_is_lower := is_lower
done;
let s = prefix ^ Buffer.contents buf in
if Keyword.is_keyword s then s ^ "_" else s
let preprocess_impl _ctxt (str : structure) =
let ctor_by_ty = Hashtbl.create 0 in
let ctors_by_ty = Hashtbl.create 0 in
List.concat_map
(fun (stri : structure_item) ->
match stri.pstr_desc with
| Pstr_type (_, tds) ->
let generated_stris = ref [] in
let defer = ref [] in
let has_poly_gadt = ref false in
let record_rows ty_name rows =
Hashtbl.replace ctors_by_ty ty_name.txt
(rows
|> List.map (fun row ->
match row.prf_desc with
| Rtag ({ txt; _ }, _, _) ->
let ty_name = String.lowercase_ascii txt in
Hashtbl.find ctor_by_ty ty_name
| Rinherit { ptyp_desc = Ptyp_constr (ty_lid, []); _ } ->
let ty_name = Longident.name ty_lid.txt in
Hashtbl.find ctor_by_ty ty_name
| _ -> assert false))
in
let add_td_stri ty_name rows =
let td =
type_declaration ~name:ty_name ~params:[] ~cstrs:[]
~kind:Ptype_abstract ~private_:Public
~manifest:(Some (ptyp_variant rows Closed None))
in
generated_stris := pstr_type Nonrecursive [ td ] :: !generated_stris;
record_rows ty_name rows
in
let proc_td (td : type_declaration) =
match td with
| {
ptype_kind = Ptype_variant cds;
ptype_params = [ ({ ptyp_desc = Ptyp_var tv_name; _ }, _) ];
_;
} ->
let prefix =
match tv_name with
| "_" -> ""
| _ ->
if String.ends_with ~suffix:"_" tv_name then tv_name
else tv_name ^ "_"
in
let all_rows = ref [] in
let cds' =
List.map
(fun (cd : constructor_declaration) ->
let ty_name =
Located.mk (derive_type_name ~prefix cd.pcd_name.txt)
in
let row =
rtag
(Located.mk (String.capitalize_ascii ty_name.txt))
true []
in
all_rows := row :: !all_rows;
let cd' =
{
cd with
pcd_res =
Some
(ptyp_constr
(Located.mk (lident td.ptype_name.txt))
[ ptyp_variant [ row ] Open None ]);
}
in
Hashtbl.replace ctor_by_ty ty_name.txt cd.pcd_name.txt;
add_td_stri ty_name [ row ];
cd')
cds
in
add_td_stri (Located.mk (prefix ^ "any")) !all_rows;
{
td with
ptype_kind = Ptype_variant cds';
ptype_attributes =
td.ptype_attributes
|> List.filter (fun it -> it.attr_name.txt <> "poly_gadt");
}
| _ -> td
in
let analyze_td (td : type_declaration) () =
match td with
| {
ptype_kind = Ptype_abstract;
ptype_manifest = Some { ptyp_desc = Ptyp_constr (ty_lid, []); _ };
_;
} ->
Hashtbl.replace ctors_by_ty td.ptype_name.txt
(Hashtbl.find ctors_by_ty (Longident.name ty_lid.txt))
| {
ptype_kind = Ptype_abstract;
ptype_manifest =
Some { ptyp_desc = Ptyp_variant (rows, Closed, _); _ };
_;
} ->
record_rows td.ptype_name rows
| _ -> ()
in
let tds' =
List.map
(fun td ->
if
List.exists
(fun attr -> attr.attr_name.txt = "poly_gadt")
td.ptype_attributes
then (
has_poly_gadt := true;
proc_td td)
else (
defer := analyze_td td :: !defer;
td))
tds
in
if !has_poly_gadt then List.iter (fun f -> f ()) !defer;
!generated_stris
@ [ { stri with pstr_desc = Pstr_type (Recursive, tds') } ]
| Pstr_value
( Nonrecursive,
[
({
pvb_expr =
{
pexp_desc =
Pexp_extension ({ txt = "poly_field"; _ }, payload);
_;
};
_;
} as binding);
] ) -> (
let expr =
Ast_pattern.(
parse (single_expr_payload __) Location.none payload Fun.id)
in
match expr with
| {
pexp_desc =
Pexp_constraint
( { pexp_desc = Pexp_ident fld_lid; _ },
({
ptyp_desc =
Ptyp_constr
(_, [ { ptyp_desc = Ptyp_constr (ty_arg_lid, _); _ } ]);
_;
} as typ) );
_;
} ->
let pats =
Hashtbl.find ctors_by_ty (Longident.name ty_arg_lid.txt)
|> List.map (fun ctor ->
ppat_construct
(Located.mk (lident ctor))
(Some
(ppat_record
[ (Located.mk fld_lid.txt, [%pat? y]) ]
Open)))
in
let pat =
match pats with
| hd :: tl -> List.fold_left ppat_or hd tl
| _ -> assert false
in
let match_ =
pexp_match [%expr x]
[ case ~lhs:pat ~guard:None ~rhs:[%expr y] ]
in
let fn = [%expr fun (x : [%t typ]) -> [%e match_]] in
[
{
stri with
pstr_desc =
Pstr_value (Nonrecursive, [ { binding with pvb_expr = fn } ]);
};
]
| _ -> assert false)
| _ -> [ stri ])
str
let _ = Driver.V2.register_transformation ~preprocess_impl "poly_gadt"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment