Created
August 21, 2020 18:08
-
-
Save jruizvar/d3f72e7a2ca8dce9522b1f84473143b0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" SQL Code Generator. | |
Analisa as regras de uma árvore de decisão ajustada pelos modelos | |
Random Forest ou Gradient Boosted do Spark, para gerar o | |
correspondente código em SQL. | |
O conjunto de árvores pode ser extraido com o método `toDebugString` | |
e formatado como uma lista: | |
>> ensemble = model.trees | |
>> list_of_trees = [tree.toDebugString.split(\n) for tree in ensemble] | |
Finalmente, a lista de árvores pode ser iterada para gerar o código em SQL: | |
>> from sql_code_generator import * | |
>> sql_code = [sql_code_generator(tree) for tree in list_of_trees] | |
""" | |
import re | |
class Node: | |
def __init__(self, header, left=None, right=None): | |
""" Define os parâmetros de um nó da árvore | |
""" | |
self.header = header | |
self.left = left | |
self.right = right | |
def __repr__(self): | |
""" Acerta a formatação na linguagem SQL | |
""" | |
header = self.header.replace("{", "(").replace("}", ")") | |
return f"CASE WHEN {header} THEN {self.left} ELSE {self.right} END" | |
def parser(line, letter="I"): | |
""" Captura os parâmetros de uma linha. | |
As linhas que começam com `I` geram um novo nó. | |
As linhas que começam com `P` atualizam um nó existente. | |
""" | |
if letter == "I": | |
m = re.search(r"If \((.+)\)", line) | |
return Node(m.group(1)) | |
m = re.search(r"Predict: (.+)", line) | |
return m.group(1) | |
def rule(line, root, depth): | |
""" Analisa cada linha da árvore. | |
Utiliza uma regex para obter a primeira letra da | |
linha e a profundidade de cada nó dentro da árvore. | |
""" | |
if not root: | |
return parser(line), ["left"] | |
m = re.search(r"(\s+)(\w)\w", line) | |
letter = m.group(2) | |
if letter == "E": | |
indentation = 2 | |
s = m.group(1).count(" ") - indentation | |
d = depth[:s] + ["right"] | |
return root, d | |
exec("root." + ".".join(depth) + "= parser(line, letter)") | |
if letter == "I": | |
d = depth + ["left"] | |
else: | |
d = depth[:-1] | |
return root, d | |
def sql_code_generator(tree, root=None, depth=None): | |
""" Percorre as linhas do modelo Spark de forma recursiva. | |
No final, retorna uma string com o código em SQL. | |
""" | |
if not tree: | |
return str(root) | |
r, d = rule(tree[0], root, depth) | |
return sql_code_generator(tree[1:], r, d) | |
if __name__ == "__main__": | |
""" A modo de exemplo analisamos a seguinte árvore: | |
""" | |
tree = [ | |
" If (a <= 0.5)", # root = Node("a <= 0.5") | |
" If (b in {1.0,2.0})", # root.left = Node(b in {1.0,2.0}) | |
" Predict: 0.0", # root.left.left = "0.0" | |
" Else (b not in {1.0,2.0})", | |
" Predict: 1.0", # root.left.right = "1.0" | |
" Else (a > 0.5)", | |
" Predict: 0.0" # root.right = "0.0" | |
] | |
sql_code = sql_code_generator(tree) | |
print(sql_code) | |
""" Resultado: | |
CASE | |
WHEN | |
a <= 0.5 | |
THEN | |
CASE | |
WHEN | |
b in | |
( | |
1.0, 2.0 | |
) | |
THEN | |
0.0 | |
ELSE | |
1.0 | |
END | |
ELSE | |
0.0 | |
END | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment