Created
May 3, 2020 02:55
-
-
Save mrdrozdov/94644e2ca5b519deccd74f7497b1daad to your computer and use it in GitHub Desktop.
fromstring.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Use at own risk. | |
I use nltk for many purposes, but I needed a faster version of nltk.Tree.fromstring. | |
This implementation is much faster and works for my use case, but it | |
has not been tested extensively on complex input. | |
Original fromstring implementation from nltk is here: | |
https://github.com/nltk/nltk/blob/develop/nltk/tree.py | |
""" | |
import nltk | |
def main(): | |
import timeit | |
ex = '(S (Q (X a) (X dog)) (Q (X sings)) (Y (X i) (X added) (Y (X more) (X tokens))) (Y (X i) (X added) (Y (X more) (X tokens))))' | |
# Test original NLTK implementation. | |
print(test_nltk(ex)()) | |
print(timeit.Timer(test_nltk(ex)).timeit(10000)) | |
# > 0.4380141739966348 | |
# Test improved implementation which yields same result. | |
print(test_nltk_fast(ex)()) | |
print(timeit.Timer(test_nltk_fast(ex)).timeit(10000)) | |
# > 0.20882822998100892 | |
# Test improved implementation which yields same result. | |
print(test_nltk_fast_preprocessed(ex)()) | |
print(timeit.Timer(test_nltk_fast_preprocessed(ex)).timeit(10000)) | |
# > 0.18404390098294243 | |
# Test new simplified format. | |
print(test_span_fast(ex)()) | |
print(timeit.Timer(test_span_fast(ex)).timeit(10000)) | |
# > 0.10287073400104418 | |
# Test new simplified format. | |
print(test_span_fast_preprocessed(ex)()) | |
print(timeit.Timer(test_span_fast_preprocessed(ex)).timeit(10000)) | |
# > 0.08031221095006913 | |
def test_nltk(ex): | |
def fn(): | |
return nltk.Tree.fromstring(ex) | |
return fn | |
def test_nltk_fast(ex): | |
def fn(): | |
return fromstring_nltk(ex) | |
return fn | |
def test_nltk_fast_preprocessed(ex): | |
tokens = ex.replace('(', '( ').replace(')', ' )').split() | |
def fn(): | |
return fromstring_nltk(ex, tokens) | |
return fn | |
def test_span_fast(ex): | |
def fn(): | |
return fromstring_span(ex) | |
return fn | |
def test_span_fast_preprocessed(ex): | |
tokens = ex.replace('(', '( ').replace(')', ' )').split() | |
def fn(): | |
return fromstring_span(ex, tokens) | |
return fn | |
def fromstring_nltk(s, tokens=None): | |
if tokens is None: | |
tokens = s.replace('(', '( ').replace(')', ' )').split() | |
tree, _ = recursive_parse_nltk(tokens) | |
return tree | |
def fromstring_span(s, tokens=None): | |
""" Each node is a tuple of (label, children), rather than an nltk Tree node. | |
This should be suitable for most cases and runs much faster. | |
""" | |
if tokens is None: | |
tokens = s.replace('(', '( ').replace(')', ' )').split() | |
tree, _ = recursive_parse(tokens) | |
return tree | |
def recursive_parse_nltk(tokens, pos=0): | |
if tokens[pos + 2] != '(': | |
label = tokens[pos + 1] | |
leaf = tokens[pos + 2] | |
size = 4 | |
node = nltk.Tree(label, (leaf,)) | |
return node, size | |
size = 2 | |
nodes = [] | |
while tokens[pos + size] != ')': | |
xnode, xsize = recursive_parse_nltk(tokens, pos + size) | |
size += xsize | |
nodes.append(xnode) | |
size += 1 | |
label = tokens[pos + 1] | |
children = tuple(nodes) | |
node = nltk.Tree(label, children) | |
return node, size | |
def recursive_parse(tokens, pos=0): | |
if tokens[pos + 2] != '(': | |
label = tokens[pos + 1] | |
leaf = tokens[pos + 2] | |
size = 4 | |
node = (label, leaf) | |
return node, size | |
size = 2 | |
nodes = [] | |
while tokens[pos + size] != ')': | |
xnode, xsize = recursive_parse(tokens, pos + size) | |
size += xsize | |
nodes.append(xnode) | |
size += 1 | |
label = tokens[pos + 1] | |
children = tuple(nodes) | |
node = (label, children) | |
return node, size | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment