Multi-threading with ThreadsX



 Table of contents

As you saw in the previous section, Base.Threads does not have a built-in parallel reduction. You can implement it yourself by hand, but all solutions are somewhat awkward, and you can run into problems with thread safety and performance (slow atomic variables, false sharing, etc) if you don’t pay close attention.

Enter ThreadsX, a multi-threaded Julia library that provides parallel versions of some of the Base functions. To see the list of supported functions, use the double-TAB feature inside REPL:

using ThreadsX
ThreadsX.<TAB>
?ThreadsX.mapreduce
?mapreduce

As you see in this example, not all functions in ThreadsX are well-documented, but this is exactly the point: they reproduce the functionality of their Base serial equivalents, so you can always look up help on a corresponding serial function.

Consider this Base function:

mapreduce(x->x^2, +, 1:10)   # sum up squares of all integers from 1 to 10

This function allows an alternative syntax:

mapreduce(+,1:10) do i
    i^2   # plays the role of the function applied to each element
end

To parallelize either snippet, replace mapreduce with ThreadsX.mapreduce, assuming you are running Julia with multiple threads. Do not time this code, as this computation is very fast, and the timing will mostly likely be dominated by an overhead from launching and terminating multiple threads. Instead, let’s parallelize and time the slow series.

Parallelizing the slow series with ThreadsX.mapreduce

In this and other examples we assume that you have already defined digitsin(). Save the following as mapreduce.jl:

using BenchmarkTools, ThreadsX
function slow(n::Int64, digitSequence::Int)
    total = ThreadsX.mapreduce(+,1:n) do i
		if !digitsin(digitSequence, i)
			1.0 / i
		else
			0.0
		end
    end
    return total
end
total = @btime slow(Int64(1e8), 9)
println("total = ", total)   # total = 13.277605949855294

With 4 CPU cores, I see:

$ julia mapreduce.jl        # runtime with 1 thread: 2.200 s
$ julia -t 4 mapreduce.jl   # runtime with 4 threads: 543.949 ms
$ julia -t 8 mapreduce.jl   # what should we expect?

Exercise “ThreadsX.1”

Using the compact (one-line) if-else notation, shorten this code by four lines. Time the new, shorter code with one and several threads.

Hint: the syntax is 1 > 2 ? "1 is greater than 2" : "1 is not greater than 2"

Parallelizing the slow series with ThreadsX.sum

?sum
sum(x->x^2, 1:10)
ThreadsX.sum(x->x^2, 1:10)
ThreadsX.sum(x^2 for x in 1:10)   # alternative syntax

The expression in the last round brackets is a generator. It generates a sequence on the fly without storing individual elements, thus taking very little memory.

(i for i in 1:10)          # generator
collect(i for i in 1:10)   # construct a vector (this one takes more space) from it
[i for i in 1:10]          # functionally the same (vector via an array comprehension)

Let’s use a generator with $10^8$ elements to compute our slow series sum:

using BenchmarkTools
@btime sum(!digitsin(9, i) ? 1.0/i : 0 for i in 1:100_000_000)
   # serial code: 2.183 s, prints 13.277605949858103

It is very easy to parallelize:

using BenchmarkTools, ThreadsX
@btime ThreadsX.sum(!digitsin(9, i) ? 1.0/i : 0 for i in 1:100_000_000)
   # with 4 threads: 527.573 ms, prints 13.277605949854381

Exercise “ThreadsX.2”

The expression [i for i in 1:10 if i%2==1] produces an array of odd integers between 1 and 10. Using this syntax, remove zero terms from the last generator, i.e. write a parallel code for summing the slow series with a generator that contains only non-zero terms. It should run slightly faster than the code with the original generator. (I get 527.159 ms runtime.)

Finally, let’s rewrite our code applying a function to all integers in a range:

function numericTerm(i)
    !digitsin(9, i) ? 1.0/i : 0
end
@btime ThreadsX.sum(numericTerm, 1:Int64(1e8))   # 571.915 ms, same result

Exercise “ThreadsX.3”

Rewrite the last code replacing sum with mapreduce. Hint: look up help for mapreduce().

Other parallel functions

ThreadsX provides various parallel functions for sorting. Sorting is intrinsically hard to parallelize, so do not expect 100% parallel efficiency. Let’s take a look at sort() and sort!():

n = Int64(1e8)
r = rand(Float32, (n));   # random floats in [0, 1]
r[1:10]      # first 20 elements, same as first(r,10)
last(r,10)   # last 10 elements

?sort              # underneath uses QuickSort (for numeric arrays) or MergeSort
@btime sort(r);    # 10.421 s, serial sorting
@btime sort!(r);   # 1.707 s, in-place serial sorting

r = rand(Float32, (n));
@btime ThreadsX.sort(r);    # 2.950 ms, parallel sorting with 4 threads
@btime ThreadsX.sort!(r);   # 1.115 ms, in-place parallel sorting with 4 threads
?ThreadsX.sort!             # there is actually a good manual page

# similar speedup for integers
r = rand(Int32, (n));
@btime sort!(r);   # 1.065 ms in serial

r = rand(Int32, (n));
@btime ThreadsX.sort!(r);   # 1.058 ms with 4 threads

Searching for extrema is much more parallel-friendly:

n = Int64(1e9)
r = rand(Int32, (n));        # make sure we have enough memory
@btime maximum(r)            # 328.375 ms
@btime ThreadsX.maximum(r)   # 82.562 ms with 4 threads

Finally, another useful function is ThreadsX.map() without reduction – we will take a closer look at it in one of the following sections.

To sum up this section, ThreadsX.jl provides a super easy way to parallelize some of the Base library functions. It includes multi-threaded reduction and shows very impressive parallel performance. To list the supported functions, use ThreadsX.<TAB>, and don’t forget to use the built-in help pages.