Last active
November 28, 2016 21:38
-
-
Save bigorn0/fbc32949d3ca90d60d3f2aafebd04369 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql._ | |
dataframe.coalesce("NUMBER OF WORKERS").mapPartitions((d) => Iterator(d)).foreach { batch => | |
val dbc: Connection = DriverManager.getConnection("JDBCURL") | |
val st: PreparedStatement = dbc.prepareStatement("YOUR PREPARED STATEMENT") | |
batch.grouped("# Of Rows you want per batch").foreach { session => | |
session.foreach { x => | |
st.setDouble(1, x.getDouble(1)) | |
st.addBatch() | |
} | |
st.executeBatch() | |
} | |
dbc.close() | |
} | |
This will execute batches for each worker and close the DB connection. It gives you control over how many workers, how many batches and allows you to work within those confines. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.InputStream | |
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils | |
import org.apache.spark.sql.{ DataFrame, Row } | |
import org.postgresql.copy.CopyManager | |
import org.postgresql.core.BaseConnection | |
val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided | |
val connectionProperties = { | |
val props = new java.util.Properties() | |
props.setProperty("driver", "org.postgresql.Driver") | |
props | |
} | |
// Spark reads the "driver" property to allow users to override the default driver selected, otherwise | |
// it picks the Redshift driver, which doesn't support JDBC CopyManager. | |
// https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51 | |
val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties) | |
// Convert every partition (an `Iterator[Row]`) to bytes (InputStream) | |
def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = { | |
val bytes: Iterator[Byte] = rows.map { row => | |
(row.mkString(delimiter) + "\n").getBytes | |
}.flatten | |
override def read(): Int = if (bytes.hasNext) { | |
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255 | |
} else { | |
-1 | |
} | |
} | |
// Beware: this will open a db connection for every partition of your DataFrame. | |
frame.foreachPartition { rows => | |
val conn = cf() | |
val cm = new CopyManager(conn.asInstanceOf[BaseConnection]) | |
cm.copyIn( | |
"""COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html | |
rowsToInputStream(rows, "\t")) | |
conn.close() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment