

spark shuff部分是spark源码的重要组成部分,shuffle发生在stage的交界处,对于spark的性能有重要影响,源码更新后,spark的shuffle机制也不一样,本文分析spark2.0的shuffle实现。



shuffle是Mapreduce框架中一个特定的phase,介于Map和Reduce之间。shuffle的英文意思是混洗,包含两个部分,shuffle write 和shuffle read。这里有一篇文章:详细探究Spark的shuffle实现,这篇文章写于2014年,讲的是早期版本的shuffle实现。随着源码的更新,shuffle机制也做出了相应的优化,下面分析spark-2.0的shuffle机制。




这个类实现了基于sort-shuffle的hash风格的shuffle fallback path(回退路径?怎么翻)。这个write路径把数据写到不同的文件里,每个文件对应一个reduce分区,然后把这些文件整合到一个单独的文件,这个文件的不同区域服务不同的reducer。数据不是缓存在内存中。这个类本质上和之前的HashShuffleReader,除了这个类的输出格式可以通过org.apache.spark.shuffle.IndexShuffleBlockResolver来调用。这个写路径对于有许多reduce分区的shuffle来说是不高效的,因为他同时打开很多serializers和文件流。因此只有在以下情况下才会选择这个路径:

1、没有排序 2、没有聚合操作 3、partition的数量小于bypassMergeThreshold


public void write(Iterator<Product2<K, V>> records) throws IOException {
  assert (partitionWriters == null);
  if (!records.hasNext()) {
    partitionLengths = new long[numPartitions];
    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
  final SerializerInstance serInstance = serializer.newInstance();
  final long openStartTime = System.nanoTime();
  partitionWriters = new DiskBlockObjectWriter[numPartitions];
  for (int i = 0; i < numPartitions; i++) {
    final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
    final File file = tempShuffleBlockIdPlusFile._2();
    final BlockId blockId = tempShuffleBlockIdPlusFile._1();
    partitionWriters[i] =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
  // Creating the file to write to and creating a disk writer both involve interacting with
  // the disk, and can take a long time in aggregate when we open many files, so should be
  // included in the shuffle write time.
  writeMetrics.incWriteTime(System.nanoTime() - openStartTime);

  while (records.hasNext()) {
    final Product2<K, V> record = records.next();
    final K key = record._1();
    partitionWriters[partitioner.getPartition(key)].write(key, record._2());

  for (DiskBlockObjectWriter writer : partitionWriters) {

  File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
  File tmp = Utils.tempFileWith(output);
  try {
    partitionLengths = writePartitionedFile(tmp);
    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
  mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);


 * Concatenate all of the per-partition files into a single combined file.
 * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
private long[] writePartitionedFile(File outputFile) throws IOException {
  // Track location of the partition starts in the output file
  final long[] lengths = new long[numPartitions];
  if (partitionWriters == null) {
    // We were passed an empty iterator
    return lengths;

  final FileOutputStream out = new FileOutputStream(outputFile, true);
  final long writeStartTime = System.nanoTime();
  boolean threwException = true;
  try {
    for (int i = 0; i < numPartitions; i++) {
      final File file = partitionWriters[i].fileSegment().file();
      if (file.exists()) {
        final FileInputStream in = new FileInputStream(file);
        boolean copyThrewException = true;
        try {
          lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
          copyThrewException = false;
        } finally {
          Closeables.close(in, copyThrewException);
        if (!file.delete()) {
          logger.error("Unable to delete file for partition {}", i);
    threwException = false;
  } finally {
    Closeables.close(out, threwException);
    writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
  partitionWriters = null;
  return lengths;


 * Write an index file with the offsets of each block, plus a final offset at the end for the
 * end of the output file. This will be used by getBlockData to figure out where each block
 * begins and ends.
 * It will commit the data and index file as an atomic operation, use the existing ones, or
 * replace them with new ones.
 * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
 * */
def writeIndexFileAndCommit(
    shuffleId: Int,
    mapId: Int,
    lengths: Array[Long],
    dataTmp: File): Unit = {
  val indexFile = getIndexFile(shuffleId, mapId)
  val indexTmp = Utils.tempFileWith(indexFile)
  try {
    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
    Utils.tryWithSafeFinally {
      // We take in lengths of each block, need to convert it to offsets.
      var offset = 0L
      for (length <- lengths) {
        offset += length
    } {

    val dataFile = getDataFile(shuffleId, mapId)
    // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
    // the following check and rename are atomic.
    synchronized {
      val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
      if (existingLengths != null) {
        // Another attempt for the same task has already written our map outputs successfully,
        // so just use the existing partition lengths and delete our temporary map outputs.
        System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
        if (dataTmp != null && dataTmp.exists()) {
      } else {
        // This is the first successful attempt in writing the map outputs for this task,
        // so override any existing index and data files with the ones we wrote.
        if (indexFile.exists()) {
        if (dataFile.exists()) {
        if (!indexTmp.renameTo(indexFile)) {
          throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
        if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
          throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
  } finally {
    if (indexTmp.exists() && !indexTmp.delete()) {
      logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")






首先描述下大概。因为是sort,所以要排序,这里就用到了ExternalSoter这个数据结构。然后把要处理的数据全部插入到ExternalSorter里面,在插入的过程中是不排序的,就是插入,插入数据是(partitionId,key,value)。然后是调用 sorter.writePartitionedFile,在这里会排序,会按照partitionId和key(或者key的hashcode)进行排序,其他的就和上面bypassShuffleWriter的差不多了,最后也是写到一个indexFile里面。返回mapStatus。

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
    // care whether the keys get sorted in each partition; that will be done on the reduce side
    // if the operation being run is sortByKey.
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)

  // Don't bother including the time to open the merged output file in the shuffle write time,
  // because it just opens a single file, so is typically too fast to measure accurately
  // (see SPARK-3570).
  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")

这里面ExternalSorter是核心。看它的源码,它存数据是使用的两种数据结构。PartitionedAppendOnlyMapPartitionedPairBuffer,其中有聚合操作使用map,没有聚合操作使用buffer。PartitionedAppendOnlyMap 继承了SizeTrackingAppendOnlyMap 和WritablePartitionedPairCollection 。 其中SizeTrackingAppendOnlyMap是用于预测空间(SizeTracker),然后加存储数据(AppendOnlyMap),然后WritablePartitionedPairCollection是用于插入数据时候插入partitionId(insert(partition: Int, key: K, value: V))加上里面实现了对数据按照partitionId和Key排序的方法。


看源码,它存储数据是private var data = new Array[AnyRef](2 * capacity),是使用数组存储的,key和value挨着,这样做是为了节省空间。


/** Set the value for a key */
def update(key: K, value: V): Unit = {
  assert(!destroyed, destructionMessage)
  val k = key.asInstanceOf[AnyRef]
  if (k.eq(null)) {
    if (!haveNullValue) {
    nullValue = value
    haveNullValue = true
  var pos = rehash(key.hashCode) & mask
  var i = 1
  while (true) {
    val curKey = data(2 * pos)
    if (curKey.eq(null)) {
      data(2 * pos) = k
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
      incrementSize()  // Since we added a new key
    } else if (k.eq(curKey) || k.equals(curKey)) {
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
    } else {
      val delta = i
      pos = (pos + delta) & mask
      i += 1



然后writePartitionedFile 里面调用collection.destructiveSortedWritablePartitionedIterator(comparator) 会对数据排序,之后就跟上一小节里面的writePartitionedFile差不多了,无非就是把内存里面的数据和spill的数据合并之后写入大文件里面,之后的writeIndexFile是一样的,就不细说。



public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
  // Keep track of success so we know if we encountered an exception
  // We do this rather than a standard try/catch/re-throw to handle
  // generic throwables.
  boolean success = false;
  try {
    while (records.hasNext()) {
    success = true;
  } finally {
    if (sorter != null) {
      try {
      } catch (Exception e) {
        // Only throw this error if we won't be masking another
        // error.
        if (success) {
          throw e;
        } else {
          logger.error("In addition to a failure during writing, we failed during " +
                       "cleanup.", e);

除了是写到堆外,其他应该跟sortShuffleWriter 差不多吧,懒得写了,以后发现有什么特别之处再补充。



/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
  val blockFetcherItr = new ShuffleBlockFetcherIterator(
    mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
    // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
    SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
    SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

  // Wrap the streams for compression based on configuration
  val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
    serializerManager.wrapForCompression(blockId, inputStream)

  val serializerInstance = dep.serializer.newInstance()

  // Create a key/value iterator for each stream
  val recordIter = wrappedStreams.flatMap { wrappedStream =>
    // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
    // NextIterator. The NextIterator makes sure that close() is called on the
    // underlying InputStream when all records have been read.

  // Update the context task metrics for each record read.
  val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
  val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
    recordIter.map { record =>

  // An interruptible iterator must be used here in order to support task cancellation
  val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

  val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
    if (dep.mapSideCombine) {
      // We are reading values that are already combined
      val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
      dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
    } else {
      // We don't know the value type, but also don't care -- the dependency *should*
      // have made sure its compatible w/ this aggregator, which will convert the value
      // type to the combined type C
      val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
      dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
  } else {
    require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
    interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]

  // Sort the output if there is a sort ordering defined.
  dep.keyOrdering match {
    case Some(keyOrd: Ordering[K]) =>
      // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
      // the ExternalSorter won't spill to disk.
      val sorter =
        new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
      CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
    case None =>

首先是建立一个ShuffleBlockFetcherIterator,传入的参数有mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),这个是必须的,只取需要的partition的数据。


private[this] def initialize(): Unit = {
  // Add a task completion callback (called in both success case and failure case) to cleanup.
  context.addTaskCompletionListener(_ => cleanup())

  // Split local and remote blocks.
  val remoteRequests = splitLocalRemoteBlocks()
  // Add the remote requests into our queue in a random order
  fetchRequests ++= Utils.randomize(remoteRequests)
  assert ((0 == reqsInFlight) == (0 == bytesInFlight),
    "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
    ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

  // Send out initial requests for blocks, up to our maxBytesInFlight

  val numFetches = remoteRequests.size - fetchRequests.size
  logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

  // Get Local Blocks
  logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))


private def fetchUpToMaxBytes(): Unit = {
  // Send fetch requests up to maxBytesInFlight
  while (fetchRequests.nonEmpty &&
    (bytesInFlight == 0 ||
      (reqsInFlight + 1 <= maxReqsInFlight &&
        bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {


private[this] def sendRequest(req: FetchRequest) {
  logDebug("Sending request for %d blocks (%s) from %s".format(
    req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
  bytesInFlight += req.size
  reqsInFlight += 1

  // so we can look up the size of each blockID
  val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
  val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
  val blockIds = req.blocks.map(_._1.toString)

  val address = req.address
  shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
    new BlockFetchingListener {
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
        // Only add the buffer to results queue if the iterator is not zombie,
        // i.e. cleanup() has not been called yet.
        ShuffleBlockFetcherIterator.this.synchronized {
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            remainingBlocks -= blockId
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
            logDebug("remainingBlocks: " + remainingBlocks)
        logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))

      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
        logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
        results.put(new FailureFetchResult(BlockId(blockId), address, e))





