Skip to content

Commit fd10439

Browse files
authored
Merge branch 'master' into dw/aqua
2 parents 95c8076 + 9ce4ab4 commit fd10439

File tree

17 files changed

+382
-277
lines changed

17 files changed

+382
-277
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,8 @@ jobs:
3131
- ubuntu-latest
3232
- windows-latest
3333
- macOS-latest
34-
exclude:
35-
# For Julia 1.6 no aarch64 binary exists
36-
- version: 'min'
37-
os: macOS-latest
38-
include:
39-
- version: 'min'
40-
os: macOS-13 # uses x64
4134
steps:
42-
- uses: actions/checkout@v5
35+
- uses: actions/checkout@v6
4336
- uses: julia-actions/setup-julia@v2
4437
with:
4538
version: ${{ matrix.version }}
@@ -56,7 +49,7 @@ jobs:
5649
name: Documentation
5750
runs-on: ubuntu-latest
5851
steps:
59-
- uses: actions/checkout@v5
52+
- uses: actions/checkout@v6
6053
- uses: julia-actions/setup-julia@v2
6154
with:
6255
version: '1'

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: "Run CompatHelper"
2525
run: |
2626
import CompatHelper
27-
CompatHelper.main(; dirs = ["", "docs"])
27+
CompatHelper.main(; subdirs = ["", "docs"])
2828
shell: julia --color=yes {0}
2929
env:
3030
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
11+
12+
[weakdeps]
1113
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1214
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1315

16+
[extensions]
17+
SparseArraysExt = "SparseArrays"
18+
StatisticsExt = "Statistics"
19+
1420
[compat]
1521
Aqua = "0.8.12"
1622
Distributed = "<0.0.1, 1"
23+
ExplicitImports = "1.13.2"
1724
LinearAlgebra = "<0.0.1, 1"
1825
Primes = "0.4, 0.5"
1926
Random = "<0.0.1, 1"
@@ -22,12 +29,15 @@ SparseArrays = "<0.0.1, 1"
2229
SpecialFunctions = "0.8, 1, 2"
2330
Statistics = "<0.0.1, 1"
2431
Test = "<0.0.1, 1"
25-
julia = "1"
32+
julia = "1.10"
2633

2734
[extras]
2835
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
36+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
37+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2938
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
39+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3040
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3141

3242
[targets]
33-
test = ["Aqua", "Test", "SpecialFunctions"]
43+
test = ["Aqua", "ExplicitImports", "SparseArrays", "SpecialFunctions", "Statistics", "Test"]

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ julia> import Pkg; Pkg.add("DistributedArrays")
3333

3434
## Project Status
3535

36-
The package is tested against Julia `0.7`, `1.0` and the nightly builds of the Julia `master` branch on Linux, and macOS.
36+
The package is tested against
37+
Julia 1.10.0 (oldest supported Julia version),
38+
the Julia LTS version,
39+
the latest stable release of Julia,
40+
and the pre-release version of Julia.
3741

3842
## Questions and Contributions
3943

ext/SparseArraysExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module SparseArraysExt
2+
3+
using DistributedArrays: DArray, localpart
4+
using DistributedArrays.Distributed: remotecall_fetch
5+
using SparseArrays: SparseArrays, nnz
6+
7+
function SparseArrays.nnz(A::DArray)
8+
B = asyncmap(A.pids) do p
9+
remotecall_fetch(nnzlocalpart, p, A)
10+
end
11+
return reduce(+, B)
12+
end
13+
14+
end

ext/StatisticsExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module StatisticsExt
2+
3+
using DistributedArrays: DArray
4+
using Statistics: Statistics
5+
6+
Statistics._mean(f, A::DArray, region) = sum(f, A, dims = region) ./ prod((size(A, i) for i in region))
7+
8+
end

src/DistributedArrays.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
__precompile__()
2-
31
module DistributedArrays
42

5-
using Distributed
6-
using Serialization
7-
using LinearAlgebra
8-
using Statistics
3+
using Base: Callable
4+
using Base.Broadcast: BroadcastStyle, Broadcasted
95

10-
import Base: +, -, *, div, mod, rem, &, |, xor
11-
import Base.Callable
12-
import LinearAlgebra: axpy!, dot, norm, mul!
6+
using Distributed: Distributed, RemoteChannel, Future, myid, nworkers, procs, remotecall, remotecall_fetch, remotecall_wait, worker_id_from_socket, workers
7+
using LinearAlgebra: LinearAlgebra, Adjoint, Diagonal, I, Transpose, adjoint, adjoint!, axpy!, dot, lmul!, mul!, norm, rmul!, transpose, transpose!
8+
using Random: Random, rand!
9+
using Serialization: Serialization, AbstractSerializer, deserialize, serialize
1310

14-
import Primes
15-
import Primes: factor
11+
using Primes: factor
1612

1713
import SparseArrays
1814

@@ -24,7 +20,7 @@ export dzeros, dones, dfill, drand, drandn, distribute, localpart, localindices,
2420
export ddata, gather
2521

2622
# immediate release of localparts
27-
export close, d_closeall
23+
export d_closeall
2824

2925
include("darray.jl")
3026
include("core.jl")

src/broadcast.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
# Distributed broadcast implementation
33
##
44

5-
using Base.Broadcast
6-
import Base.Broadcast: BroadcastStyle, Broadcasted
7-
85
# We define a custom ArrayStyle here since we need to keep track of
96
# the fact that it is Distributed and what kind of underlying broadcast behaviour
107
# we will encounter.
@@ -13,11 +10,11 @@ DArrayStyle(::S) where {S} = DArrayStyle{S}()
1310
DArrayStyle(::S, ::Val{N}) where {S,N} = DArrayStyle(S(Val(N)))
1411
DArrayStyle(::Val{N}) where N = DArrayStyle{Broadcast.DefaultArrayStyle{N}}()
1512

16-
BroadcastStyle(::Type{<:DArray{<:Any, N, A}}) where {N, A} = DArrayStyle(BroadcastStyle(A), Val(N))
13+
Broadcast.BroadcastStyle(::Type{<:DArray{<:Any, N, A}}) where {N, A} = DArrayStyle(BroadcastStyle(A), Val(N))
1714

1815
# promotion rules
1916
# TODO: test this
20-
function BroadcastStyle(::DArrayStyle{AStyle}, ::DArrayStyle{BStyle}) where {AStyle, BStyle}
17+
function Broadcast.BroadcastStyle(::DArrayStyle{AStyle}, ::DArrayStyle{BStyle}) where {AStyle, BStyle}
2118
DArrayStyle(BroadcastStyle(AStyle, BStyle))
2219
end
2320

@@ -70,18 +67,18 @@ end
7067
# This will turn local AbstractArrays into DArrays
7168
dbc = bcdistribute(bc)
7269

73-
asyncmap(procs(dest)) do p
74-
remotecall_fetch(p) do
70+
@sync for p in procs(dest)
71+
@async remotecall_wait(p) do
7572
# get the indices for the localpart
7673
lpidx = localpartindex(dest)
7774
@assert lpidx != 0
7875
# create a local version of the broadcast, by constructing views
7976
# Note: creates copies of the argument
8077
lbc = bclocal(dbc, dest.indices[lpidx])
81-
Base.copyto!(localpart(dest), lbc)
82-
return nothing
78+
copyto!(localpart(dest), lbc)
8379
end
8480
end
81+
8582
return dest
8683
end
8784

src/core.jl

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,109 @@
1-
const registry=Dict{Tuple, Any}()
2-
const refs=Set() # Collection of darray identities created on this node
1+
# Thread-safe registry of DArray references
2+
struct DArrayRegistry
3+
data::Dict{Tuple{Int,Int}, Any}
4+
lock::ReentrantLock
5+
DArrayRegistry() = new(Dict{Tuple{Int,Int}, Any}(), ReentrantLock())
6+
end
7+
const REGISTRY = DArrayRegistry()
8+
9+
function Base.get(r::DArrayRegistry, id::Tuple{Int,Int}, default)
10+
@lock r.lock begin
11+
return get(r.data, id, default)
12+
end
13+
end
14+
function Base.getindex(r::DArrayRegistry, id::Tuple{Int,Int})
15+
@lock r.lock begin
16+
return r.data[id]
17+
end
18+
end
19+
function Base.setindex!(r::DArrayRegistry, val, id::Tuple{Int,Int})
20+
@lock r.lock begin
21+
r.data[id] = val
22+
end
23+
return r
24+
end
25+
function Base.delete!(r::DArrayRegistry, id::Tuple{Int,Int})
26+
@lock r.lock delete!(r.data, id)
27+
return r
28+
end
29+
30+
# Thread-safe set of IDs of DArrays created on this node
31+
struct DArrayRefs
32+
data::Set{Tuple{Int,Int}}
33+
lock::ReentrantLock
34+
DArrayRefs() = new(Set{Tuple{Int,Int}}(), ReentrantLock())
35+
end
36+
const REFS = DArrayRefs()
337

4-
let DID::Int = 1
5-
global next_did
6-
next_did() = (id = DID; DID += 1; (myid(), id))
38+
function Base.push!(r::DArrayRefs, id::Tuple{Int,Int})
39+
# Ensure id refers to a DArray created on this node
40+
if first(id) != myid()
41+
throw(
42+
ArgumentError(
43+
lazy"`DArray` is not created on the current worker: Only `DArray`s created on worker $(myid()) can be stored in this set but the `DArray` was created on worker $(first(id))."))
44+
end
45+
@lock r.lock begin
46+
return push!(r.data, id)
47+
end
48+
end
49+
function Base.delete!(r::DArrayRefs, id::Tuple{Int,Int})
50+
@lock r.lock delete!(r.data, id)
51+
return r
752
end
853

54+
# Global counter to generate a unique ID for each DArray
55+
const DID = Threads.Atomic{Int}(1)
56+
957
"""
1058
next_did()
1159
12-
Produces an incrementing ID that will be used for DArrays.
13-
"""
14-
next_did
60+
Increment a global counter and return a tuple of the current worker ID and the incremented
61+
value of the counter.
1562
16-
release_localpart(id::Tuple) = (delete!(registry, id); nothing)
17-
release_localpart(d) = release_localpart(d.id)
63+
This tuple is used as a unique ID for a new `DArray`.
64+
"""
65+
next_did() = (myid(), Threads.atomic_add!(DID, 1))
1866

19-
function close_by_id(id, pids)
20-
# @async println("Finalizer for : ", id)
21-
global refs
67+
release_localpart(id::Tuple{Int,Int}) = (delete!(REGISTRY, id); nothing)
68+
function release_allparts(id::Tuple{Int,Int}, pids::Array{Int})
2269
@sync begin
70+
released_myid = false
2371
for p in pids
24-
@async remotecall_fetch(release_localpart, p, id)
72+
if p == myid()
73+
@async release_localpart(id)
74+
released_myid = true
75+
else
76+
@async remotecall_fetch(release_localpart, p, id)
77+
end
2578
end
26-
if !(myid() in pids)
27-
release_localpart(id)
79+
if !released_myid
80+
@async release_localpart(id)
2881
end
2982
end
30-
delete!(refs, id)
31-
nothing
83+
return nothing
3284
end
3385

34-
function Base.close(d::DArray)
35-
# @async println("close : ", d.id, ", object_id : ", object_id(d), ", myid : ", myid() )
36-
if (myid() == d.id[1]) && d.release
37-
@async close_by_id(d.id, d.pids)
38-
d.release = false
39-
end
86+
function close_by_id(id::Tuple{Int,Int}, pids::Array{Int})
87+
release_allparts(id, pids)
88+
delete!(REFS, id)
4089
nothing
4190
end
4291

4392
function d_closeall()
44-
crefs = copy(refs)
45-
for id in crefs
46-
if id[1] == myid() # sanity check
47-
if haskey(registry, id)
48-
d = d_from_weakref_or_d(id)
49-
(d === nothing) || close(d)
93+
@lock REFS.lock begin
94+
while !isempty(REFS.data)
95+
id = pop!(REFS.data)
96+
d = d_from_weakref_or_d(id)
97+
if d isa DArray
98+
finalize(d)
5099
end
51-
yield()
52100
end
53101
end
102+
return nothing
54103
end
55104

105+
Base.close(d::DArray) = finalize(d)
106+
56107
"""
57108
procs(d::DArray)
58109
@@ -67,4 +118,3 @@ Distributed.procs(d::SubDArray) = procs(parent(d))
67118
The identity when input is not distributed
68119
"""
69120
localpart(A) = A
70-

0 commit comments

Comments
 (0)