Skip to content

Instantly share code, notes, and snippets.

@serenamm
Last active September 9, 2019 18:41
Show Gist options
  • Save serenamm/65d721e3e42b31a2b6086d0467635254 to your computer and use it in GitHub Desktop.
Save serenamm/65d721e3e42b31a2b6086d0467635254 to your computer and use it in GitHub Desktop.
from mock import mock
create_table_query = '''
SELECT
item_id_1,
item_id_2
FROM (
SELECT
item_id_1,
item_id_2,
ROW_NUMBER() OVER(PARTITION BY item_id_1 ORDER BY similarity_score DESC) as row_num
FROM {similarity_table} s
{same_category_q}
)
WHERE row_num <= {num_items}
'''
def create_new_table(spark, table_paths, params, same_category_q):
similarity_table = table_paths["product_similarity"]["table"]
created_table = spark.sql(create_table_query.format(similarity_table=similarity_table,
same_category_q=same_category_q,
num_items=params["num_items"]))
# Write table to some path
created_table.coalesce(1).write.save(table_paths["created_table"]["path"],
format="orc", mode="Overwrite")
def make_query(same_category, table_paths):
if same_category is True:
same_category_q = '''
INNER JOIN {product_table} p
ON s.item_id_1 = p.item_id
INNER JOIN {product_table} q
ON s.item_id_2 = q.item_id
WHERE item_id_1 != item_id_2
AND p.category_id = q.category_id
'''.format(product_table=table_paths["products"]["table"])
else:
same_category_q = ''
return same_category_q
def test_get_queries_true(mocker):
# Create some fake table paths
test_paths = {
"product_table": {
"table": "products",
},
"similarity_table": {
"table": "product_similarity"
}
}
# Call the function with our paths and "True"
same_category_q = make_query(True, test_paths)
# We want same_category_q to be non-empty
assert same_category_q != ''
def test_get_queries_false(mocker):
# As above, create some fake paths
test_paths = {
"product_table": {
"table": "products",
},
"similarity_table": {
"table": "product_similarity"
}
}
same_category_q = make_query(False, test_paths)
# This time, we want same_category_q to be empty
assert same_category_q == ''
def test_create_new_table(mocker):
# Mock all our variables
mock_spark = mock.Mock()
mock_category_q = mock.Mock()
mock_created_table = mock.Mock()
mock_created_table_coalesced = mock.Mock()
# Calling spark.sql with create_table_query returns created_table - we need to mock it
mock_spark.sql.side_effect = [mock_created_table]
# Mock the output of calling .coalesce on created_table
mock_created_table.coalesce.return_value = mock_created_table_coalesced
# Mock the .write as well
mock_write = mock.Mock()
# Mock the output of calling .write on the coalesced created table
mock_created_table_coalesced.write = mock_write
test_paths = {
"product_table": {
"table": "products"
},
"similarity_table": {
"table": "product_similarity"
},
"created_table": {
"path": "path_to_table"
}
}
test_params = {
"num_items": 10
}
# Call our function with our mocks
create_new_table(mock_spark, test_paths, test_params, mock_category_q)
# We only want spark.sql to have been called once, so assert that
assert 1 == mock_spark.sql.call_count
# Assert that we did in fact call created_table.coalesce(1)
mock_created_table.coalesce.assert_called_with(1)
# Assert that the table save path was passed in properly
mock_write.save.assert_called_with(test_paths["created_table"]["path"],
format="orc", mode="Overwrite")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment