Created
August 5, 2020 13:23
-
-
Save fifar/be214ef3a023e1a59dfc66476dd7b0b2 to your computer and use it in GitHub Desktop.
Changes of core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala between v2.4.4 and v3.0.0
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala | |
index 7271004..cdaab59 100644 | |
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala | |
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala | |
@@ -30,6 +30,7 @@ import scala.util.control.NonFatal | |
import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} | |
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} | |
import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} | |
+import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} | |
import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} | |
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} | |
import org.apache.avro.generic.{GenericData, GenericRecord} | |
@@ -38,6 +39,8 @@ import org.roaringbitmap.RoaringBitmap | |
import org.apache.spark._ | |
import org.apache.spark.api.python.PythonBroadcast | |
import org.apache.spark.internal.Logging | |
+import org.apache.spark.internal.config.Kryo._ | |
+import org.apache.spark.internal.io.FileCommitProtocol._ | |
import org.apache.spark.network.util.ByteUnit | |
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} | |
import org.apache.spark.storage._ | |
@@ -57,33 +60,34 @@ class KryoSerializer(conf: SparkConf) | |
with Logging | |
with Serializable { | |
- private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") | |
+ private val bufferSizeKb = conf.get(KRYO_SERIALIZER_BUFFER_SIZE) | |
if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { | |
- throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + | |
- s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") | |
+ throw new IllegalArgumentException(s"${KRYO_SERIALIZER_BUFFER_SIZE.key} must be less than " + | |
+ s"2048 MiB, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} MiB.") | |
} | |
private val bufferSize = ByteUnit.KiB.toBytes(bufferSizeKb).toInt | |
- val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt | |
+ val maxBufferSizeMb = conf.get(KRYO_SERIALIZER_MAX_BUFFER_SIZE).toInt | |
if (maxBufferSizeMb >= ByteUnit.GiB.toMiB(2)) { | |
- throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + | |
- s"2048 mb, got: + $maxBufferSizeMb mb.") | |
+ throw new IllegalArgumentException(s"${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} must be less " + | |
+ s"than 2048 MiB, got: $maxBufferSizeMb MiB.") | |
} | |
private val maxBufferSize = ByteUnit.MiB.toBytes(maxBufferSizeMb).toInt | |
- private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) | |
- private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) | |
- private val userRegistrators = conf.get("spark.kryo.registrator", "") | |
- .split(',').map(_.trim) | |
+ private val referenceTracking = conf.get(KRYO_REFERENCE_TRACKING) | |
+ private val registrationRequired = conf.get(KRYO_REGISTRATION_REQUIRED) | |
+ private val userRegistrators = conf.get(KRYO_USER_REGISTRATORS) | |
+ .map(_.trim) | |
.filter(!_.isEmpty) | |
- private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") | |
- .split(',').map(_.trim) | |
+ private val classesToRegister = conf.get(KRYO_CLASSES_TO_REGISTER) | |
+ .map(_.trim) | |
.filter(!_.isEmpty) | |
private val avroSchemas = conf.getAvroSchema | |
// whether to use unsafe based IO for serialization | |
- private val useUnsafe = conf.getBoolean("spark.kryo.unsafe", false) | |
+ private val useUnsafe = conf.get(KRYO_USE_UNSAFE) | |
+ private val usePool = conf.get(KRYO_USE_POOL) | |
def newKryoOutput(): KryoOutput = | |
if (useUnsafe) { | |
@@ -92,12 +96,41 @@ class KryoSerializer(conf: SparkConf) | |
new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) | |
} | |
+ @transient | |
+ private lazy val factory: KryoFactory = new KryoFactory() { | |
+ override def create: Kryo = { | |
+ newKryo() | |
+ } | |
+ } | |
+ | |
+ private class PoolWrapper extends KryoPool { | |
+ private var pool: KryoPool = getPool | |
+ | |
+ override def borrow(): Kryo = pool.borrow() | |
+ | |
+ override def release(kryo: Kryo): Unit = pool.release(kryo) | |
+ | |
+ override def run[T](kryoCallback: KryoCallback[T]): T = pool.run(kryoCallback) | |
+ | |
+ def reset(): Unit = { | |
+ pool = getPool | |
+ } | |
+ | |
+ private def getPool: KryoPool = { | |
+ new KryoPool.Builder(factory).softReferences.build | |
+ } | |
+ } | |
+ | |
+ @transient | |
+ private lazy val internalPool = new PoolWrapper | |
+ | |
+ def pool: KryoPool = internalPool | |
+ | |
def newKryo(): Kryo = { | |
val instantiator = new EmptyScalaKryoInstantiator | |
val kryo = instantiator.newKryo() | |
kryo.setRegistrationRequired(registrationRequired) | |
- val oldClassLoader = Thread.currentThread.getContextClassLoader | |
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) | |
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. | |
@@ -123,23 +156,22 @@ class KryoSerializer(conf: SparkConf) | |
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) | |
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) | |
- try { | |
- // scalastyle:off classforname | |
- // Use the default classloader when calling the user registrator. | |
- Thread.currentThread.setContextClassLoader(classLoader) | |
- // Register classes given through spark.kryo.classesToRegister. | |
- classesToRegister | |
- .foreach { className => kryo.register(Class.forName(className, true, classLoader)) } | |
- // Allow the user to register their own classes by setting spark.kryo.registrator. | |
- userRegistrators | |
- .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) | |
- .foreach { reg => reg.registerClasses(kryo) } | |
- // scalastyle:on classforname | |
- } catch { | |
- case e: Exception => | |
- throw new SparkException(s"Failed to register classes with Kryo", e) | |
- } finally { | |
- Thread.currentThread.setContextClassLoader(oldClassLoader) | |
+ // Use the default classloader when calling the user registrator. | |
+ Utils.withContextClassLoader(classLoader) { | |
+ try { | |
+ // Register classes given through spark.kryo.classesToRegister. | |
+ classesToRegister.foreach { className => | |
+ kryo.register(Utils.classForName(className, noSparkClassLoader = true)) | |
+ } | |
+ // Allow the user to register their own classes by setting spark.kryo.registrator. | |
+ userRegistrators | |
+ .map(Utils.classForName[KryoRegistrator](_, noSparkClassLoader = true). | |
+ getConstructor().newInstance()) | |
+ .foreach { reg => reg.registerClasses(kryo) } | |
+ } catch { | |
+ case e: Exception => | |
+ throw new SparkException(s"Failed to register classes with Kryo", e) | |
+ } | |
} | |
// Register Chill's classes; we do this after our ranges and the user's own classes to let | |
@@ -181,32 +213,8 @@ class KryoSerializer(conf: SparkConf) | |
// We can't load those class directly in order to avoid unnecessary jar dependencies. | |
// We load them safely, ignore it if the class not found. | |
- Seq( | |
- "org.apache.spark.sql.catalyst.expressions.UnsafeRow", | |
- "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", | |
- "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", | |
- | |
- "org.apache.spark.ml.feature.Instance", | |
- "org.apache.spark.ml.feature.LabeledPoint", | |
- "org.apache.spark.ml.feature.OffsetInstance", | |
- "org.apache.spark.ml.linalg.DenseMatrix", | |
- "org.apache.spark.ml.linalg.DenseVector", | |
- "org.apache.spark.ml.linalg.Matrix", | |
- "org.apache.spark.ml.linalg.SparseMatrix", | |
- "org.apache.spark.ml.linalg.SparseVector", | |
- "org.apache.spark.ml.linalg.Vector", | |
- "org.apache.spark.ml.tree.impl.TreePoint", | |
- "org.apache.spark.mllib.clustering.VectorWithNorm", | |
- "org.apache.spark.mllib.linalg.DenseMatrix", | |
- "org.apache.spark.mllib.linalg.DenseVector", | |
- "org.apache.spark.mllib.linalg.Matrix", | |
- "org.apache.spark.mllib.linalg.SparseMatrix", | |
- "org.apache.spark.mllib.linalg.SparseVector", | |
- "org.apache.spark.mllib.linalg.Vector", | |
- "org.apache.spark.mllib.regression.LabeledPoint" | |
- ).foreach { name => | |
+ KryoSerializer.loadableSparkClasses.foreach { clazz => | |
try { | |
- val clazz = Utils.classForName(name) | |
kryo.register(clazz) | |
} catch { | |
case NonFatal(_) => // do nothing | |
@@ -218,8 +226,14 @@ class KryoSerializer(conf: SparkConf) | |
kryo | |
} | |
+ override def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { | |
+ super.setDefaultClassLoader(classLoader) | |
+ internalPool.reset() | |
+ this | |
+ } | |
+ | |
override def newInstance(): SerializerInstance = { | |
- new KryoSerializerInstance(this, useUnsafe) | |
+ new KryoSerializerInstance(this, useUnsafe, usePool) | |
} | |
private[spark] override lazy val supportsRelocationOfSerializedObjects: Boolean = { | |
@@ -246,14 +260,14 @@ class KryoSerializationStream( | |
this | |
} | |
- override def flush() { | |
+ override def flush(): Unit = { | |
if (output == null) { | |
throw new IOException("Stream is closed") | |
} | |
output.flush() | |
} | |
- override def close() { | |
+ override def close(): Unit = { | |
if (output != null) { | |
try { | |
output.close() | |
@@ -288,7 +302,7 @@ class KryoDeserializationStream( | |
} | |
} | |
- override def close() { | |
+ override def close(): Unit = { | |
if (input != null) { | |
try { | |
// Kryo's Input automatically closes the input stream it is using. | |
@@ -302,7 +316,8 @@ class KryoDeserializationStream( | |
} | |
} | |
-private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boolean) | |
+private[spark] class KryoSerializerInstance( | |
+ ks: KryoSerializer, useUnsafe: Boolean, usePool: Boolean) | |
extends SerializerInstance { | |
/** | |
* A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do | |
@@ -310,22 +325,29 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole | |
* pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are | |
* not synchronized. | |
*/ | |
- @Nullable private[this] var cachedKryo: Kryo = borrowKryo() | |
+ @Nullable private[this] var cachedKryo: Kryo = if (usePool) null else borrowKryo() | |
/** | |
* Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; | |
* otherwise, it allocates a new instance. | |
*/ | |
private[serializer] def borrowKryo(): Kryo = { | |
- if (cachedKryo != null) { | |
- val kryo = cachedKryo | |
- // As a defensive measure, call reset() to clear any Kryo state that might have been modified | |
- // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) | |
+ if (usePool) { | |
+ val kryo = ks.pool.borrow() | |
kryo.reset() | |
- cachedKryo = null | |
kryo | |
} else { | |
- ks.newKryo() | |
+ if (cachedKryo != null) { | |
+ val kryo = cachedKryo | |
+ // As a defensive measure, call reset() to clear any Kryo state that might have | |
+ // been modified by the last operation to borrow this instance | |
+ // (see SPARK-7766 for discussion of this issue) | |
+ kryo.reset() | |
+ cachedKryo = null | |
+ kryo | |
+ } else { | |
+ ks.newKryo() | |
+ } | |
} | |
} | |
@@ -335,8 +357,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole | |
* re-use. | |
*/ | |
private[serializer] def releaseKryo(kryo: Kryo): Unit = { | |
- if (cachedKryo == null) { | |
- cachedKryo = kryo | |
+ if (usePool) { | |
+ ks.pool.release(kryo) | |
+ } else { | |
+ if (cachedKryo == null) { | |
+ cachedKryo = kryo | |
+ } | |
} | |
} | |
@@ -352,7 +378,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole | |
} catch { | |
case e: KryoException if e.getMessage.startsWith("Buffer overflow") => | |
throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + | |
- "increase spark.kryoserializer.buffer.max value.", e) | |
+ s"increase ${KRYO_SERIALIZER_MAX_BUFFER_SIZE.key} value.", e) | |
} finally { | |
releaseKryo(kryo) | |
} | |
@@ -444,7 +470,8 @@ private[serializer] object KryoSerializer { | |
classOf[Array[String]], | |
classOf[Array[Array[String]]], | |
classOf[BoundedPriorityQueue[_]], | |
- classOf[SparkConf] | |
+ classOf[SparkConf], | |
+ classOf[TaskCommitMessage] | |
) | |
private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( | |
@@ -459,6 +486,50 @@ private[serializer] object KryoSerializer { | |
} | |
} | |
) | |
+ | |
+ // classForName() is expensive in case the class is not found, so we filter the list of | |
+ // SQL / ML / MLlib classes once and then re-use that filtered list in newInstance() calls. | |
+ private lazy val loadableSparkClasses: Seq[Class[_]] = { | |
+ Seq( | |
+ "org.apache.spark.sql.catalyst.expressions.UnsafeRow", | |
+ "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", | |
+ "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", | |
+ | |
+ "org.apache.spark.ml.attribute.Attribute", | |
+ "org.apache.spark.ml.attribute.AttributeGroup", | |
+ "org.apache.spark.ml.attribute.BinaryAttribute", | |
+ "org.apache.spark.ml.attribute.NominalAttribute", | |
+ "org.apache.spark.ml.attribute.NumericAttribute", | |
+ | |
+ "org.apache.spark.ml.feature.Instance", | |
+ "org.apache.spark.ml.feature.LabeledPoint", | |
+ "org.apache.spark.ml.feature.OffsetInstance", | |
+ "org.apache.spark.ml.linalg.DenseMatrix", | |
+ "org.apache.spark.ml.linalg.DenseVector", | |
+ "org.apache.spark.ml.linalg.Matrix", | |
+ "org.apache.spark.ml.linalg.SparseMatrix", | |
+ "org.apache.spark.ml.linalg.SparseVector", | |
+ "org.apache.spark.ml.linalg.Vector", | |
+ "org.apache.spark.ml.stat.distribution.MultivariateGaussian", | |
+ "org.apache.spark.ml.tree.impl.TreePoint", | |
+ "org.apache.spark.mllib.clustering.VectorWithNorm", | |
+ "org.apache.spark.mllib.linalg.DenseMatrix", | |
+ "org.apache.spark.mllib.linalg.DenseVector", | |
+ "org.apache.spark.mllib.linalg.Matrix", | |
+ "org.apache.spark.mllib.linalg.SparseMatrix", | |
+ "org.apache.spark.mllib.linalg.SparseVector", | |
+ "org.apache.spark.mllib.linalg.Vector", | |
+ "org.apache.spark.mllib.regression.LabeledPoint", | |
+ "org.apache.spark.mllib.stat.distribution.MultivariateGaussian" | |
+ ).flatMap { name => | |
+ try { | |
+ Some[Class[_]](Utils.classForName(name)) | |
+ } catch { | |
+ case NonFatal(_) => None // do nothing | |
+ case _: NoClassDefFoundError if Utils.isTesting => None // See SPARK-23422. | |
+ } | |
+ } | |
+ } | |
} | |
/** |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment