Last active
December 30, 2020 14:24
-
-
Save vtslab/81ded1a7af006100e00bf2a4a70a8147 to your computer and use it in GitHub Desktop.
Converts spark-sql dtypes to a python-friendly format
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
import re | |
import string | |
import pyparsing | |
def pysql_dtypes(dtypes): | |
"""Represents the spark-sql dtypes in terms of python [], {} and Row() | |
constructs. | |
:param dtypes: [(string, string)] result from pyspark.sql.DataFrame.dtypes | |
:return: [(string, string)] | |
""" | |
def assemble(nested): | |
cur = 0 | |
assembled = '' | |
while cur < len(nested): | |
parts = re.findall(r'[^:,]+', nested[cur]) | |
if not parts: | |
parts = [nested[cur]] | |
tail = parts[-1] | |
if tail == 'array': | |
assembled += nested[cur][:-5] + '[' | |
assembled += assemble(nested[cur+1]) | |
assembled += ']' | |
cur += 2 | |
elif tail == 'map': | |
assembled += nested[cur][:-3] + '{' | |
assembled += assemble(nested[cur+1]) | |
assembled += '}' | |
cur += 2 | |
elif tail == 'struct': | |
assembled += nested[cur][:-6] + 'Row(' | |
assembled += assemble(nested[cur+1]) | |
assembled += ')' | |
cur += 2 | |
else: | |
assembled += nested[cur] | |
cur += 1 | |
return assembled | |
chars = ''.join([x for x in string.printable if x not in ['<', '>']]) | |
word = pyparsing.Word(chars) | |
parens = pyparsing.nestedExpr('<', '>', content=word) | |
dtype = word + pyparsing.Optional(parens) | |
result = [] | |
for name, schema in dtypes: | |
tree = dtype.parseString(schema).asList() | |
pyschema = assemble(tree).replace(',', ', ').replace(', ', ', ') | |
result.append((name, pyschema)) | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
from pyspark.sql import types as t | |
from pysql_dtypes import pysql_dtypes | |
class TestPySqlDTypes: | |
"""Run with: | |
export SPARK_HOME=/opt/spark3-client | |
export PYTHONPATH=`echo $SPARK_HOME/python/lib/py4j-*-src.zip` | |
export PYTHONPATH=.:$PYTHONPATH:$SPARK_HOME/python/lib/pyspark.zip | |
pytest -vv test_pysql_dtypes.py | |
""" | |
def test_atomic(self): | |
data = [] | |
for atomic in t._atomic_types: | |
type_name = atomic.typeName() | |
data.append(('field_' + type_name, type_name)) | |
assert pysql_dtypes(data) == data | |
def test_array(self): | |
data = [('field', 'array<bigint>')] | |
assert pysql_dtypes(data) == [('field', '[bigint]')] | |
def test_map(self): | |
data = [('field', 'map<string,bigint>')] | |
assert pysql_dtypes(data) == [('field', '{string, bigint}')] | |
def test_struct_with_atom_atom(self): | |
data = [('field', 'struct<x:bigint,y:string>')] | |
assert pysql_dtypes(data) == [('field', 'Row(x:bigint, y:string)')] | |
def test_struct_with_atom_map(self): | |
data = [( | |
'field', 'struct<x:bigint,y:map<string,bigint>>')] | |
assert pysql_dtypes(data) == [( | |
'field', 'Row(x:bigint, y:{string, bigint})')] | |
def test_struct_with_atom_atom_map(self): | |
data = [( | |
'field', 'struct<x:bigint,y:bigint,z:map<string,bigint>>')] | |
assert pysql_dtypes(data) == [( | |
'field', 'Row(x:bigint, y:bigint, z:{string, bigint})')] | |
def test_struct_with_atom_array_map(self): | |
data = [( | |
'field', 'struct<x:bigint,y:array<bigint>,z:map<string,bigint>>')] | |
assert pysql_dtypes(data) == [( | |
'field', 'Row(x:bigint, y:[bigint], z:{string, bigint})')] | |
def test_array_struct_with_atom_atom(self): | |
data = [( | |
'field', 'array<struct<x:string,y.z:array<string>>>')] | |
assert pysql_dtypes(data) == [( | |
'field', '[Row(x:string, y.z:[string])]')] | |
def test_array_struct_with_atom_map(self): | |
data = [( | |
'field', 'array<struct<x.y:string,z:map<string,array<string>>>>')] | |
assert pysql_dtypes(data) == [( | |
'field', '[Row(x.y:string, z:{string, [string]})]')] | |
def test_array_struct_with_arraystruct_atom_atom(self): | |
data = [( | |
'field', 'array<struct<x:array<struct<x1:string,x2:string>>,' + | |
'y:string, z:string>>' | |
)] | |
assert pysql_dtypes(data) == [( | |
'field', '[Row(x:[Row(x1:string, x2:string)], y:string, z:string)]' | |
)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment