Skip to content
On this page

ForkJoin

Starting with Java 7, a new Fork/Join thread pool was introduced, which can execute a special kind of task: breaking down a large task into multiple smaller tasks and executing them in parallel.

Let's take an example: if you want to calculate the sum of a very large array, the simplest approach is to use a loop to complete it within a single thread:

┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

Another method is to split the array into two parts, calculate each part separately, and then add them together to get the final result. This allows using two threads to execute in parallel:

┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┴─┘

If splitting into two parts is still too large, we can continue splitting and execute with four threads in parallel:

┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘
┌─┬─┬─┬─┬─┬─┐
└─┴─┴─┴─┴─┴─┘

This is the principle of Fork/Join tasks: determine whether a task is small enough to compute directly. If it is, compute it; otherwise, split it into smaller tasks and compute them separately. This process can repeatedly "fork" into a series of small tasks.

Let's see how to use Fork/Join to perform parallel summation on large data:

java
import java.util.Random;
import java.util.concurrent.*;

public class Main {
    public static void main(String[] args) throws Exception {
        // Create an array of 2000 random numbers:
        long[] array = new long[2000];
        long expectedSum = 0;
        for (int i = 0; i < array.length; i++) {
            array[i] = random();
            expectedSum += array[i];
        }
        System.out.println("Expected sum: " + expectedSum);
        // Fork/Join:
        ForkJoinTask<Long> task = new SumTask(array, 0, array.length);
        long startTime = System.currentTimeMillis();
        Long result = ForkJoinPool.commonPool().invoke(task);
        long endTime = System.currentTimeMillis();
        System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
    }

    static Random random = new Random(0);

    static long random() {
        return random.nextInt(10000);
    }
}

class SumTask extends RecursiveTask<Long> {
    static final int THRESHOLD = 500;
    long[] array;
    int start;
    int end;

    SumTask(long[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {
        if (end - start <= THRESHOLD) {
            // If the task is small enough, compute directly:
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += this.array[i];
                // Intentionally slow down the computation:
                try {
                    Thread.sleep(1);
                } catch (InterruptedException e) {
                }
            }
            return sum;
        }
        // Task is too large, split it into two:
        int middle = (end + start) / 2;
        System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
        SumTask subtask1 = new SumTask(this.array, start, middle);
        SumTask subtask2 = new SumTask(this.array, middle, end);
        invokeAll(subtask1, subtask2);
        Long subresult1 = subtask1.join();
        Long subresult2 = subtask2.join();
        Long result = subresult1 + subresult2;
        System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
        return result;
    }
}

Observing the execution process of the above code, a large computation task from 0~2000 is first split into two smaller tasks: 0~1000 and 1000~2000. These two tasks are still too large and continue to split into even smaller tasks: 0~500, 500~1000, 1000~1500, 1500~2000. Finally, the computation results are merged sequentially to obtain the final result.

Therefore, the core code SumTask extends RecursiveTask and overrides the compute() method. The key is how to "split" sub-tasks and submit them:

java
class SumTask extends RecursiveTask<Long> {
    protected Long compute() {
        // "Split" sub-tasks:
        SumTask subtask1 = new SumTask(...);
        SumTask subtask2 = new SumTask(...);
        // invokeAll runs the two sub-tasks in parallel:
        invokeAll(subtask1, subtask2);
        // Get the results of the sub-tasks:
        Long subresult1 = subtask1.join();
        Long subresult2 = subtask2.join();
        // Combine the results:
        return subresult1 + subresult2;
    }
}

The Fork/Join thread pool is already utilized within the Java standard library. The java.util.Arrays.parallelSort(array) method provided by the Java standard library can perform parallel sorting. Its principle is to internally use Fork/Join to split large arrays for parallel sorting, which can significantly improve sorting speed on multi-core CPUs.

Exercise

Use Fork/Join.

Summary

  • Fork/Join is a "divide and conquer" algorithm: it decomposes tasks, executes them in parallel, and then merges the results to obtain the final outcome.
  • The ForkJoinPool thread pool can split a large task into smaller tasks for parallel execution. Task classes must extend RecursiveTask or RecursiveAction.
  • Using the Fork/Join pattern allows for parallel computation to improve efficiency.
ForkJoin has loaded