import org.apache.spark.TaskContext
TaskContext
TaskContext
allows a task to access contextual information about itself as well as register task listeners.
Using TaskContext
you can access local properties that were set by the driver. You can also access task metrics.
You can access the active TaskContext
instance using TaskContext.get method.
TaskContext
belongs to org.apache.spark
package.
Note
|
TaskContext is serializable.
|
Contextual Information
-
stageId
is the id of the stage the task belongs to. -
partitionId
is the id of the partition computed by the task. -
attemptNumber
is to denote how many times the task has been attempted (starting from 0). -
taskAttemptId
is the id of the attempt of the task. -
isCompleted
returnstrue
when a task is completed. -
isInterrupted
returnstrue
when a task was killed.
All these attributes are accessible using appropriate getters, e.g. getPartitionId
for the partition id.
Registering Task Listeners
Using TaskContext
object you can register task listeners for task completion regardless of the final state and task failures only.
addTaskCompletionListener
Method
addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
addTaskCompletionListener
methods register a TaskCompletionListener
listener to be executed on task completion.
Note
|
It will be executed regardless of the final state of a task - success, failure, or cancellation. |
val rdd = sc.range(0, 5, numSlices = 1)
import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
rdd.foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskCompletionListener(printTaskInfo)
}
addTaskFailureListener
Method
addTaskFailureListener(listener: TaskFailureListener): TaskContext
addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener
methods register a TaskFailureListener
listener to be executed on task failure only. It can be executed multiple times since a task can be re-attempted when it fails.
val rdd = sc.range(0, 2, numSlices = 2)
import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|error: ${error.toString}
|-------------------""".stripMargin
println(msg)
}
val throwExceptionForOddNumber = (n: Long) => {
if (n % 2 == 1) {
throw new Exception(s"No way it will pass for odd number: $n")
}
}
// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
}
// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
val tc = TaskContext.get
tc.addTaskFailureListener(printTaskErrorInfo)
it
}.map(throwExceptionForOddNumber).count
Accessing Local Properties — getLocalProperty
Method
getLocalProperty(key: String): String
You can use getLocalProperty
method to access local properties that were initially set by the driver using SparkContext.setLocalProperty.
Task Metrics
taskMetrics(): TaskMetrics
taskMetrics
method is part of the Developer API that allows to access the instance of TaskMetrics for a task.
getMetricsSources
Method
getMetricsSources(sourceName: String): Seq[Source]
getMetricsSources
allows to access all the metrics sources by sourceName
which are associated with the instance that runs the task.
Accessing Active TaskContext — get
Method
get(): TaskContext
get
method returns the TaskContext
instance for an active task (as a TaskContextImpl object). There can only be one instance and tasks can use the object to access contextual information about themselves.
val rdd = sc.range(0, 3, numSlices = 3)
scala> rdd.partitions.size
res0: Int = 3
rdd.foreach { n =>
import org.apache.spark.TaskContext
val tc = TaskContext.get
val msg = s"""|-------------------
|partitionId: ${tc.partitionId}
|stageId: ${tc.stageId}
|attemptNum: ${tc.attemptNumber}
|taskAttemptId: ${tc.taskAttemptId}
|-------------------""".stripMargin
println(msg)
}
Note
|
TaskContext object uses ThreadLocal to keep it thread-local, i.e. to associate state with the thread of a task.
|
TaskContextImpl
TaskContextImpl
is the only implementation of TaskContext abstract class.
Caution
|
FIXME |
-
stage
-
partition
-
task attempt
-
attempt number
-
runningLocally = false
Caution
|
FIXME Where and how is TaskMemoryManager used?
|
Creating TaskContextImpl Instance
Caution
|
FIXME |
markInterrupted
Caution
|
FIXME |