Last active
July 7, 2021 08:13
-
-
Save DhyanRathore/e361e73acf497d320a3bd2b53e78cd9c to your computer and use it in GitHub Desktop.
Cleansing and transforming schema drifted csv files into relational data with incremental loads in Azure Databricks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python/PySpark code for cleansing and transforming schema drifted csv files into relational data with incremental loads in Azure Databricks | |
# Author: Dhyanendra Singh Rathore | |
# Define the variables used for creating connection strings | |
adlsAccountName = "dlscsvdataproject" | |
adlsContainerName = "csv-data-store" | |
adlsFolderName = "covid19-data" | |
mountPoint = "/mnt/csvFiles" | |
# Application (Client) ID | |
applicationId = dbutils.secrets.get(scope="CSVProjectKeyVault",key="ClientId") | |
# Application (Client) Secret Key | |
authenticationKey = dbutils.secrets.get(scope="CSVProjectKeyVault",key="ClientSecret") | |
# Directory (Tenant) ID | |
tenandId = dbutils.secrets.get(scope="CSVProjectKeyVault",key="TenantId") | |
endpoint = "https://login.microsoftonline.com/" + tenandId + "/oauth2/token" | |
source = "abfss://" + adlsContainerName + "@" + adlsAccountName + ".dfs.core.windows.net/" + adlsFolderName | |
# Connecting using Service Principal secrets and OAuth | |
configs = {"fs.azure.account.auth.type": "OAuth", | |
"fs.azure.account.oauth.provider.type": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider", | |
"fs.azure.account.oauth2.client.id": applicationId, | |
"fs.azure.account.oauth2.client.secret": authenticationKey, | |
"fs.azure.account.oauth2.client.endpoint": endpoint} | |
# Mounting ADLS Storage to DBFS | |
# Mount only if it's not already mounted | |
if not any(mount.mountPoint == mountPoint for mount in dbutils.fs.mounts()): | |
dbutils.fs.mount( | |
source = source, | |
mount_point = mountPoint, | |
extra_configs = configs) | |
# Declare variables for creating JDBC URL | |
jdbcHostname = "sql-csv-data-server.database.windows.net" | |
jdbcPort = 1433 | |
jdbcDatabase = "syn-csv-data-dw" | |
jdbcTable = "csvData.covidcsvdata" | |
# Connection secrets from vault | |
jdbcUsername = dbutils.secrets.get(scope="CSVProjectKeyVault",key="SQLAdmin") | |
jdbcPassword = dbutils.secrets.get(scope="CSVProjectKeyVault",key="SQLAdminPwd") | |
# Create JDBC URL | |
jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname, jdbcPort, jdbcDatabase) | |
connectionProperties = { | |
"user" : jdbcUsername, | |
"password" : jdbcPassword, | |
"driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver" | |
} | |
# import lit() to create new columns in our dataframe | |
from pyspark.sql.functions import lit | |
# Function to flatten the column names by removing (' ', '/', '_') and converting them to lowercase letters | |
def rename_columns(rename_df): | |
for column in rename_df.columns: | |
new_column = column.replace(' ','').replace('/','').replace('_','') | |
rename_df = rename_df.withColumnRenamed(column, new_column.lower()) | |
return rename_df | |
# List all the files we have in our store to iterate through them one by one | |
file_list = [file.name for file in dbutils.fs.ls("dbfs:{}".format(mountPoint))] | |
# Find out the last loaded file to use as cut-off for our incremental load | |
lastLoadedFileQuery = "(SELECT MAX(sourcefile) as sourcefile FROM csvData.covidcsvdata) t" | |
lastFileDf = spark.read.jdbc(url=jdbcUrl, table=lastLoadedFileQuery, properties=connectionProperties) | |
lastFile = lastFileDf.collect()[0][0] | |
# Find the index of the file from the list | |
loadFrom = file_list.index('{}.csv'.format(lastFile)) + 1 if lastFile else 0 | |
# Trim the list keeping only the files that should be processed | |
file_list = file_list[loadFrom:] | |
# Iterate through the files | |
for file in file_list: | |
loadFile = "{0}/{1}".format(mountPoint, file) | |
# Read the csv files with first line as header, comma (,) as separator, and detect schema from the file | |
csvDf = spark.read.format("csv") \ | |
.option("inferSchema", "true") \ | |
.option("header", "true") \ | |
.option("sep", ",") \ | |
.load(loadFile) | |
csvDf = rename_columns(csvDf) | |
# Check dataframe and add/rename columns to fit our database table structure | |
if 'lat' in csvDf.columns: | |
csvDf = csvDf.withColumnRenamed('lat', 'latitude') | |
if 'long' in csvDf.columns: | |
csvDf = csvDf.withColumnRenamed('long', 'longitude') | |
if 'active' not in csvDf.columns: | |
csvDf = csvDf.withColumn('active', lit(None).cast("int")) | |
if 'latitude' not in csvDf.columns: | |
csvDf = csvDf.withColumn('latitude', lit(None).cast("decimal")) | |
if 'longitude' not in csvDf.columns: | |
csvDf = csvDf.withColumn('longitude', lit(None).cast("decimal")) | |
# Add the source file name (without the extension) as an additional column to help us keep track of data source | |
csvDf = csvDf.withColumn("sourcefile", lit(file.split('.')[0])) | |
csvDf = csvDf.select("provincestate", "countryregion", "lastupdate", "confirmed", "deaths", "recovered", "active", "latitude", "longitude", "sourcefile") | |
# Write the cleansed data to the database | |
csvDf.createOrReplaceTempView("finalCsvData") | |
spark.table("finalCsvData").write.mode("append").jdbc(url=jdbcUrl, table=jdbcTable, properties=connectionProperties) | |
# Unmount only if directory is mounted | |
if any(mount.mountPoint == mountPoint for mount in dbutils.fs.mounts()): | |
dbutils.fs.unmount(mountPoint) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment