Skip to content

Instantly share code, notes, and snippets.

@fifar
Created August 5, 2020 13:23
Show Gist options
  • Save fifar/be214ef3a023e1a59dfc66476dd7b0b2 to your computer and use it in GitHub Desktop.
Save fifar/be214ef3a023e1a59dfc66476dd7b0b2 to your computer and use it in GitHub Desktop.
Changes of core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala between v2.4.4 and v3.0.0
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 7271004..cdaab59 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -30,6 +30,7 @@ import scala.util.control.NonFatal
import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer}
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput}
+import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool}
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
@@ -38,6 +39,8 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Kryo._
+import org.apache.spark.internal.io.FileCommitProtocol._
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
@@ -57,33 +60,34 @@ class KryoSerializer(conf: SparkConf)
with Logging
with Serializable {
- private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k")
+ private val bufferSizeKb = conf.get(KRYO_SERIALIZER_BUFFER_SIZE)
if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) {
- throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " +
- s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.")
+ throw new IllegalArgumentException(s"${KRYO_SERIALIZER_BUFFER_SIZE.key} must be less than " +
+ s"2048 MiB, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} MiB.")
}
private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt
- val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt
+ val maxBufferSizeMb = conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE).toInt
if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) {
- throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " +
- s"2048 mb, got: + $maxBufferSizeMb mb.")
+ throw new IllegalArgumentException(s"${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} must be less " +
+ s"than 2048 MiB, got: $maxBufferSizeMb MiB.")
}
private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt
- private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
- private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false)
- private val userRegistrators = conf.get("spark.kryo.registrator", "")
- .split(',').map(_.trim)
+ private val referenceTracking = conf.get(KRYO_REFERENCE_TRACKING)
+ private val registrationRequired = conf.get(KRYO_REGISTRATION_REQUIRED)
+ private val userRegistrators = conf.get(KRYO_USER_REGISTRATORS)
+ .map(_.trim)
.filter(!_.isEmpty)
- private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
- .split(',').map(_.trim)
+ private val classesToRegister = conf.get(KRYO_CLASSES_TO_REGISTER)
+ .map(_.trim)
.filter(!_.isEmpty)
private val avroSchemas = conf.getAvroSchema
// whether to use unsafe based IO for serialization
- private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false)
+ private val useUnsafe = conf.get(KRYO_USE_UNSAFE)
+ private val usePool = conf.get(KRYO_USE_POOL)
def newKryoOutput(): KryoOutput =
if (useUnsafe) {
@@ -92,12 +96,41 @@ class KryoSerializer(conf: SparkConf)
new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
}
+ @transient
+ private lazy val factory: KryoFactory = new KryoFactory() {
+ override def create: Kryo = {
+ newKryo()
+ }
+ }
+
+ private class PoolWrapper extends KryoPool {
+ private var pool: KryoPool = getPool
+
+ override def borrow(): Kryo = pool.borrow()
+
+ override def release(kryo: Kryo): Unit = pool.release(kryo)
+
+ override def run[T](kryoCallback: KryoCallback[T]): T = pool.run(kryoCallback)
+
+ def reset(): Unit = {
+ pool = getPool
+ }
+
+ private def getPool: KryoPool = {
+ new KryoPool.Builder(factory).softReferences.build
+ }
+ }
+
+ @transient
+ private lazy val internalPool = new PoolWrapper
+
+ def pool: KryoPool = internalPool
+
def newKryo(): Kryo = {
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
- val oldClassLoader = Thread.currentThread.getContextClassLoader
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
@@ -123,23 +156,22 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
- try {
- // scalastyle:off classforname
- // Use the default classloader when calling the user registrator.
- Thread.currentThread.setContextClassLoader(classLoader)
- // Register classes given through spark.kryo.classesToRegister.
- classesToRegister
- .foreach { className => kryo.register(Class.forName(className, true, classLoader)) }
- // Allow the user to register their own classes by setting spark.kryo.registrator.
- userRegistrators
- .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
- .foreach { reg => reg.registerClasses(kryo) }
- // scalastyle:on classforname
- } catch {
- case e: Exception =>
- throw new SparkException(s"Failed to register classes with Kryo", e)
- } finally {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
+ // Use the default classloader when calling the user registrator.
+ Utils.withContextClassLoader(classLoader) {
+ try {
+ // Register classes given through spark.kryo.classesToRegister.
+ classesToRegister.foreach { className =>
+ kryo.register(Utils.classForName(className, noSparkClassLoader = true))
+ }
+ // Allow the user to register their own classes by setting spark.kryo.registrator.
+ userRegistrators
+ .map(Utils.classForName[KryoRegistrator](_, noSparkClassLoader = true).
+ getConstructor().newInstance())
+ .foreach { reg => reg.registerClasses(kryo) }
+ } catch {
+ case e: Exception =>
+ throw new SparkException(s"Failed to register classes with Kryo", e)
+ }
}
// Register Chill's classes; we do this after our ranges and the user's own classes to let
@@ -181,32 +213,8 @@ class KryoSerializer(conf: SparkConf)
// We can't load those class directly in order to avoid unnecessary jar dependencies.
// We load them safely, ignore it if the class not found.
- Seq(
- "org.apache.spark.sql.catalyst.expressions.UnsafeRow",
- "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
- "org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
-
- "org.apache.spark.ml.feature.Instance",
- "org.apache.spark.ml.feature.LabeledPoint",
- "org.apache.spark.ml.feature.OffsetInstance",
- "org.apache.spark.ml.linalg.DenseMatrix",
- "org.apache.spark.ml.linalg.DenseVector",
- "org.apache.spark.ml.linalg.Matrix",
- "org.apache.spark.ml.linalg.SparseMatrix",
- "org.apache.spark.ml.linalg.SparseVector",
- "org.apache.spark.ml.linalg.Vector",
- "org.apache.spark.ml.tree.impl.TreePoint",
- "org.apache.spark.mllib.clustering.VectorWithNorm",
- "org.apache.spark.mllib.linalg.DenseMatrix",
- "org.apache.spark.mllib.linalg.DenseVector",
- "org.apache.spark.mllib.linalg.Matrix",
- "org.apache.spark.mllib.linalg.SparseMatrix",
- "org.apache.spark.mllib.linalg.SparseVector",
- "org.apache.spark.mllib.linalg.Vector",
- "org.apache.spark.mllib.regression.LabeledPoint"
- ).foreach { name =>
+ KryoSerializer.loadableSparkClasses.foreach { clazz =>
try {
- val clazz = Utils.classForName(name)
kryo.register(clazz)
} catch {
case NonFatal(_) => // do nothing
@@ -218,8 +226,14 @@ class KryoSerializer(conf: SparkConf)
kryo
}
+ override def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
+ super.setDefaultClassLoader(classLoader)
+ internalPool.reset()
+ this
+ }
+
override def newInstance(): SerializerInstance = {
- new KryoSerializerInstance(this, useUnsafe)
+ new KryoSerializerInstance(this, useUnsafe, usePool)
}
private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = {
@@ -246,14 +260,14 @@ class KryoSerializationStream(
this
}
- override def flush() {
+ override def flush(): Unit = {
if (output == null) {
throw new IOException("Stream is closed")
}
output.flush()
}
- override def close() {
+ override def close(): Unit = {
if (output != null) {
try {
output.close()
@@ -288,7 +302,7 @@ class KryoDeserializationStream(
}
}
- override def close() {
+ override def close(): Unit = {
if (input != null) {
try {
// Kryo's Input automatically closes the input stream it is using.
@@ -302,7 +316,8 @@ class KryoDeserializationStream(
}
}
-private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean)
+private[spark] class KryoSerializerInstance(
+ ks: KryoSerializer, useUnsafe: Boolean, usePool: Boolean)
extends SerializerInstance {
/**
* A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do
@@ -310,22 +325,29 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole
* pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are
* not synchronized.
*/
- @Nullable private[this] var cachedKryo: Kryo = borrowKryo()
+ @Nullable private[this] var cachedKryo: Kryo = if (usePool) null else borrowKryo()
/**
* Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance;
* otherwise, it allocates a new instance.
*/
private[serializer] def borrowKryo(): Kryo = {
- if (cachedKryo != null) {
- val kryo = cachedKryo
- // As a defensive measure, call reset() to clear any Kryo state that might have been modified
- // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue)
+ if (usePool) {
+ val kryo = ks.pool.borrow()
kryo.reset()
- cachedKryo = null
kryo
} else {
- ks.newKryo()
+ if (cachedKryo != null) {
+ val kryo = cachedKryo
+ // As a defensive measure, call reset() to clear any Kryo state that might have
+ // been modified by the last operation to borrow this instance
+ // (see SPARK-7766 for discussion of this issue)
+ kryo.reset()
+ cachedKryo = null
+ kryo
+ } else {
+ ks.newKryo()
+ }
}
}
@@ -335,8 +357,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole
* re-use.
*/
private[serializer] def releaseKryo(kryo: Kryo): Unit = {
- if (cachedKryo == null) {
- cachedKryo = kryo
+ if (usePool) {
+ ks.pool.release(kryo)
+ } else {
+ if (cachedKryo == null) {
+ cachedKryo = kryo
+ }
}
}
@@ -352,7 +378,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole
} catch {
case e: KryoException if e.getMessage.startsWith("Buffer overflow") =>
throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " +
- "increase spark.kryoserializer.buffer.max value.", e)
+ s"increase ${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} value.", e)
} finally {
releaseKryo(kryo)
}
@@ -444,7 +470,8 @@ private[serializer] object KryoSerializer {
classOf[Array[String]],
classOf[Array[Array[String]]],
classOf[BoundedPriorityQueue[_]],
- classOf[SparkConf]
+ classOf[SparkConf],
+ classOf[TaskCommitMessage]
)
private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]](
@@ -459,6 +486,50 @@ private[serializer] object KryoSerializer {
}
}
)
+
+ // classForName() is expensive in case the class is not found, so we filter the list of
+ // SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls.
+ private lazy val loadableSparkClasses: Seq[Class[_]] = {
+ Seq(
+ "org.apache.spark.sql.catalyst.expressions.UnsafeRow",
+ "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
+ "org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
+
+ "org.apache.spark.ml.attribute.Attribute",
+ "org.apache.spark.ml.attribute.AttributeGroup",
+ "org.apache.spark.ml.attribute.BinaryAttribute",
+ "org.apache.spark.ml.attribute.NominalAttribute",
+ "org.apache.spark.ml.attribute.NumericAttribute",
+
+ "org.apache.spark.ml.feature.Instance",
+ "org.apache.spark.ml.feature.LabeledPoint",
+ "org.apache.spark.ml.feature.OffsetInstance",
+ "org.apache.spark.ml.linalg.DenseMatrix",
+ "org.apache.spark.ml.linalg.DenseVector",
+ "org.apache.spark.ml.linalg.Matrix",
+ "org.apache.spark.ml.linalg.SparseMatrix",
+ "org.apache.spark.ml.linalg.SparseVector",
+ "org.apache.spark.ml.linalg.Vector",
+ "org.apache.spark.ml.stat.distribution.MultivariateGaussian",
+ "org.apache.spark.ml.tree.impl.TreePoint",
+ "org.apache.spark.mllib.clustering.VectorWithNorm",
+ "org.apache.spark.mllib.linalg.DenseMatrix",
+ "org.apache.spark.mllib.linalg.DenseVector",
+ "org.apache.spark.mllib.linalg.Matrix",
+ "org.apache.spark.mllib.linalg.SparseMatrix",
+ "org.apache.spark.mllib.linalg.SparseVector",
+ "org.apache.spark.mllib.linalg.Vector",
+ "org.apache.spark.mllib.regression.LabeledPoint",
+ "org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
+ ).flatMap { name =>
+ try {
+ Some[Class[_]](Utils.classForName(name))
+ } catch {
+ case NonFatal(_) => None // do nothing
+ case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422.
+ }
+ }
+ }
}
/**
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment