暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

spark内核源码深度剖析(十四):task原理剖析与源码分析

程序员雨衣 2019-10-25
466

源码

看下Executor类的 launchTask()
方法

  1. def launchTask(

  2. context: ExecutorBackend,

  3. taskId: Long,

  4. attemptNumber: Int,

  5. taskName: String,

  6. serializedTask: ByteBuffer) {

  7. // 对于每一个task,都会创建一个TaskRunner

  8. // TaskRunner继承的是Java多线程中的Runnable接口

  9. val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,

  10. serializedTask)

  11. // 将TaskRunner放入内存缓存

  12. runningTasks.put(taskId, tr)

  13. // Executor内部有一个Java线程池,这里其实将task封装在一个线程中(TaskRunner),直接将线程丢入线程池,进行执行

  14. // 线程池是自动实现了排队机制的,也就是说,如果线程池内的线程暂时没有空闲的,那么丢进去的线程都是要排队的

  15. threadPool.execute(tr)

  16. }

看下 newTaskRunner()

  1. /**

  2. * 从TaskRunner开始,来看Task的运行的工作原理

  3. */

  4. class TaskRunner(

  5. execBackend: ExecutorBackend,

  6. val taskId: Long,

  7. val attemptNumber: Int,

  8. taskName: String,

  9. serializedTask: ByteBuffer)

  10. extends Runnable {


  11. @volatile private var killed = false

  12. @volatile var task: Task[Any] = _

  13. @volatile var attemptedTask: Option[Task[Any]] = None

  14. @volatile var startGCTime: Long = _


  15. def kill(interruptThread: Boolean) {

  16. logInfo(s"Executor is trying to kill $taskName (TID $taskId)")

  17. killed = true

  18. if (task != null) {

  19. task.kill(interruptThread)

  20. }

  21. }


  22. override def run() {

  23. val deserializeStartTime = System.currentTimeMillis()

  24. Thread.currentThread.setContextClassLoader(replClassLoader)

  25. val ser = env.closureSerializer.newInstance()

  26. logInfo(s"Running $taskName (TID $taskId)")

  27. execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)

  28. var taskStart: Long = 0

  29. startGCTime = gcTime


  30. try {

  31. // 对序列化的task数据,进行反序列化

  32. val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)

  33. // 然后,通过网络通信,将需要的文件、资源、jar拷贝过来

  34. updateDependencies(taskFiles, taskJars)

  35. // 最后,通过正式的反序列化操作,将整个task的数据集反序列化回来

  36. // 这里用到了java的ClassLoader,因为java的ClassLoader可以干很多事情,比如,用反射的方式来动态加载一个类,创建这个类的对象,

  37. // 还有比如,可以用于对指定上下文的相关资源,进行加载和读取

  38. task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)


  39. // If this task has been killed before we deserialized it, let's quit now. Otherwise,

  40. // continue executing the task.

  41. if (killed) {

  42. // Throw an exception rather than returning, because returning within a try{} block

  43. // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl

  44. // exception will be caught by the catch block, leading to an incorrect ExceptionFailure

  45. // for the task.

  46. throw new TaskKilledException

  47. }


  48. attemptedTask = Some(task)

  49. logDebug("Task " + taskId + "'s epoch is " + task.epoch)

  50. env.mapOutputTracker.updateEpoch(task.epoch)


  51. // Run the actual task and measure its runtime.

  52. // 计算出task开始的时间

  53. taskStart = System.currentTimeMillis()

  54. // 执行task,用的是task的run()方法

  55. // 这里的value,对于ShuffleMapTask来说,其实就是MapStatus,封装了ShuffleMapTask计算的数据,输出的位置

  56. // 后面还是一个ShuffleMapTask,那么就会去联系MapOutputTracker,来获取上一个ShuffleMapTasks的输出位置,然后通过网络拉取数据

  57. // ResultTask,也是一样的

  58. val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)

  59. // 计算出task结束的时间

  60. val taskFinish = System.currentTimeMillis()


  61. // If the task has been killed, let's fail it.

  62. if (task.killed) {

  63. throw new TaskKilledException

  64. }

  65. // 这个,其实就是针对MapStatus进行了各种序列化和封装,因为后面要发送给Driver(通过网络)

  66. //

  67. val resultSer = env.serializer.newInstance()

  68. val beforeSerialization = System.currentTimeMillis()

  69. val valueBytes = resultSer.serialize(value)

  70. val afterSerialization = System.currentTimeMillis()


  71. // 计算出task相关的一些metrics,就是统计信息,包括运行了多长时间、反序列化消耗了多长时间、java虚拟机gc耗费了多长时间

  72. // 结果的序列化耗费了多长时间,这些东西,其实会在我们的SparkUI上显示

  73. for (m <- task.metrics) {

  74. m.setExecutorDeserializeTime(taskStart - deserializeStartTime)

  75. m.setExecutorRunTime(taskFinish - taskStart)

  76. m.setJvmGCTime(gcTime - startGCTime)

  77. m.setResultSerializationTime(afterSerialization - beforeSerialization)

  78. }


  79. val accumUpdates = Accumulators.values


  80. val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)

  81. val serializedDirectResult = ser.serialize(directResult)

  82. val resultSize = serializedDirectResult.limit


  83. // directSend = sending directly back to the driver

  84. val serializedResult = {

  85. if (maxResultSize > 0 && resultSize > maxResultSize) {

  86. logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +

  87. s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +

  88. s"dropping it.")

  89. ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))

  90. } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {

  91. val blockId = TaskResultBlockId(taskId)

  92. env.blockManager.putBytes(

  93. blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)

  94. logInfo(

  95. s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")

  96. ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))

  97. } else {

  98. logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")

  99. serializedDirectResult

  100. }

  101. }


  102. // 其实就是调用了Executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法

  103. execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)


  104. }

  105. }

看下 updateDependencies()
方法

  1. private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {

  2. // 获取hadoop配置文件

  3. lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)

  4. // 这里,使用java的synchronized进行了多线程并发访问的同步

  5. // 因为task实际上是以java线程的方式,在一个CoarseGrainedExecutorBackend进程内并发运行的

  6. // 如果在执行业务逻辑的时候,要访问一些共享的资源,那么就可能会出现多线程并发访问安全问题

  7. // 所以,spark在这里选择进行了多线程并发访问的同步(synchronized),因为在这里面访问了诸如currentFiles等等这些共享资源


  8. synchronized {

  9. // Fetch missing dependencies

  10. // 遍历要拉取的文件

  11. // 通过Utils的fetchFile()方法,通过网络通信,从远程拉取文件

  12. for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {

  13. logInfo("Fetching " + name + " with timestamp " + timestamp)

  14. // Fetch file with useCache mode, close cache for local mode.

  15. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,

  16. env.securityManager, hadoopConf, timestamp, useCache = !isLocal)

  17. currentFiles(name) = timestamp

  18. }

  19. // 遍历要拉取的jar

  20. for ((name, timestamp) <- newJars) {

  21. val localName = name.split("/").last

  22. // 判断一下时间戳,要求jar当前时间戳必须小于目标时间戳

  23. // 通过Utils的fetchFile(),拉取jar文件

  24. val currentTimeStamp = currentJars.get(name)

  25. .orElse(currentJars.get(localName))

  26. .getOrElse(-1L)

  27. if (currentTimeStamp < timestamp) {

  28. logInfo("Fetching " + name + " with timestamp " + timestamp)

  29. // Fetch file with useCache mode, close cache for local mode.

  30. Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,

  31. env.securityManager, hadoopConf, timestamp, useCache = !isLocal)

  32. currentJars(name) = timestamp

  33. // Add it to our class loader

  34. val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL

  35. if (!urlClassLoader.getURLs.contains(url)) {

  36. logInfo("Adding " + url + " to class loader")

  37. urlClassLoader.addURL(url)

  38. }

  39. }

  40. }

  41. }

  42. }

看下 task.run()
方法

  1. final def run(taskAttemptId: Long, attemptNumber: Int): T = {

  2. // 创建一个TaskContext,就是task的执行上下文,里面记录了task执行的一些全局性的数据,比如task重试了几次

  3. // 比如task属于哪个stage,task要处理的是rdd的哪个partition等等

  4. context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,

  5. taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)

  6. TaskContextHelper.setTaskContext(context)

  7. context.taskMetrics.setHostname(Utils.localHostName())

  8. taskThread = Thread.currentThread()

  9. if (_killed) {

  10. kill(interruptThread = false)

  11. }

  12. try {

  13. // 调用抽象方法,runTask()

  14. runTask(context)

  15. } finally {

  16. context.markTaskCompleted()

  17. TaskContextHelper.unset()

  18. }

  19. }

看下这个 runTask()
方法

  1. // 调用到了抽象方法,那就意味着这个类,只是一个模板类,或者抽象父类,仅仅封装了一些子类通用的数据和操作

  2. // 而关键的操作,全部都要依赖于子类的实现,task的子类,有ShuffleMapTask、ResultTask

  3. // 要运行子类的runTask()方法,才能执行我们自己定义的算子和逻辑

  4. def runTask(context: TaskContext): T

接下来分别看下ShuffleMapTask和ResultTask的 runTask()
方法 先看ShuffleMapTask的 一个ShuffleMapTask会将一个RDD的元素,切分为多个bucket,基于一个在ShuffleDependency中指定的partitioner,默认就是HashPartition

  1. /**

  2. * ShuffleMapTask的runTask()方法有MapStatus返回值

  3. */

  4. override def runTask(context: TaskContext): MapStatus = {

  5. // Deserialize the RDD using the broadcast variable.

  6. // 对task要处理的rdd相关的数据,做一些反序列化操作

  7. // 这里有一个问题,如何拿到这个要处理的RDD

  8. // 多个task运行在多个Executor中,都是并行运行,或者并发运行的,可能都不在一个地方,但是一个stage的task,其实要处理的rdd是一样,所以task如何拿到自己要处理的rdd数据?

  9. // 这里会通过broadcast variable 直接拿到

  10. val ser = SparkEnv.get.closureSerializer.newInstance()

  11. val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](

  12. ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)


  13. metrics = Some(context.taskMetrics)

  14. var writer: ShuffleWriter[Any, Any] = null

  15. try {

  16. // 获取ShuffleManager

  17. val manager = SparkEnv.get.shuffleManager

  18. // 从ShuffleManager中获取ShuffleWriter

  19. writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)

  20. // 首先调用了,rdd的iterator()方法,并且传入了,当前task要处理哪个partition

  21. // 所以核心的逻辑,就在rdd的iterator()方法中,在这里,实现了针对rdd的某个partition,执行我们自己定义的算子,或者是函数

  22. // 执行完了我们自己定义的算子、或者函数,就相当于是,针对rdd的partition执行了处理,会有返回的数据

  23. // 返回的数据,都是通过ShuffleWriter,经过HashPartitioner进行分区之后,写入自己对应的分区bucket

  24. writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

  25. // 最后,返回结果MapStatus,MapStatus里面封装了ShuffleMapTask计算后的数据,数据存储在哪里,其实就是BlockManager的相关信息

  26. // BlockManager是Spark底层的内存,数据,磁盘数据管理的组件

  27. return writer.stop(success = true).get

  28. } catch {

  29. case e: Exception =>

  30. try {

  31. if (writer != null) {

  32. writer.stop(success = false)

  33. }

  34. } catch {

  35. case e: Exception =>

  36. log.debug("Could not stop writer", e)

  37. }

  38. throw e

  39. }

  40. }

看下 rdd.iterator()
方法

  1. final def iterator(split: Partition, context: TaskContext): Iterator[T] = {

  2. if (storageLevel != StorageLevel.NONE) {

  3. // cacheManager相关东西

  4. SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)

  5. } else {

  6. // 进行rdd partition的计算

  7. computeOrReadCheckpoint(split, context)

  8. }

  9. }

看下 computeOrReadCheckpoint()
方法

  1. private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =

  2. {

  3. // Checkpointed相关先忽略

  4. if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)

  5. }

看下 compute()
方法

  1. // 抽象方法,找具体实现类,比如MapPartitionsRDD

  2. def compute(split: Partition, context: TaskContext): Iterator[T]

看下MapPartitionsRDD的 compute()
方法

  1. // 这里,就是针对rdd中的某个partition执行我们给这个rdd定义的算子和函数

  2. // 这里的f,可以理解为我们自己定义的算子和函数,但是是Spark内部进行了封装的,还实现了一些其他的逻辑

  3. // 执行到了这里,就是在针对RDD的partition,执行自定义的计算操作,并返回新的rdd的partition数据

  4. override def compute(split: Partition, context: TaskContext) =

  5. f(context, split.index, firstParent[T].iterator(split, context))

看下ResultTask的 runTask()
方法

  1. override def runTask(context: TaskContext): U = {

  2. // Deserialize the RDD and the func using the broadcast variables.

  3. // 进行了基本的反序列化

  4. val ser = SparkEnv.get.closureSerializer.newInstance()

  5. val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](

  6. ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)


  7. metrics = Some(context.taskMetrics)

  8. // 执行通过rdd的iterator,执行我们定义的算子和函数

  9. func(context, rdd.iterator(partition, context))

  10. }

接下来看看 execBackend.statusUpdate()
方法

  1. // 其实就是调用了Executor所在的CoarseGrainedExecutorBackend的statusUpdate()方法

  2. execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

CoarseGrainedExecutorBackend的 statusUpdate()
方法

  1. override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {

  2. // 向CoarseGrainedSchedulerBackend发送一个StatusUpdate消息

  3. driver ! StatusUpdate(executorId, taskId, state, data)

  4. }

看CoarseGrainedSchedulerBackend的StatusUpdate

  1. // 处理task执行结束的事件

  2. case StatusUpdate(executorId, taskId, state, data) =>

  3. scheduler.statusUpdate(taskId, state, data.value)

  4. if (TaskState.isFinished(state)) {

  5. executorDataMap.get(executorId) match {

  6. case Some(executorInfo) =>

  7. executorInfo.freeCores += scheduler.CPUS_PER_TASK

  8. makeOffers(executorId)

  9. case None =>

  10. // Ignoring the update since we don't know about the executor.

  11. logWarning(s"Ignored task status update ($taskId state $state) " +

  12. "from unknown executor $sender with ID $executorId")

  13. }

  14. }

看看 scheduler.statusUpdate()
方法

  1. def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {

  2. var failedExecutor: Option[String] = None

  3. synchronized {

  4. try {

  5. // 判断如果task是lost了,实际上,可能会经常发现task lost了,这就是因为各种各样的原因,执行失败了

  6. if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {

  7. // We lost this entire executor, so remember that it's gone

  8. // 移除Executor,将它加入失败队列

  9. val execId = taskIdToExecutorId(tid)

  10. if (activeExecutorIds.contains(execId)) {

  11. removeExecutor(execId)

  12. failedExecutor = Some(execId)

  13. }

  14. }

  15. // 获取对应的taskSet

  16. taskIdToTaskSetId.get(tid) match {

  17. case Some(taskSetId) =>

  18. // 如果task结束了,从内存缓存中移除

  19. if (TaskState.isFinished(state)) {

  20. taskIdToTaskSetId.remove(tid)

  21. taskIdToExecutorId.remove(tid)

  22. }

  23. // 如果正常结束,也做相应的处理

  24. activeTaskSets.get(taskSetId).foreach { taskSet =>

  25. if (state == TaskState.FINISHED) {

  26. taskSet.removeRunningTask(tid)

  27. taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)

  28. } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {

  29. taskSet.removeRunningTask(tid)

  30. taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)

  31. }

  32. }

  33. case None =>

  34. logError(

  35. ("Ignoring update with state %s for TID %s because its task set is gone (this is " +

  36. "likely the result of receiving duplicate task finished status updates)")

  37. .format(state, tid))

  38. }

  39. } catch {

  40. case e: Exception => logError("Exception in statusUpdate", e)

  41. }

  42. }

  43. // Update the DAGScheduler without holding a lock on this, since that can deadlock

  44. if (failedExecutor.isDefined) {

  45. dagScheduler.executorLost(failedExecutor.get)

  46. backend.reviveOffers()

  47. }


文章转载自程序员雨衣,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

评论