Last active
May 9, 2022 15:37
-
-
Save hans/a7fd2ea8cd53a5576fcb9acb9140412a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from pathlib import Path | |
from syntaxgym import Suite | |
import transformers | |
# Materials from https://github.com/cpllab/syntactic-generalization | |
ds = [] | |
for p in Path("../syntactic-generalization/test_suites/json").glob("*.json"): | |
with open(p) as f: | |
ds.append(json.load(f)) | |
ss = [Suite.from_dict(d) for d in ds] | |
out = [] | |
for s in ss: | |
for condition, region_number in s.predictions[0].referenced_regions: | |
for item in s.items: | |
c = next(cx for cx in item["conditions"] if cx["condition_name"] == condition) | |
out.append((s.meta["name"], item["item_number"], region_number, c["regions"][region_number - 1]["content"])) | |
df = pd.DataFrame(out, columns=["suite", "item", "region_number", "content"]) | |
######### | |
tk = transformers.AutoTokenizer.from_pretrained("gpt2") | |
decoded = [tk.convert_ids_to_tokens(tk.encode(string)) for string in df.content] | |
########## | |
# We only care about items where critical region content differs by condition. | |
df["matched_content"] = df.groupby(["suite", "item", "region_number"]).content.transform(lambda xs: len(set(xs)) == 1) | |
df["content_tokenized"] = ["_".join(content) for content in decoded] | |
df["num_tokens_bpe"] = [len(content) for content in decoded] | |
df["num_tokens_whitespace"] = df.content.str.count(" ") + 1 | |
df["num_bpe_splits"] = df.num_tokens_bpe - df.num_tokens_whitespace | |
df.to_csv("critical_region_analysis.csv") | |
# avg. number of BPE splits per critical region for relevant items/suites, grouped by suite | |
df[~df.matched_content].groupby("suite").num_bpe_splits.mean() | |
# on the item level -- how many items with unmatched critical region content have content that contain BPE splits | |
(df[~df.matched_content].groupby(["suite", "item"]).num_bpe_splits.max() > 0).agg(["sum", "mean"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
suite | |
center_embed 1.000000 | |
center_embed_mod 1.000000 | |
number_orc 0.368421 | |
number_prep 0.368421 | |
number_src 0.368421 | |
reflexive_orc_fem 1.000000 | |
reflexive_orc_masc 1.000000 | |
reflexive_prep_fem 1.000000 | |
reflexive_prep_masc 1.000000 | |
reflexive_src_fem 1.000000 | |
reflexive_src_masc 1.000000 | |
subordination 0.021739 | |
subordination_orc-orc 0.021739 | |
subordination_pp-pp 0.021739 | |
subordination_src-src 0.021739 | |
Name: num_bpe_splits, dtype: float64 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment