diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 000000000..3a11c113f --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,10 @@ +# When using system MPI, run once in the environment where you run MPI jobs (with MPI module loaded): +# julia --project=Dagger.jl -e 'using MPIPreferences; MPIPreferences.use_system_binary()' +# That populates abi, libmpi, mpiexec and avoids "Unknown MPI ABI nothing". +[MPIPreferences] +_format = "1.0" +abi = "MPICH" +binary = "system" +libmpi = "libmpi" +mpiexec = "mpiexec" +preloads = [] diff --git a/Project.toml b/Project.toml index ce49bf6d7..69163e027 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" NextLA = "d37ed344-79c4-486d-9307-6d11355a15a3" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" @@ -77,6 +78,7 @@ Graphs = "1" JSON3 = "1" KernelAbstractions = "0.9" MacroTools = "0.5" +MPI = "0.20.22" MemPool = "0.4.12" Metal = "1.1" NextLA = "0.2.2" diff --git a/benchmarks/check_comm_asymmetry.jl b/benchmarks/check_comm_asymmetry.jl new file mode 100644 index 000000000..684240ec5 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.jl @@ -0,0 +1,111 @@ +#!/usr/bin/env julia +# Parse MPI+Dagger logs and report communication decision asymmetry per tag. +# Asymmetry: for the same tag, one rank decides to send (local+bcast, sender+communicated, etc.) +# and another rank decides to infer (inferred, uninvolved) and never recv → deadlock. +# +# Usage: julia check_comm_asymmetry.jl < logfile +# Or: mpiexec -n 10 julia ... run_matmul.jl 2>&1 | tee matmul.log; julia check_comm_asymmetry.jl < matmul.log + +const SEND_DECISIONS = Set([ + "local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast", + "aliasing", # when followed by local+bcast we already capture local+bcast +]) +const RECV_DECISIONS = Set([ + "communicated", "receiver", "sender+communicated", # received data +]) +const INFER_DECISIONS = Set([ + "inferred", "uninvolved", # did not recv (uses inferred type) +]) + +function parse_line(line) + # Match [rank X][tag Y] then any [...] and capture the last bracket pair before space or end + rank = nothing + tag = nothing + decision = nothing + category = nothing # aliasing, execute!, remotecall_endpoint + for m in eachmatch(r"\[rank\s+(\d+)\]", line) + rank = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[tag\s+(\d+)\]", line) + tag = parse(Int, m.captures[1]) + end + for m in eachmatch(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + category = m.captures[1] + end + # Decision is usually in last [...] that looks like [word] or [word+word] + for m in eachmatch(r"\]\[([^\]]+)\]", line) + candidate = m.captures[1] + # Normalize: "communicated" "inferred" "local+bcast" "sender+inferred" "receiver" etc. + if occursin("inferred", candidate) && !occursin("communicated", candidate) + decision = "inferred" + break + elseif occursin("communicated", candidate) + decision = "communicated" + break + elseif occursin("local+bcast", candidate) + decision = "local+bcast" + break + elseif occursin("sender+", candidate) + decision = startswith(candidate, "sender+inferred") ? "sender+inferred" : "sender+communicated" + break + elseif candidate == "receiver" + decision = "receiver" + break + elseif candidate == "receiver+bcast" + decision = "receiver+bcast" + break + elseif candidate == "inplace_move" + decision = "inplace_move" + break + end + end + return rank, tag, category, decision +end + +function main() + # tag => Dict(rank => decision) + by_tag = Dict{Int, Dict{Int, String}}() + for line in eachline(stdin) + rank, tag, category, decision = parse_line(line) + isnothing(rank) && continue + isnothing(tag) && continue + isnothing(decision) && continue + if !haskey(by_tag, tag) + by_tag[tag] = Dict{Int, String}() + end + by_tag[tag][rank] = decision + end + + # For each tag, check: is there at least one sender and one inferrer (non-receiver)? + send_keys = Set(["local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"]) + infer_keys = Set(["inferred", "sender+inferred"]) # sender+inferred means sender didn't need to recv + recv_keys = Set(["communicated", "receiver", "sender+communicated"]) + + asymmetries = [] + for (tag, ranks) in sort(collect(by_tag), by = first) + senders = [r for (r, d) in ranks if d in send_keys] + inferrers = [r for (r, d) in ranks if d in infer_keys || d == "uninvolved"] + receivers = [r for (r, d) in ranks if d in recv_keys] + # Asymmetry: someone sends (bcast) so will send to ALL other ranks; someone chose infer and won't recv. + if !isempty(senders) && !isempty(inferrers) + push!(asymmetries, (tag, senders, inferrers, receivers, ranks)) + end + end + + if isempty(asymmetries) + println("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + end + + println("=== Communication decision asymmetry (can cause deadlock) ===\n") + for (tag, senders, inferrers, receivers, ranks) in asymmetries + println("Tag $tag:") + println(" Senders (will bcast to all others): $senders") + println(" Inferrers (did not recv): $inferrers") + println(" Receivers: $receivers") + println(" All decisions: $ranks") + println() + end +end + +main() diff --git a/benchmarks/check_comm_asymmetry.py b/benchmarks/check_comm_asymmetry.py new file mode 100644 index 000000000..31a117442 --- /dev/null +++ b/benchmarks/check_comm_asymmetry.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Parse MPI+Dagger logs and report communication decision asymmetry per tag. +Asymmetry: for the same tag, one rank decides to send (local+bcast, etc.) +and another decides to infer (inferred) and never recv → deadlock. + +Usage: + # Capture full log (all ranks' Core.println from mpi.jl go to stdout): + mpiexec -n 10 julia --project=/path/to/Dagger.jl benchmarks/run_matmul.jl 2>&1 | tee matmul.log + # Then look for asymmetry (same tag: one rank sends, another infers → deadlock): + python3 check_comm_asymmetry.py < matmul.log +""" + +import re +import sys +from collections import defaultdict + +SEND_DECISIONS = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} +RECV_DECISIONS = {"communicated", "receiver", "sender+communicated"} +INFER_DECISIONS = {"inferred", "uninvolved", "sender+inferred"} + + +def parse_line(line: str): + rank = tag = category = decision = None + m = re.search(r"\[rank\s+(\d+)\]", line) + if m: + rank = int(m.group(1)) + m = re.search(r"\[tag\s+(\d+)\]", line) + if m: + tag = int(m.group(1)) + m = re.search(r"\[(execute!|aliasing|remotecall_endpoint)\]", line) + if m: + category = m.group(1) + # Capture decision from [...] blocks + for m in re.finditer(r"\]\[([^\]]+)\]", line): + candidate = m.group(1) + if "inferred" in candidate and "communicated" not in candidate: + decision = "inferred" + break + if "communicated" in candidate: + decision = "communicated" + break + if "local+bcast" in candidate: + decision = "local+bcast" + break + if candidate.startswith("sender+"): + decision = "sender+inferred" if "inferred" in candidate else "sender+communicated" + break + if candidate == "receiver": + decision = "receiver" + break + if candidate == "receiver+bcast": + decision = "receiver+bcast" + break + if candidate == "inplace_move": + decision = "inplace_move" + break + return rank, tag, category, decision + + +def main(): + by_tag = defaultdict(dict) # tag -> {rank: decision} + for line in sys.stdin: + rank, tag, category, decision = parse_line(line) + if rank is None or tag is None or decision is None: + continue + by_tag[tag][rank] = decision + + send_keys = {"local+bcast", "sender+communicated", "sender+inferred", "receiver+bcast"} + infer_keys = {"inferred", "sender+inferred", "uninvolved"} + recv_keys = {"communicated", "receiver", "sender+communicated"} + + asymmetries = [] + for tag in sorted(by_tag.keys()): + ranks = by_tag[tag] + senders = [r for r, d in ranks.items() if d in send_keys] + inferrers = [r for r, d in ranks.items() if d in infer_keys] + receivers = [r for r, d in ranks.items() if d in recv_keys] + if senders and inferrers: + asymmetries.append((tag, senders, inferrers, receivers, ranks)) + + if not asymmetries: + print("No communication decision asymmetry found (no tag has both sender and inferrer).") + return + + print("=== Communication decision asymmetry (can cause deadlock) ===\n") + for tag, senders, inferrers, receivers, ranks in asymmetries: + print(f"Tag {tag}:") + print(f" Senders (will bcast to all others): {senders}") + print(f" Inferrers (did not recv): {inferrers}") + print(f" Receivers: {receivers}") + print(f" All decisions: {dict(ranks)}") + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_distribute_fetch.jl b/benchmarks/run_distribute_fetch.jl new file mode 100644 index 000000000..822e1ad2c --- /dev/null +++ b/benchmarks/run_distribute_fetch.jl @@ -0,0 +1,42 @@ +#!/usr/bin/env julia +# Create a matrix with a fixed reproducible pattern, distribute it with an +# MPI procgrid, then on each rank fetch and println the chunk(s) it owns. +# Usage (from repo root, use full path to Dagger.jl): +# mpiexec -n 4 julia --project=/path/to/Dagger.jl benchmarks/run_distribute_fetch.jl + +using MPI +using Dagger + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) + +# Fixed reproducible pattern: 6×6 matrix, M[i,j] = 10*i + j (same on all ranks) +const N = 6 +const BLOCK = 2 +A = [10 * i + j for i in 1:N, j in 1:N] + +# Procgrid: use Dagger's compatible processors so the procgrid passes validation +availprocs = collect(Dagger.compatible_processors()) +nblocks = (cld(N, BLOCK), cld(N, BLOCK)) +procgrid = reshape( + [availprocs[mod(i - 1, length(availprocs)) + 1] for i in 1:prod(nblocks)], + nblocks, +) + +# Distribute so chunk (i,j) is computed on procgrid[i,j] +D = distribute(A, Blocks(BLOCK, BLOCK), procgrid) +D_fetched = fetch(D) + +# On each rank: fetch and print only the chunk(s) this rank owns +for (idx, ch) in enumerate(D_fetched.chunks) + if ch isa Dagger.Chunk && ch.handle isa Dagger.MPIRef && ch.handle.rank == rank + data = fetch(ch) + println("rank $rank chunk $idx: ", data) + end +end diff --git a/benchmarks/run_matmul.jl b/benchmarks/run_matmul.jl new file mode 100644 index 000000000..0eb4ec0d7 --- /dev/null +++ b/benchmarks/run_matmul.jl @@ -0,0 +1,105 @@ +#!/usr/bin/env julia +# N×N matmul benchmark (Float32); block size scales with number of ranks. +# Usage (use the full path to Dagger.jl, not "..."): +# mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl +# Set CHECK_CORRECTNESS=true to collect and compare against GPU baseline: +# CHECK_CORRECTNESS=true mpiexec -n 10 julia --project=/home/felipetome/dagger-dev/mpi/Dagger.jl benchmarks/run_matmul.jl + +using MPI +using Dagger +using LinearAlgebra + +if !isdefined(Dagger, :accelerate!) + error("Dagger.accelerate! not found. Run with the local Dagger project: julia --project=/path/to/Dagger.jl ...") +end +Dagger.accelerate!(:mpi) + +const N = 2_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +const CHECK_CORRECTNESS = parse(Bool, get(ENV, "CHECK_CORRECTNESS", "false")) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (matmul)") +end + +# Allocate and fill matrices in blocks (Float32) +A = rand(Blocks(BLOCK, BLOCK), Float32, N, N) +B = rand(Blocks(BLOCK, BLOCK), Float32, N, N) + +# Matrix multiply C = A * B +t_matmul = @elapsed begin + C = A * B +end + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + +# Optional: collect via datadeps (root=0). All ranks participate in the datadeps region. +if CHECK_CORRECTNESS + t_collect = @elapsed begin + A_full = Dagger.collect_datadeps(A; root=0) + B_full = Dagger.collect_datadeps(B; root=0) + C_dagger = Dagger.collect_datadeps(C; root=0) + end + if rank == 0 + println("Collecting result and computing baseline for correctness check (GPU)...") + using CUDA + CUDA.functional() || error("CUDA not functional; cannot compute GPU baseline. Check CUDA driver and device.") + t_upload = @elapsed begin + A_g = CUDA.cu(A_full) + B_g = CUDA.cu(B_full) + end + println("Collect + upload time: ", round(t_collect + t_upload; digits=4), " s") + + t_baseline = @elapsed begin + C_ref_g = A_g * B_g + end + println("Baseline (GPU/CUDA) time: ", round(t_baseline; digits=4), " s") + + # Require all elements within 100× machine epsilon relative error (componentwise) + C_dagger_cpu = C_dagger + C_ref_cpu = Array(C_ref_g) + eps_f = eps(Float32) + rtol = 50.0f0 * eps_f + diff = C_dagger_cpu .- C_ref_cpu + # rel_ij = |diff|/|C_ref|, denominator at least eps to avoid div by zero + denom = max.(abs.(C_ref_cpu), eps_f) + rel_err = abs.(diff) ./ denom + max_rel_err = Float32(maximum(rel_err)) + ok = max_rel_err <= rtol + if ok + println("Correctness: OK (max rel_err = ", max_rel_err, " <= 100×eps = ", rtol, ")") + else + println("Correctness: FAIL (max rel_err = ", max_rel_err, " > 100×eps = ", rtol, ")") + end + + # Per-block: which blocks have any element with rel_err > 100×eps + n_bi = ceil(Int, N / BLOCK) + n_bj = ceil(Int, N / BLOCK) + bad_blocks = Tuple{Int,Int,Float32}[] + for bi in 1:n_bi, bj in 1:n_bj + ri = (bi - 1) * BLOCK + 1 : min(bi * BLOCK, N) + rj = (bj - 1) * BLOCK + 1 : min(bj * BLOCK, N) + block_rel = Float32(maximum(@view(rel_err[ri, rj]))) + if block_rel > rtol + push!(bad_blocks, (bi, bj, block_rel)) + end + end + if isempty(bad_blocks) + println("Per-block: all ", n_bi * n_bj, " blocks within 100×eps rel_err.") + else + println("Per-block: ", length(bad_blocks), " block(s) exceed 100×eps rel_err (block size ", BLOCK, "×", BLOCK, "):") + sort!(bad_blocks; by = x -> -x[3]) + for (bi, bj, block_rel) in bad_blocks + println(" block [", bi, ",", bj, "] rows ", (bi - 1) * BLOCK + 1, ":", min(bi * BLOCK, N), + ", cols ", (bj - 1) * BLOCK + 1, ":", min(bj * BLOCK, N), " max rel_err = ", block_rel) + end + end + end +end diff --git a/benchmarks/run_qr.jl b/benchmarks/run_qr.jl new file mode 100644 index 000000000..c5915db2a --- /dev/null +++ b/benchmarks/run_qr.jl @@ -0,0 +1,46 @@ +#!/usr/bin/env julia +# 10k×10k QR + matmul benchmark; block size scales with number of ranks. +# Usage: mpiexec -n 100 julia --project=/path/to/Dagger.jl benchmarks/bench_100rank_qr_matmul.jl +# Or: bash benchmarks/run_100rank_qr_matmul.sh . + +using MPI +using Dagger +using LinearAlgebra + +Dagger.accelerate!(:mpi) + +const N = 10_000 +const comm = MPI.COMM_WORLD +const rank = MPI.Comm_rank(comm) +const nranks = MPI.Comm_size(comm) +# Block size proportional to ranks: ~nranks blocks in 2D => side blocks ≈ √nranks +const BLOCK = max(1, ceil(Int, N / ceil(Int, sqrt(nranks)))) + +if rank == 0 + println("Benchmark: ", nranks, " ranks, N=", N, ", block size ", BLOCK, "×", BLOCK, " (QR + matmul)") +end + +# Allocate and fill 10k×10k matrix in 1k×1k blocks +A = rand(Blocks(BLOCK, BLOCK), Float64, N, N) +MPI.Barrier(comm) + +# QR factorization (computing Q runs the full factorization) +t_qr = @elapsed begin + qr!(A) +end +MPI.Barrier(comm) + +if rank == 0 + println("QR time: ", round(t_qr; digits=4), " s") +end + +# Matrix multiply A * A +t_matmul = @elapsed begin + C = A * A +end +MPI.Barrier(comm) + +if rank == 0 + println("Matmul time: ", round(t_matmul; digits=4), " s") +end + diff --git a/demo.jl b/demo.jl new file mode 100644 index 000000000..0c9ef9e0c --- /dev/null +++ b/demo.jl @@ -0,0 +1,55 @@ +begin +using Revise +using Dagger +using LinearAlgebra + +using Profile +include("filter-traces.jl") +end + +function demo(pivot=RowMaximum()) + fetch(Dagger.@spawn 1+1) + + N = 2000 + nt = Threads.nthreads() + #bs = cld(N, np) + bs = div(N, 4) + println("OpenBLAS Initialization:") + GC.enable(false) + A = @time rand(N, N) + GC.enable(true) + println("Dagger Initialization:") + GC.enable(false) + @time begin + DA = DArray(A, Blocks(bs, bs)) + wait.(DA.chunks) + end + GC.enable(true) + + println("OpenBLAS:") + BLAS.set_num_threads(nt) + lu_A = @time lu(A, pivot; check=false) + println("Dagger:") + BLAS.set_num_threads(1) + GC.enable(false) + lu_DA = @time lu(DA, pivot; check=false) + GC.enable(true) + + Profile.@profile 1+1 + Profile.clear() + println("Dagger (profiler):") + GC.enable(false) + Profile.@profile @time lu(DA, pivot; check=false) + GC.enable(true) + + @show norm(lu_A.U - UpperTriangular(collect(lu_DA.factors))) + + return +end + +demo(); + +begin + samples, lidata = Profile.retrieve() + validate_and_filter_traces!(samples, lidata) +end \ No newline at end of file diff --git a/filter-traces.jl b/filter-traces.jl new file mode 100644 index 000000000..0e6ab49c4 --- /dev/null +++ b/filter-traces.jl @@ -0,0 +1,168 @@ +""" +Filter out traces from the Julia Profile.jl buffer. + +Each trace in the buffer has the following structure (all UInt64): +- Stack frames (variable number) +- Thread ID +- Task ID +- CPU cycle clock +- Thread state (1 = awake, 2 = sleeping) +- Null word +- Null word + +The trace ends are marked by two consecutive null words (0x0). +""" +function filter_traces!(f, buffer::Vector{UInt64}, lidata) + if length(buffer) < 6 + return 0 # Buffer too small to contain even one complete trace + end + + filtered_count = 0 + i = 1 + + while i <= length(buffer) + # Find the end of the current trace by looking for two consecutive nulls + trace_start = i + trace_end = find_trace_end(buffer, i) + + if trace_end == -1 + # No complete trace found from this position + error("Failed to find trace end for $i") + break + end + + # Extract trace metadata (last 6 elements before the two nulls) + if trace_end - trace_start < 5 + # Trace too short to have proper metadata + i = trace_end + 1 + continue + end + + # Check if the trace should be filtered + do_filter = f(buffer, lidata, trace_start, trace_end)::Bool + + # If the trace should be filtered, null out the entire trace + if do_filter + for j in trace_start:trace_end + buffer[j] = 0x0 + end + filtered_count += 1 + end + + # Move to the next trace + i = trace_end + 1 + end + + return filtered_count +end +""" +Find the end of a trace starting from start_idx. +Returns the index of the second null word, or -1 if not found. +""" +function find_trace_end(buffer::Vector{UInt64}, start_idx::Int) + i = start_idx + while i < length(buffer) + if buffer[i] == 0x0 && buffer[i + 1] == 0x0 + return i + 1 # Return index of second null + end + i += 1 + end + + return -1 # No complete trace found +end + +"Count total number of trace entries in the buffer." +function count_traces(buffer::Vector{UInt64}) + count = 0 + + filter_traces!(buffer, lidata) do buffer, lidata, trace_start, trace_end + count += 1 + return false + end + + return count +end + +"Parse profile buffer and null out traces from sleeping threads." +function filter_sleeping_traces!(buffer::Vector{UInt64}, lidata) + return filter_traces!(buffer, lidata) do buffer, lidata, trace_start, trace_end + # The structure before the two nulls is: + # [...stack frames...][thread_id][task_id][cpu_cycles][thread_state][null][null] + thread_state_idx = trace_end - 2 # thread_state is 4th from end (before 2 nulls + 1 other field) + thread_state = buffer[thread_state_idx] + return thread_state == 2 + end +end + +"Parse profile buffer and null out traces without calls to a slowlock path." +function filter_for_slowlock_traces!(buffer::Vector{UInt64}, lidata) + return filter_traces!(buffer, lidata) do buffer, lidata, trace_start, trace_end + # Find slowlock frames + slowlock = false + frames_end = trace_end - 6 + for j in trace_start:frames_end + slowlock && break + ptr = buffer[j] + for frame in lidata[ptr] + if occursin("slowlock", string(frame)) + slowlock = true + break + end + end + end + return !slowlock + end +end + +"Parse profile buffer and keep only traces from a specific thread." +function filter_for_thread!(buffer::Vector{UInt64}, lidata, thread) + return filter_traces!(buffer, lidata) do buffer, lidata, trace_start, trace_end + # The structure before the two nulls is: + # [...stack frames...][thread_id][task_id][cpu_cycles][thread_state][null][null] + thread_id_idx = trace_end - 5 # thread_id is 5th from end (before 2 nulls + 1 other field) + thread_id = buffer[thread_id_idx] + return thread_id+1 == thread + end +end + +""" +Filters out traces from the Julia Profile.jl buffer. Performs: +- Removal of sleeping thread traces +- If slocklock is true, also remove traces that do not call into a slowlock path + +Args: + buffer: Vector{UInt64} containing profile trace data + +Returns: + (filtered_count, total_traces) tuple +""" +function filter_traces_multi!(buffer::Vector{UInt64}, lidata; + slocklock::Bool=false, thread=nothing) + total_traces = count_traces(buffer) + sleeping_count = filter_sleeping_traces!(buffer, lidata) + if slocklock + slowlock_count = filter_for_slowlock_traces!(buffer, lidata) + else + slowlock_count = 0 + end + if thread !== nothing + thread_count = filter_for_thread!(buffer, lidata, thread) + else + thread_count = 0 + end + + #= Find the last double-zero in the buffer and truncate the buffer there + last_zero = 1 + idx = 1 + while idx < length(buffer) + if buffer[idx] == 0x0 && buffer[idx + 1] == 0x0 + last_zero = idx + break + end + idx += 1 + end + deleteat!(buffer, last_zero:length(buffer)) + =# + + return (;total_traces, sleeping_count, slowlock_count, thread_count) +end \ No newline at end of file diff --git a/src/Dagger.jl b/src/Dagger.jl index b29254d5d..a74cec3e6 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -53,6 +53,13 @@ import Adapt include("lib/util.jl") include("utils/dagdebug.jl") +# Type definitions (for MPI/acceleration) +include("types/processor.jl") +include("types/scope.jl") +include("types/memory-space.jl") +include("types/chunk.jl") +include("types/acceleration.jl") + # Distributed data include("utils/locked-object.jl") include("utils/tasks.jl") @@ -77,6 +84,7 @@ include("queue.jl") include("thunk.jl") include("utils/fetch.jl") include("utils/chunks.jl") +include("weakchunk.jl") include("utils/logging.jl") include("submission.jl") abstract type MemorySpace end @@ -90,6 +98,7 @@ include("utils/clock.jl") include("utils/system_uuid.jl") include("utils/caching.jl") include("sch/Sch.jl"); using .Sch +include("tochunk.jl") # Data dependency task queue include("datadeps/aliasing.jl") @@ -156,6 +165,10 @@ function set_distributed_package!(value) @info "Dagger.jl preference has been set, restart your Julia session for this change to take effect!" end +# MPI (mpi.jl loads MPI; mpi_mempool uses it) +include("mpi.jl") +include("mpi_mempool.jl") + # Precompilation import PrecompileTools: @compile_workload include("precompile.jl") diff --git a/src/affinity.jl b/src/affinity.jl new file mode 100644 index 000000000..aab663a51 --- /dev/null +++ b/src/affinity.jl @@ -0,0 +1,32 @@ +export domain, UnitDomain, project, alignfirst, ArrayDomain + +import Base: isempty, getindex, intersect, ==, size, length, ndims + +""" + domain(x::T) + +Returns metadata about `x`. This metadata will be in the `domain` +field of a Chunk object when an object of type `T` is created as +the result of evaluating a Thunk. +""" +function domain end + +""" + UnitDomain + +Default domain -- has no information about the value +""" +struct UnitDomain end + +""" +If no `domain` method is defined on an object, then +we use the `UnitDomain` on it. A `UnitDomain` is indivisible. +""" +domain(x::Any) = UnitDomain() + +### ChunkIO +affinity(r::DRef) = OSProc(r.owner)=>r.size +# this previously returned a vector with all machines that had the file cached +# but now only returns the owner and size, for consistency with affinity(::DRef), +# see #295 +affinity(r::FileRef) = OSProc(1)=>r.size diff --git a/src/array/alloc.jl b/src/array/alloc.jl index fe92ae1e1..0c45443da 100644 --- a/src/array/alloc.jl +++ b/src/array/alloc.jl @@ -93,14 +93,31 @@ function stage(ctx, A::AllocateArray) scope = ExactScope(A.procgrid[CartesianIndex(mod1.(Tuple(I), size(A.procgrid))...)]) end + N = ndims(A.domainchunks) + ret_type = Array{A.eltype, N} if A.want_index - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, i, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, i, size(x)) else - task = Dagger.@spawn compute_scope=scope allocate_array(A.f, A.eltype, size(x)) + task = Dagger.@spawn compute_scope=scope return_type=ret_type allocate_array(A.f, A.eltype, size(x)) end tasks[i] = task end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = ndims(A.domainchunks) + expected_type = Array{A.eltype, N} + Dagger.mpi_propagate_chunk_types!(tasks, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in tasks] + if allequal(chunk_types) + @info "[rank $rank] Array creation (alloc): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (alloc): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(A.eltype, A.domain, A.domainchunks, tasks, A.partitioning) end diff --git a/src/array/darray.jl b/src/array/darray.jl index 20722b8ed..7f22d6f88 100644 --- a/src/array/darray.jl +++ b/src/array/darray.jl @@ -1,7 +1,7 @@ -import Base: ==, fetch +import Base: ==, fetch, length, isempty, size export DArray, DVector, DMatrix, DVecOrMat, Blocks, AutoBlocks -export distribute +export distribute, collect_datadeps ###### Array Domains ###### @@ -83,7 +83,8 @@ isempty(a::ArrayDomain) = length(a) == 0 The domain of an array is an ArrayDomain. """ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)]) - +# Scalar / non-array values (e.g. for Chunk of immediate data) +domain(x::Any) = ArrayDomain(()) abstract type ArrayOp{T, N} <: AbstractArray{T, N} end Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian() @@ -174,6 +175,7 @@ domain(d::DArray) = d.domain chunks(d::DArray) = d.chunks domainchunks(d::DArray) = d.subdomains size(x::DArray) = size(domain(x)) +Base.ndims(d::DArray{T,N}) where {T,N} = N stage(ctx, c::DArray) = c function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} @@ -200,6 +202,31 @@ function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N} collect(treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))) end end + +""" + collect_datadeps(d::DArray; root=nothing) + +Collect a DArray to a single array by fetching every chunk on the current rank +and assembling into a full array. No datadeps scheduling or root-only assembly: +each rank that calls this gets the full matrix (useful when correctness matters +more than communication cost). +""" +function collect_datadeps(d::DArray{T,N}; root=nothing) where {T,N} + if isempty(d.chunks) + return Array{eltype(d)}(undef, size(d)...) + end + if N == 0 + return fetch(d.chunks[1]) + end + + chks = d.chunks + doms = domainchunks(d) + out = Array{T,N}(undef, size(d)) + for I in CartesianIndices(chks) + copyto!(view(out, indexes(doms[I])...), fetch(chks[I])) + end + return out +end Array{T,N}(A::DArray{S,N}) where {T,N,S} = convert(Array{T,N}, collect(A)) Base.wait(A::DArray) = foreach(wait, A.chunks) @@ -481,6 +508,21 @@ function stage(ctx::Context, d::Distribute) Dagger.@spawn compute_scope=scope identity(d.data[c]) end end + # MPI type propagation: ensure all ranks know the concrete chunk types + accel = Dagger.current_acceleration() + if accel isa Dagger.MPIAcceleration + N = Base.ndims(d.data) + expected_type = Array{eltype(d.data), N} + Dagger.mpi_propagate_chunk_types!(cs, accel, expected_type) + # Log chunk types per rank after array creation + rank = MPI.Comm_rank(accel.comm) + #=chunk_types = Type[chunktype(t) for t in cs] + if allequal(chunk_types) + @info "[rank $rank] Array creation (distribute): all $(length(chunk_types)) chunk types are uniform: $(first(chunk_types))" + else + @warn "[rank $rank] Array creation (distribute): chunk types are NOT uniform: $chunk_types" + end=# + end return DArray(eltype(d.data), domain(d.data), d.domainchunks, @@ -612,7 +654,7 @@ end mapchunk(f, chunk) = tochunk(f(poolget(chunk.handle))) function mapchunks(f, d::DArray{T,N,F}) where {T,N,F} chunks = map(d.chunks) do chunk - owner = get_parent(chunk.processor).pid + owner = root_worker_id(chunk.processor) remotecall_fetch(mapchunk, owner, f, chunk) end DArray{T,N,F}(d.domain, d.subdomains, chunks, d.concat) diff --git a/src/array/gmres.jl b/src/array/gmres.jl new file mode 100644 index 000000000..5127eb314 --- /dev/null +++ b/src/array/gmres.jl @@ -0,0 +1,152 @@ +function gmres(A::DArray, b::DVector; x0=nothing, m=length(b), tol=1e-6, maxiter=100) + """ + GMRES algorithm for solving Ax = b + + Args: + A: coefficient matrix (or function that computes A*v) + b: right-hand side vector + x0: initial guess (default: zero vector) + m: restart parameter (default: no restart) + tol: convergence tolerance + maxiter: maximum number of restarts + + Returns: + x: solution vector + residual_norm: final residual norm + iterations: number of iterations + """ + n = length(b) + x = x0 === nothing ? zeros(AutoBlocks(), n) : DArray(copy(x0)) + + # Initial residual + r = b - A * x + β = norm(r) + + if β < tol + return x, β, 0 + end + + for restart in 1:maxiter + # Krylov subspace basis vectors + V = zeros(AutoBlocks(), n, m + 1) + V[:, 1] = r / β + + # Upper Hessenberg matrix + H = zeros(m + 1, m) + + # Givens rotation matrices (store cos and sin) + cs = zeros(m) + sn = zeros(m) + + # RHS for least squares problem + e1 = zeros(AutoBlocks(), m + 1) + e1[1] = β + + # Arnoldi iteration + for j in 1:m + # Apply matrix to current basis vector + w = A * V[:, j] + + # Modified Gram-Schmidt orthogonalization + for i in 1:j + H[i, j] = dot(w, V[:, i]) + w -= H[i, j] * V[:, i] + end + + H[j + 1, j] = norm(w) + + # Check for breakdown + if abs(H[j + 1, j]) < eps() + m = j + break + end + + V[:, j + 1] = w / H[j + 1, j] + + # Apply previous Givens rotations to new column of H + for i in 1:(j-1) + temp = cs[i] * H[i, j] + sn[i] * H[i + 1, j] + H[i + 1, j] = -sn[i] * H[i, j] + cs[i] * H[i + 1, j] + H[i, j] = temp + end + + # Compute new Givens rotation + if abs(H[j + 1, j]) < eps() + cs[j] = 1.0 + sn[j] = 0.0 + else + if abs(H[j + 1, j]) > abs(H[j, j]) + τ = H[j, j] / H[j + 1, j] + sn[j] = 1.0 / sqrt(1 + τ^2) + cs[j] = sn[j] * τ + else + τ = H[j + 1, j] / H[j, j] + cs[j] = 1.0 / sqrt(1 + τ^2) + sn[j] = cs[j] * τ + end + end + + # Apply new Givens rotation + temp = cs[j] * H[j, j] + sn[j] * H[j + 1, j] + H[j + 1, j] = -sn[j] * H[j, j] + cs[j] * H[j + 1, j] + H[j, j] = temp + + # Apply rotation to RHS + temp = cs[j] * e1[j] + sn[j] * e1[j + 1] + e1[j + 1] = -sn[j] * e1[j] + cs[j] * e1[j + 1] + e1[j] = temp + + # Check convergence + residual_norm = abs(e1[j + 1]) + if residual_norm < tol + m = j + break + end + end + + # Solve upper triangular system H[1:m, 1:m] * y = e1[1:m] + y = zeros(m) + for i in m:-1:1 + y[i] = e1[i] + for k in (i+1):m + y[i] -= H[i, k] * y[k] + end + y[i] /= H[i, i] + end + + # Update solution + for i in 1:m + x += y[i] * V[:, i] + end + + # Check final convergence + r = b - A * x + β = norm(r) + + if β < tol + return x, β, restart + end + end + + return x, β, maxiter +end + +# Example usage +function example_usage() + # Create test problem + n = 100 + A = DArray(randn(n, n) + 5*I) # Well-conditioned matrix + x_true = randn(AutoBlocks(), n) + b = A * x_true + + # Solve with GMRES + allowscalar(false) do + x_gmres, res_norm, iters = gmres(A, b, tol=1e-10) + end + + println("GMRES converged in $iters iterations") + println("Final residual norm: $res_norm") + println("Solution error: $(norm(x_gmres - x_true))") + + return x_gmres +end diff --git a/src/array/mul.jl b/src/array/mul.jl index 02b207641..5890473da 100644 --- a/src/array/mul.jl +++ b/src/array/mul.jl @@ -41,7 +41,7 @@ function LinearAlgebra.generic_matmatmul!( return gemm_dagger!(C, transA, transB, A, B, alpha, beta) end end -function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) +function _repartition_matmatmul(C, A, B, transA::Char, transB::Char)::Tuple{Blocks{2}, Blocks{2}, Blocks{2}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -93,6 +93,24 @@ function _repartition_matmatmul(C, A, B, transA::Char, transB::Char) return Blocks(partC...), Blocks(partA...), Blocks(partB...) end +# Typed BLAS wrappers so that every @spawn kernel has an inferable return type +@inline function _gemm!(transA::Char, transB::Char, alpha::T, A, B, mzone, C)::Matrix{T} where {T} + BLAS.gemm!(transA, transB, alpha, A, B, mzone, C) + return C +end +@inline function _syrk!(uplo::AbstractChar, trans::AbstractChar, alpha::T, A, mzone, C)::Matrix{T} where {T} + BLAS.syrk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _herk!(uplo::AbstractChar, trans::AbstractChar, alpha::Real, A, mzone, C)::Matrix{<:Complex} + BLAS.herk!(uplo, trans, alpha, A, mzone, C) + return C +end +@inline function _gemv!(transA::Char, alpha::T, A, x, mzone, y)::Vector{T} where {T} + BLAS.gemv!(transA, alpha, A, x, mzone, y) + return y +end + """ Performs one of the matrix-matrix operations @@ -136,7 +154,7 @@ function gemm_dagger!( # A: NoTrans / B: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -150,7 +168,7 @@ function gemm_dagger!( # A: NoTrans / B: [Conj]Trans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -166,7 +184,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: NoTrans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -180,7 +198,7 @@ function gemm_dagger!( # A: [Conj]Trans / B: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transA, transB, alpha, @@ -243,7 +261,7 @@ function syrk_dagger!( for k in range(1, Ant) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, trans, real(alpha), @@ -252,7 +270,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -267,7 +285,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -283,7 +301,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Ant) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( trans, transs, alpha, @@ -300,7 +318,7 @@ function syrk_dagger!( for k in range(1, Amt) mzone = k == 1 ? real(beta) : one(real(T)) if iscomplex - Dagger.@spawn BLAS.herk!( + Dagger.@spawn _herk!( uplo, transs, real(alpha), @@ -309,7 +327,7 @@ function syrk_dagger!( InOut(Cc[n, n]), ) else - Dagger.@spawn BLAS.syrk!( + Dagger.@spawn _syrk!( uplo, trans, alpha, @@ -324,7 +342,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -340,7 +358,7 @@ function syrk_dagger!( for m in range(n + 1, Cmt) for k in range(1, Amt) mzone = k == 1 ? beta : one(T) - Dagger.@spawn BLAS.gemm!( + Dagger.@spawn _gemm!( transs, 'N', alpha, @@ -393,16 +411,17 @@ end return A end -@inline function copytile!(A, B) +@inline function copytile!(A::AbstractMatrix{T}, B::AbstractMatrix{T})::Nothing where {T} m, n = size(A) C = B' for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end -@inline function copydiagtile!(A, uplo) +@inline function copydiagtile!(A::AbstractMatrix{T}, uplo::AbstractChar)::Nothing where {T} m, n = size(A) Acpy = copy(A) @@ -417,6 +436,7 @@ end for i = 1:m, j = 1:n A[i, j] = C[i, j] end + return nothing end function LinearAlgebra.generic_matvecmul!( C::DVector{T}, @@ -440,7 +460,7 @@ function LinearAlgebra.generic_matvecmul!( return gemv_dagger!(C, transA, A, B, _alpha, _beta) end end -function _repartition_matvecmul(C, A, B, transA::Char) +function _repartition_matvecmul(C, A, B, transA::Char)::Tuple{Blocks{1}, Blocks{2}, Blocks{1}} partA = A.partitioning.blocksize partB = B.partitioning.blocksize istransA = transA == 'T' || transA == 'C' @@ -495,7 +515,7 @@ function gemv_dagger!( # A: NoTrans for k in range(1, Ant) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[m, k]), @@ -508,7 +528,7 @@ function gemv_dagger!( # A: [Conj]Trans for k in range(1, Amt) mzone = k == 1 ? beta : T(1.0) - Dagger.@spawn BLAS.gemv!( + Dagger.@spawn _gemv!( transA, alpha, In(Ac[k, m]), diff --git a/src/chunks.jl b/src/chunks.jl index 03bdfb65d..0defc1ff6 100644 --- a/src/chunks.jl +++ b/src/chunks.jl @@ -1,56 +1,4 @@ -export domain, UnitDomain, project, alignfirst, ArrayDomain - -import Base: isempty, getindex, intersect, ==, size, length, ndims - -""" - domain(x::T) - -Returns metadata about `x`. This metadata will be in the `domain` -field of a Chunk object when an object of type `T` is created as -the result of evaluating a Thunk. -""" -function domain end - -""" - UnitDomain - -Default domain -- has no information about the value -""" -struct UnitDomain end - -""" -If no `domain` method is defined on an object, then -we use the `UnitDomain` on it. A `UnitDomain` is indivisible. -""" -domain(x::Any) = UnitDomain() - -###### Chunk ###### - -""" - Chunk - -A reference to a piece of data located on a remote worker. `Chunk`s are -typically created with `Dagger.tochunk(data)`, and the data can then be -accessed from any worker with `collect(::Chunk)`. `Chunk`s are -serialization-safe, and use distributed refcounting (provided by -`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, -as long as a reference exists on some worker. - -Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a -sense) the processor that "owns" or contains the data. Calling -`collect(::Chunk)` will perform data movement and conversions defined by that -processor to safely serialize the data to the calling worker. - -## Constructors -See [`tochunk`](@ref). -""" -mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope} - chunktype::Type{T} - domain - handle::H - processor::P - scope::S -end +###### Chunk Methods ###### domain(c::Chunk) = c.domain chunktype(c::Chunk) = c.chunktype @@ -72,20 +20,27 @@ function collect(ctx::Context, chunk::Chunk; options=nothing) elseif chunk.handle isa FileRef return poolget(chunk.handle) else - return move(chunk.processor, OSProc(), chunk.handle) + return move(chunk.processor, default_processor(), chunk.handle) end end collect(ctx::Context, ref::DRef; options=nothing) = move(OSProc(ref.owner), OSProc(), ref) collect(ctx::Context, ref::FileRef; options=nothing) = poolget(ref) # FIXME: Do move call -function Base.fetch(chunk::Chunk; raw=false) - if raw - poolget(chunk.handle) - else - collect(chunk) +@warn "Fix semantics of collect" maxlog=1 +function Base.fetch(chunk::Chunk{T}; unwrap::Bool=false, uniform::Bool=false, kwargs...) where T + value = fetch_handle(chunk.handle; uniform)::T + if unwrap && unwrappable(value) + return fetch(value; unwrap, uniform, kwargs...) end + return value end +fetch_handle(ref::DRef; uniform::Bool=false) = poolget(ref) +fetch_handle(ref::FileRef; uniform::Bool=false) = poolget(ref) +unwrappable(x::Chunk) = true +unwrappable(x::DRef) = true +unwrappable(x::FileRef) = true +unwrappable(x) = false # Unwrap Chunk, DRef, and FileRef by default move(from_proc::Processor, to_proc::Processor, x::Chunk) = @@ -100,32 +55,3 @@ move(to_proc::Processor, d::DRef) = move(OSProc(d.owner), to_proc, d) move(to_proc::Processor, x) = move(OSProc(), to_proc, x) - -### ChunkIO -affinity(r::DRef) = OSProc(r.owner)=>r.size -# this previously returned a vector with all machines that had the file cached -# but now only returns the owner and size, for consistency with affinity(::DRef), -# see #295 -affinity(r::FileRef) = OSProc(1)=>r.size - -struct WeakChunk - wid::Int - id::Int - x::WeakRef - function WeakChunk(c::Chunk) - return new(c.handle.owner, c.handle.id, WeakRef(c)) - end -end -unwrap_weak(c::WeakChunk) = c.x.value -function unwrap_weak_checked(c::WeakChunk) - cw = unwrap_weak(c) - @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" - return cw -end -wrap_weak(c::Chunk) = WeakChunk(c) -isweak(c::WeakChunk) = true -isweak(c::Chunk) = false -is_task_or_chunk(c::WeakChunk) = true -Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = - error("Cannot serialize a WeakChunk") -chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl index 64ce11be5..518c4bd2c 100644 --- a/src/datadeps/aliasing.jl +++ b/src/datadeps/aliasing.jl @@ -8,7 +8,7 @@ export In, Out, InOut, Deps, spawn_datadeps ============================================================================== This file implements the data dependencies system for Dagger tasks, which allows -tasks to access their arguments in a controlled manner. The system maintains +tasks to write to their arguments in a controlled manner. The system maintains data coherency across distributed workers by tracking aliasing relationships and orchestrating data movement operations. @@ -25,59 +25,26 @@ KEY CONCEPTS: 1. ALIASING ANALYSIS: - Every mutable argument is analyzed for its memory access pattern - Memory spans are computed to determine which bytes in memory are accessed - - Arguments that access overlapping memory spans are considered "aliasing" + - Objects that access overlapping memory spans are considered "aliasing" - Examples: An array A and view(A, 2:3, 2:3) alias each other 2. DATA LOCALITY TRACKING: - The system tracks where the "source of truth" for each piece of data lives - As tasks execute and modify data, the source of truth may move between workers - - Each argument can have its own independent source of truth location + - Each aliasing region can have its own independent source of truth location 3. ALIASED OBJECT MANAGEMENT: - When copying arguments between workers, the system tracks "aliased objects" - This ensures that if both an array and its view need to be copied to a worker, only one copy of the underlying array is made, with the view pointing to it - - The aliased_object!() and move_rewrap() functions manage this sharing - -ALIASING INFO: --------------- - -The system uses different types of aliasing info to represent different types of -aliasing relationships: - -- ContiguousAliasing: Single contiguous memory region (e.g., full array) -- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) -- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) -- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) - -Any two aliasing objects can be compared using the will_alias function to -determine if they overlap. Additionally, any aliasing object can be converted to -a vector of memory spans, which represents the contiguous regions of memory that -the aliasing object covers. - -DATA MOVEMENT FUNCTIONS: ------------------------- - -move!(dep_mod, to_space, from_space, to, from): -- The core in-place data movement function -- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) -- Supports partial copies via RemainderAliasing dependency modifiers - -move_rewrap(...): -- Handles copying of wrapped objects (SubArrays, ChunkViews) -- Ensures aliased objects are reused on destination worker - -read/write_remainder!(...): -- Read/write a span of memory from an object to/from a buffer -- Used by move! to copy the remainder of an aliased object + - The aliased_object!() functions manage this sharing THE DISTRIBUTED ALIASING PROBLEM: --------------------------------- In a multithreaded environment, aliasing "just works" because all tasks operate -on the user-provided memory. However, in a distributed environment, arguments -must be copied between workers, which breaks aliasing relationships if care is -not taken. +on the same memory. However, in a distributed environment, arguments must be +copied between workers, which breaks aliasing relationships. Consider this scenario: ```julia @@ -96,9 +63,11 @@ MULTITHREADED BEHAVIOR (WORKS): - Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Tasks may be scheduled on different workers - Each argument must be copied to the destination worker -- Without special handling, we would copy A and vA independently to another worker -- This creates two separate arrays, breaking the aliasing relationship between A and vA +- Without special handling, we would copy A to worker1 and vA to worker2 +- This creates two separate arrays, breaking the aliasing relationship +- Updates to the view on worker2 don't affect the array on worker1 THE SOLUTION - PARTIAL DATA MOVEMENT: ------------------------------------- @@ -112,13 +81,12 @@ The datadeps system solves this by: 2. PARTIAL DATA TRANSFER: - Instead of copying entire objects, only transfer the "dirty" regions - - This prevents overwrites of data that has already been updated by another task - - This also minimizes network traffic and overall copy time - - Uses the move!(dep_mod, ...) function with RemainderAliasing dependency modifiers + - This minimizes network traffic and maximizes parallelism + - Uses the move!(dep_mod, ...) function with dependency modifiers 3. REMAINDER TRACKING: - - When a task needs the full object, copy partial regions as needed - When a partial region is updated, track what parts still need updating + - Before a task needs the full object, copy the remaining "clean" regions - This preserves all updates while avoiding overwrites EXAMPLE EXECUTION FLOW: @@ -140,24 +108,69 @@ Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) - T2 needs vA, but vA aliases with A (which was modified by T1) - Copy vA-region of A from worker1 to worker2 - This is a PARTIAL copy - only the 2:3, 2:3 region - - Create vA on worker2 pointing to the appropriate region of A + - Create vA on worker2 pointing to the appropriate region - T2 executes, modifying vA region on worker2 - Update: vA's data_locality = worker2 4. FINAL SYNCHRONIZATION: - - Need to copy-back A and vA to worker0 - - A needs to be assembled from: worker1 (non-vA regions of A) + worker2 (vA region of A) - - REMAINDER COPY: Copy non-vA regions from worker1 to worker0 - - REMAINDER COPY: Copy vA region from worker2 to worker0 + - Some future task needs the complete A + - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 + - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A + +MEMORY SPAN COMPUTATION: +------------------------ -REMAINDER COMPUTATION: ----------------------- +The system uses memory spans to determine aliasing and compute remainders: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) Remainder computation involves: 1. Computing memory spans for all overlapping aliasing objects 2. Finding the set difference: full_object_spans - updated_spans -3. Creating a RemainderAliasing object representing the difference between spans -4. Performing one or more move! calls with this RemainderAliasing object to copy only needed data +3. Creating a "remainder aliasing" object representing the not-yet-updated regions +4. Performing move! with this remainder object to copy only needed data + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via dependency modifiers + +move_rewrap(): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +enqueue_copy_to!(): +- Schedules data movement tasks before user tasks +- Ensures data is up-to-date on the worker where a task will run + +CURRENT LIMITATIONS AND TODOS: +------------------------------- + +1. REMAINDER COMPUTATION: + - The system currently handles simple overlaps but needs sophisticated + remainder calculation for complex aliasing patterns + - Need functions to compute span set differences + +2. ORDERING DEPENDENCIES: + - Need to ensure remainder copies happen in correct order + - Must not overwrite more recent updates with stale data + +3. COMPLEX ALIASING PATTERNS: + - Multiple overlapping views of the same array + - Nested aliasing structures (views of views) + - Mixed aliasing types (diagonal + triangular regions) + +4. PERFORMANCE OPTIMIZATION: + - Minimize number of copy operations + - Batch compatible transfers + - Optimize for common access patterns =# "Specifies a read-only dependency." @@ -179,11 +192,6 @@ struct Deps{T,DT<:Tuple} end Deps(x, deps...) = Deps(x, deps) -chunktype(::In{T}) where T = T -chunktype(::Out{T}) where T = T -chunktype(::InOut{T}) where T = T -chunktype(::Deps{T,DT}) where {T,DT} = T - function unwrap_inout(arg) readdep = false writedep = false @@ -218,6 +226,7 @@ _identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) _identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) _identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) +@warn "Dispatch bcast behavior on acceleration" maxlog=1 struct ArgumentWrapper arg dep_mod @@ -226,6 +235,7 @@ struct ArgumentWrapper function ArgumentWrapper(arg, dep_mod) h = hash(dep_mod) h = _identity_hash(arg, h) + check_uniform(h, arg) return new(arg, dep_mod, h) end end @@ -241,123 +251,7 @@ struct HistoryEntry write_num::Int end -struct AliasedObjectCacheStore - keys::Vector{AbstractAliasing} - derived::Dict{AbstractAliasing,AbstractAliasing} - stored::Dict{MemorySpace,Set{AbstractAliasing}} - values::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} -end -AliasedObjectCacheStore() = - AliasedObjectCacheStore(Vector{AbstractAliasing}(), - Dict{AbstractAliasing,AbstractAliasing}(), - Dict{MemorySpace,Set{AbstractAliasing}}(), - Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}()) - -function is_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) - if !haskey(cache.stored, space) - return false - end - if !haskey(cache.derived, ainfo) - return false - end - key = cache.derived[ainfo] - return key in cache.stored[space] -end -function is_key_present(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) - return haskey(cache.derived, ainfo) -end -function get_stored(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing) - @assert is_stored(cache, space, ainfo) "Cache does not have derived ainfo $ainfo" - key = cache.derived[ainfo] - return cache.values[space][key] -end -function set_stored!(cache::AliasedObjectCacheStore, dest_space::MemorySpace, value::Chunk, ainfo::AbstractAliasing) - @assert !is_stored(cache, dest_space, ainfo) "Cache already has derived ainfo $ainfo" - key = cache.derived[ainfo] - value_ainfo = aliasing(value, identity) - cache.derived[value_ainfo] = key - push!(get!(Set{AbstractAliasing}, cache.stored, dest_space), key) - values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, dest_space) - values_dict[key] = value - return -end -function set_key_stored!(cache::AliasedObjectCacheStore, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) - push!(cache.keys, ainfo) - cache.derived[ainfo] = ainfo - push!(get!(Set{AbstractAliasing}, cache.stored, space), ainfo) - values_dict = get!(Dict{AbstractAliasing,Chunk}, cache.values, space) - values_dict[ainfo] = value - return -end - -struct AliasedObjectCache - space::MemorySpace - chunk::Chunk -end -function is_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(is_stored, wid, cache, ainfo) - end - cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore - return is_stored(cache_raw, cache.space, ainfo) -end -function is_key_present(cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(is_key_present, wid, cache, space, ainfo) - end - cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore - return is_key_present(cache_raw, space, ainfo) -end -function get_stored(cache::AliasedObjectCache, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(get_stored, wid, cache, ainfo) - end - cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore - return get_stored(cache_raw, cache.space, ainfo) -end -function set_stored!(cache::AliasedObjectCache, value::Chunk, ainfo::AbstractAliasing) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(set_stored!, wid, cache, value, ainfo) - end - cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore - set_stored!(cache_raw, cache.space, value, ainfo) - return -end -function set_key_stored!(cache::AliasedObjectCache, space::MemorySpace, ainfo::AbstractAliasing, value::Chunk) - wid = root_worker_id(cache.chunk) - if wid != myid() - return remotecall_fetch(set_key_stored!, wid, cache, space, ainfo, value) - end - cache_raw = unwrap(cache.chunk)::AliasedObjectCacheStore - set_key_stored!(cache_raw, space, ainfo, value) -end -function aliased_object!(f, cache::AliasedObjectCache, x; ainfo=aliasing(x, identity)) - x_space = memory_space(x) - if !is_key_present(cache, x_space, ainfo) - # Preserve the object's memory-space/processor pairing when inserting - # the source key. Using bare `tochunk(x)` defaults to OSProc, which can - # incorrectly wrap GPU-backed objects as CPU chunks. - x_chunk = x isa Chunk ? x : tochunk(x, first(processors(x_space))) - set_key_stored!(cache, x_space, ainfo, x_chunk) - end - if is_stored(cache, ainfo) - return get_stored(cache, ainfo) - else - y = f(x) - @assert y isa Chunk "Didn't get a Chunk from functor" - @assert memory_space(y) == cache.space "Space mismatch! $(memory_space(y)) != $(cache.space)" - if memory_space(x) != cache.space - @assert ainfo != aliasing(y, identity) "Aliasing mismatch! $ainfo == $(aliasing(y, identity))" - end - set_stored!(cache, y, ainfo) - return y - end -end - +@warn "Switch ArgumentWrapper to contain just the argument, and add DependencyWrapper" maxlog=1 struct DataDepsState # The mapping of original raw argument to its Chunk raw_arg_to_chunk::IdDict{Any,Chunk} @@ -373,13 +267,10 @@ struct DataDepsState # The mapping of remote argument to original argument remote_arg_to_original::IdDict{Any,Any} - # The mapping of original argument wrapper to remote argument wrapper - remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} - # The mapping of ainfo to argument and dep_mod # Used to lookup which argument and dep_mod a given ainfo is generated from # N.B. This is a mapping for remote argument copies - ainfo_arg::Dict{AliasingWrapper,Set{ArgumentWrapper}} + ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to # Updated when a new write happens on an overlapping ainfo @@ -397,7 +288,7 @@ struct DataDepsState # The mapping of, for a given memory space, the backing Chunks that an ainfo references # Used by slot generation to replace the backing Chunks during move - ainfo_backing_chunk::Chunk{AliasedObjectCacheStore} + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} # Cache of argument's supports_inplace_move query result supports_inplace_cache::IdDict{Any,Bool} @@ -406,10 +297,6 @@ struct DataDepsState # N.B. This is a mapping for remote argument copies ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - # The oracle for aliasing lookups - # Used to populate ainfos_overlaps efficiently - ainfos_lookup::AliasingLookup - # The overlapping ainfos for each ainfo # Incrementally updated as new ainfos are created # Used for fast will_alias lookups @@ -421,30 +308,58 @@ struct DataDepsState ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - function DataDepsState() + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + arg_to_chunk = IdDict{Any,Chunk}() arg_origin = IdDict{Any,MemorySpace}() remote_args = Dict{MemorySpace,IdDict{Any,Any}}() remote_arg_to_original = IdDict{Any,Any}() - remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() - ainfo_arg = Dict{AliasingWrapper,Set{ArgumentWrapper}}() - arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() arg_owner = Dict{ArgumentWrapper,MemorySpace}() arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() - ainfo_backing_chunk = tochunk(AliasedObjectCacheStore()) + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() supports_inplace_cache = IdDict{Any,Bool}() ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - ainfos_lookup = AliasingLookup() ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_history, arg_owner, arg_overlaps, ainfo_backing_chunk, - supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(current_acceleration(), remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + state.ainfo_arg[ainfo] = remote_arg_w + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo end function supports_inplace_move(state::DataDepsState, arg) @@ -460,72 +375,70 @@ function is_writedep(arg, deps, task::DTask) end # Aliasing state setup -function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) - # Track the task's arguments and access patterns - return map_or_ntuple(task_args) do idx - _arg = task_args[idx] - - # Unwrap the argument - _arg_with_deps = value(_arg) - pos = _arg.pos +# Internal: iterate over task args and call callback(arg, pos, may_alias, inplace_move, deps) for each tracked arg. +function _populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask, callback) + for (idx, _arg) in enumerate(spec.fargs) + arg_pos = _arg.pos # ArgPosition for this argument (Argument/TypedArgument have .pos) + arg = value(_arg) # Unwrap In/InOut/Out wrappers and record dependencies - arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap - - # Skip non-aliasing arguments or arguments that don't support in-place move - may_alias = type_may_alias(typeof(arg)) - inplace_move = may_alias && supports_inplace_move(state, arg) - if !may_alias || !inplace_move - arg_w = ArgumentWrapper(arg, identity) - if is_typed(spec) - return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) - else - return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) - end + arg, deps = unwrap_inout(arg) + + # Unwrap the Chunk underlying any DTask arguments only when already ready. + # Fetching an unready DTask here would deadlock: distribute_tasks! runs before + # the scheduler, so dependent tasks have not run yet. Skip aliasing for unready + # DTasks so we pass them through; the worker will fetch at execution time (may block on MPI). + if arg isa DTask + isready(arg) || continue + arg = fetch(arg; move_value=false, unwrap=false) end + # Skip non-aliasing arguments + type_may_alias(typeof(arg)) || continue + + # Skip arguments not supporting in-place move + supports_inplace_move(state, arg) || continue + # Generate a Chunk for the argument if necessary if haskey(state.raw_arg_to_chunk, arg) - arg_chunk = state.raw_arg_to_chunk[arg] + arg = state.raw_arg_to_chunk[arg] else if !(arg isa Chunk) - arg_chunk = tochunk(arg) - state.raw_arg_to_chunk[arg] = arg_chunk + new_arg = with(MPI_UID=>task.uid) do + tochunk(arg) + end + state.raw_arg_to_chunk[arg] = new_arg + arg = new_arg else state.raw_arg_to_chunk[arg] = arg - arg_chunk = arg end end # Track the origin space of the argument - origin_space = memory_space(arg_chunk) - state.arg_origin[arg_chunk] = origin_space - state.remote_arg_to_original[arg_chunk] = arg_chunk + origin_space = memory_space(arg) + check_uniform(origin_space) + state.arg_origin[arg] = origin_space + state.remote_arg_to_original[arg] = arg + + may_alias = true + inplace_move = true + callback(arg, arg_pos, may_alias, inplace_move, deps) # Populate argument info for all aliasing dependencies - # And return the argument, dependencies, and ArgumentWrappers - if is_typed(spec) - deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) - map_or_ntuple(deps) do dep_idx - dep = deps[dep_idx] - # Populate argument info - populate_argument_info!(state, dep.arg_w, origin_space) - end - return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) - else - deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] - map_or_ntuple(deps) do dep_idx - dep = deps[dep_idx] - # Populate argument info - populate_argument_info!(state, dep.arg_w, origin_space) - end - return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + for (dep_mod, _, _) in deps + # Generate an ArgumentWrapper for the argument + aw = ArgumentWrapper(arg, dep_mod) + + # Populate argument info + populate_argument_info!(state, aw, origin_space) end end end + +function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns (callback only for state updates) + _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> nothing) +end function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) # Initialize ownership and history if !haskey(state.arg_owner, arg_w) @@ -545,56 +458,23 @@ function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, o # Calculate the ainfo (which will populate ainfo structures and merge history) aliasing!(state, origin_space, arg_w) end -# N.B. arg_w must be the original argument wrapper, not a remote copy -function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) - if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) - remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] - remote_arg = remote_arg_w.arg - else - # Grab the remote copy of the argument, and calculate the ainfo - remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) - remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) - get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w - end - - # Check if we already have the result cached - if haskey(state.ainfo_cache, remote_arg_w) - return state.ainfo_cache[remote_arg_w] - end - - # Calculate the ainfo - ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) - - # Cache the result - state.ainfo_cache[remote_arg_w] = ainfo - - # Update the mapping of ainfo to argument and dep_mod - if !haskey(state.ainfo_arg, ainfo) - state.ainfo_arg[ainfo] = Set{ArgumentWrapper}([remote_arg_w]) - end - push!(state.ainfo_arg[ainfo], remote_arg_w) - - # Populate info for the new ainfo - populate_ainfo!(state, arg_w, ainfo, target_space) - - return ainfo -end function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + # Initialize owner and readers if !haskey(state.ainfos_owner, target_ainfo) - # Add ourselves to the lookup oracle - ainfo_idx = push!(state.ainfos_lookup, target_ainfo) - - # Find overlapping ainfos overlaps = Set{AliasingWrapper}() push!(overlaps, target_ainfo) - for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) + other_ainfos = (Dagger.current_acceleration() isa Dagger.MPIAcceleration + ? sort(collect(keys(state.ainfos_owner)), by=hash) + : keys(state.ainfos_owner)) + for other_ainfo in other_ainfos target_ainfo == other_ainfo && continue - # Mark us and them as overlapping - push!(overlaps, other_ainfo) - push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + if will_alias(target_ainfo, other_ainfo) + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) - # Add overlapping history to our own - for other_remote_arg_w in state.ainfo_arg[other_ainfo] + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) push!(state.arg_overlaps[original_arg_w], other_arg_w) @@ -603,8 +483,6 @@ function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, end end state.ainfos_overlaps[target_ainfo] = overlaps - - # Initialize owner and readers state.ainfos_owner[target_ainfo] = nothing state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] end @@ -613,6 +491,7 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ history = state.arg_history[arg_w] @opcounter :merge_history @opcounter :merge_history_complexity length(history) + largest_value_update!(length(history)) origin_space = state.arg_origin[other_arg_w.arg] for other_entry in state.arg_history[other_arg_w] write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) @@ -641,7 +520,6 @@ function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_ end end function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) - # FIXME: Do this continuously if possible if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 origin_space = state.arg_origin[arg_w.arg] @opcounter :truncate_history @@ -665,8 +543,11 @@ use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` region returns. """ supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; move_value=false, unwrap=false)) +@warn "Fix this to work with MPI (can't call poolget on the wrong rank)" maxlog=1 function supports_inplace_move(c::Chunk) + # FIXME + return true # FIXME: Use MemPool.access_ref pid = root_worker_id(c.processor) if pid == myid() @@ -738,12 +619,19 @@ function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::M push!(state.ainfos_readers[ainfo], task=>write_num) end +# FIXME: These should go in MPIExt.jl +const MPI_TID = ScopedValue{Int64}(0) +const MPI_UID = ScopedValue{Int64}(0) + # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} isremotehandle(x) = false isremotehandle(x::DTask) = true isremotehandle(x::Chunk) = true function generate_slot!(state::DataDepsState, dest_space, data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end # N.B. We do not perform any sync/copy with the current owner of the data, # because all we want here is to make a copy of some version of the data, # even if the data is not up to date. @@ -751,16 +639,30 @@ function generate_slot!(state::DataDepsState, dest_space, data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - aliased_object_cache = AliasedObjectCache(dest_space, state.ainfo_backing_chunk) - ctx = Sch.eager_context() - id = rand(Int) - @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_chunk = move_rewrap(aliased_object_cache, from_proc, to_proc, orig_space, dest_space, data) - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping + task = DATADEPS_CURRENT_TASK[] + data_chunk = with(MPI_UID=>task.uid) do + tochunk(data, from_proc) + end + else + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + end @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" dest_space_args[data] = data_chunk state.remote_arg_to_original[data_chunk] = data + ALIASED_OBJECT_CACHE[] = nothing + + check_uniform(memory_space(dest_space_args[data])) + check_uniform(processor(dest_space_args[data])) + check_uniform(dest_space_args[data].handle) + return dest_space_args[data] end function get_or_generate_slot!(state, dest_space, data) @@ -773,82 +675,86 @@ function get_or_generate_slot!(state, dest_space, data) end return state.remote_args[dest_space][data] end -function remotecall_endpoint(f, from_proc, to_proc, from_space, to_space, data) - to_w = root_worker_id(to_proc) - if to_w == myid() - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) - end - return remotecall_fetch(to_w, from_proc, to_proc, to_space, data) do from_proc, to_proc, to_space, data - data_converted = f(move(from_proc, to_proc, data)) - return tochunk(data_converted, to_proc) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + return aliased_object!(data) do data + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, data) end end -function rewrap_aliased_object!(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x) - return aliased_object!(cache, x) do x - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, x) +function remotecall_endpoint(f, ::Dagger.DistributedAcceleration, from_proc, to_proc, orig_space, dest_space, data) + to_w = root_worker_id(to_proc) + return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data::Chunk) - # Unwrap so that we hit the right dispatch - wid = root_worker_id(data) - if wid != myid() - return remotecall_fetch(move_rewrap, wid, cache, from_proc, to_proc, from_space, to_space, data) - end - data_raw = unwrap(data) - return move_rewrap(cache, from_proc, to_proc, from_space, to_space, data_raw) +const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) + +# Explicit cache for move_rewrap (used by haloarray, tests) +struct AliasedObjectCacheStore end +struct AliasedObjectCache + dest_space::MemorySpace + backing::Chunk + cache::Dict{AbstractAliasing,Chunk} + AliasedObjectCache(dest_space::MemorySpace, backing::Chunk) = new(dest_space, backing, Dict{AbstractAliasing,Chunk}()) end function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) - # For generic data - return aliased_object!(cache, data) do data - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + old = ALIASED_OBJECT_CACHE[] + ALIASED_OBJECT_CACHE[] = cache.cache + try + return move_rewrap(from_proc, to_proc, from_space, to_space, data) + finally + ALIASED_OBJECT_CACHE[] = old end end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) - to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - inds = parentindices(v) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, inds) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) - end + +@warn "Document these public methods" maxlog=1 +# TODO: Use state to cache aliasing() results +function declare_aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + cache[ainfo] = x end -# FIXME: Do this programmatically via recursive dispatch -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular) - @eval function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::$(wrapper)) - to_w = root_worker_id(to_proc) - p_chunk = rewrap_aliased_object!(cache, from_proc, to_proc, from_space, to_space, parent(v)) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk) do from_proc, to_proc, from_space, to_space, p_chunk - p_new = move(from_proc, to_proc, p_chunk) - v_new = $(wrapper)(p_new) - return tochunk(v_new, to_proc) - end +function aliased_object!(x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" + cache[ainfo] = x + y = x end + return y end -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::Base.RefValue) - return aliased_object!(cache, v) do v - return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, v) +function aliased_object!(f, x; ainfo=aliasing(current_acceleration(), x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y end + return y end -#= FIXME: Make this work so we can automatically move-rewrap recursive objects -function move_rewrap_recursive(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::T) where T - if isstructtype(T) - # Check all object fields (recursive) - for field in fieldnames(T) - value = getfield(x, field) - new_value = aliased_object!(cache, value) do value - return move_rewrap_recursive(cache, from_proc, to_proc, from_space, to_space, value) - end - setfield!(x, field, new_value) - end - return x - else - @warn "Cannot move-rewrap object of type $T" - return x +function aliased_object_unwrap!(x::Chunk) + y = unwrap(x) + ainfo = aliasing(current_acceleration(), y, identity) + return unwrap(aliased_object!(x; ainfo)) +end + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) end end -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::String) = x # FIXME: Not necessarily true -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x -move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x -=# diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl index 1c2aa600f..6e2a21dfd 100644 --- a/src/datadeps/chunkview.jl +++ b/src/datadeps/chunkview.jl @@ -3,10 +3,6 @@ struct ChunkView{N} slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} end -function _identity_hash(arg::ChunkView, h::UInt=UInt(0)) - return hash(arg.slices, _identity_hash(arg.chunk, h)) -end - function Base.view(c::Chunk, slices...) if c.domain isa ArrayDomain nd, sz = ndims(c.domain), size(c.domain) @@ -29,39 +25,31 @@ function Base.view(c::Chunk, slices...) return ChunkView(c, slices) end -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) +Base.view(c::DTask, slices...) = view(fetch(c; move_value=false, unwrap=false), slices...) -function aliasing(x::ChunkView{N}) where N - return remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end -end +aliasing(x::ChunkView) = + throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) memory_space(x::ChunkView) = memory_space(x.chunk) isremotehandle(x::ChunkView) = true -function move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) - to_w = root_worker_id(to_proc) - # N.B. We use move_rewrap (not rewrap_aliased_object!) so that if the inner - # chunk is a SubArray, it goes through the SubArray-aware path which shares - # the parent array via the aliased object cache. Using rewrap_aliased_object! - # would simply serialize the entire SubArray, creating a new parent copy on - # the destination, breaking aliasing with other views of the same parent. - p_chunk = move_rewrap(cache, from_proc, to_proc, from_space, to_space, slice.chunk) - return remotecall_fetch(to_w, from_proc, to_proc, from_space, to_space, p_chunk, slice.slices) do from_proc, to_proc, from_space, to_space, p_chunk, inds - p_new = move(from_proc, to_proc, p_chunk) - v_new = view(p_new, inds...) - return tochunk(v_new, to_proc) +# This definition is here because it's so similar to ChunkView +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, v::SubArray) + p_chunk = aliased_object!(parent(v)) do p_chunk + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) + end + inds = parentindices(v) + return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new + return view(p_new, inds...) end end -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - to_w = root_worker_id(to_proc) - return remotecall_fetch(to_w, from_proc, to_proc, slice.chunk, slice.slices) do from_proc, to_proc, chunk, slices - chunk_new = move(from_proc, to_proc, chunk) - v_new = view(chunk_new, slices...) - return tochunk(v_new, to_proc) +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, slice::ChunkView) + p_chunk = aliased_object!(slice.chunk) do p_chunk + return remotecall_endpoint(identity, current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) + end + inds = slice.slices + return remotecall_endpoint(current_acceleration(), from_proc, to_proc, from_space, to_space, p_chunk) do p_new + return view(p_new, inds...) end end -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl index 3e8d89d50..b203c3e44 100644 --- a/src/datadeps/queue.jl +++ b/src/datadeps/queue.jl @@ -1,4 +1,21 @@ -struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue + +const TAG_WAITING = Base.Lockable(Ref{UInt32}(1)) +function to_tag() + intask = Dagger.in_task() + if intask + opts = Dagger.get_tls().task_spec.options + tag = opts.tag + return tag + end + lock(TAG_WAITING) do counter_ref + @assert Sch.SCHED_MOVE[] == false "We should not create a tag on the scheduler unwrap move" + tag = counter_ref[] + counter_ref[] = tag + 1 > MPI.tag_ub() ? 1 : tag + 1 + return tag + end +end + +struct DataDepsTaskQueue <: AbstractTaskQueue # The queue above us upper_queue::AbstractTaskQueue # The set of tasks that have already been seen @@ -7,14 +24,24 @@ struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue g::Union{SimpleDiGraph{Int},Nothing} # The mapping from task to graph ID task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol # Which scheduler to use to assign tasks to processors - scheduler::Scheduler + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool - function DataDepsTaskQueue(upper_queue; scheduler::DataDepsScheduler) + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) seen_tasks = DTaskPair[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() - return new{typeof(scheduler)}(upper_queue, seen_tasks, g, task_to_id, scheduler) + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) end end @@ -25,8 +52,10 @@ function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) append!(queue.seen_tasks, pairs) end +const DATADEPS_CURRENT_TASK = TaskLocalValue{Union{DTask,Nothing}}(Returns(nothing)) + """ - spawn_datadeps(f::Base.Callable) + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) Constructs a "datadeps" (data dependencies) region and calls `f` within it. Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or @@ -53,41 +82,46 @@ appropriately. At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks to complete, rethrowing the first error, if any. The result of `f` will be returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. """ function spawn_datadeps(f::Base.Callable; static::Bool=true, traversal::Symbol=:inorder, - scheduler::Union{DataDepsScheduler,Nothing}=nothing, + scheduler::Union{Symbol,Nothing}=nothing, aliasing::Bool=true, launch_wait::Union{Bool,Nothing}=nothing) if !static throw(ArgumentError("Dynamic scheduling is no longer available")) end - if traversal != :inorder - throw(ArgumentError("Traversal order is no longer configurable, and always :inorder")) - end - if !aliasing - throw(ArgumentError("Aliasing analysis is no longer optional")) - end wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler()) + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool if launch_wait result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) with_options(f; task_queue=queue) distribute_tasks!(queue) end else - queue = DataDepsTaskQueue(get_options(:task_queue); scheduler) + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) result = with_options(f; task_queue=queue) distribute_tasks!(queue) end + DATADEPS_CURRENT_TASK[] = nothing return result end end -const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing) +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) +@warn "Don't blindly set occupancy=0, only do for MPI" maxlog=1 function distribute_tasks!(queue::DataDepsTaskQueue) #= TODO: Improvements to be made: # - Support for copying non-AbstractArray arguments @@ -98,37 +132,96 @@ function distribute_tasks!(queue::DataDepsTaskQueue) =# # Get the set of all processors to be scheduled on - all_procs = Processor[] scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) + accel = current_acceleration() + accel_procs = filter(procs(Dagger.Sch.eager_context())) do proc + Dagger.accel_matches_proc(accel, proc) end + all_procs = unique(vcat([collect(Dagger.get_processors(gp)) for gp in accel_procs]...)) + # FIXME: This is an unreliable way to ensure processor uniformity + sort!(all_procs, by=short_name) filter!(proc->proc_in_scope(proc, scope), all_procs) if isempty(all_procs) throw(Sch.SchedulingException("No processors available, try widening scope")) end - all_scope = UnionScope(map(ExactScope, all_procs)) exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + #=if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end=# + for proc in all_procs + check_uniform(proc) end # Round-robin assign tasks to processors upper_queue = get_options(:task_queue) + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + # Start launching tasks and necessary copies - state = DataDepsState() write_num = 1 + proc_idx = 1 + #pressures = Dict{Processor,Int}() proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for pair in queue.seen_tasks + for pair in queue.seen_tasks[task_order] spec = pair.spec task = pair.task - write_num = distribute_task!(queue, state, all_procs, all_scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num) + write_num, proc_idx = distribute_task!(queue, state, all_procs, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) end # Copy args from remote to local # N.B. We sort the keys to ensure a deterministic order for uniformity + check_uniform(length(state.arg_owner)) for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + check_uniform(arg_w) arg = arg_w.arg origin_space = state.arg_origin[arg] remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) @@ -141,10 +234,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy_skip, (;id), (;)) - @maybelog ctx timespan_finish(ctx, :datadeps_copy_skip, (;id), (;thunk_id=0, from_space=origin_space, to_space=origin_space, arg_w, from_arg=arg, to_arg=arg)) end end end @@ -170,29 +259,174 @@ struct TypedDataDepsTaskArgument{T,N} deps::NTuple{N,DataDepsTaskDependency} end map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) -@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N)) -function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, all_scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int) where typed +map_or_ntuple(f, xs::Tuple) = ntuple(f, length(xs)) + +# 4-arg version: side effects + returns Vector/Tuple of DataDepsTaskArgument for distribute_task! +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + result = DataDepsTaskArgument[] + _populate_task_info!(state, spec, task, (arg, pos, may_alias, inplace_move, deps) -> begin + dep_infos = DataDepsTaskDependency[DataDepsTaskDependency(arg, d) for d in deps] + push!(result, DataDepsTaskArgument(arg, pos, may_alias, inplace_move, dep_infos)) + end) + return spec.fargs isa Tuple ? (result...,) : result +end + +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed @specialize spec fargs + DATADEPS_CURRENT_TASK[] = task + if typed fargs::Tuple else fargs::Vector{Argument} end - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) scheduler = queue.scheduler - our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task) + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered (skip when MPI for deterministic tie-breaking) + procs = if current_acceleration() isa Dagger.MPIAcceleration + collect(all_procs) + else + P = randperm(length(all_procs)) + getindex.(Ref(all_procs), P) + end + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = if current_acceleration() isa Dagger.MPIAcceleration + first(sort(collect(our_space_procs), by=short_name)) + else + rand(our_space_procs) + end + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + else + error("Invalid scheduler: $sched") + end @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) # Find the scope for this task (and its copies) task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - if task_scope == all_scope + if task_scope == scope # Optimize for the common case, cache the proc=>scope mapping our_scope = get!(proc_to_scope_lfu, our_proc) do our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), all_scope) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) end else # Use the provided scope and constrain it to the available processors @@ -202,12 +436,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr if our_scope isa InvalidScope throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) end + check_uniform(our_proc) + check_uniform(our_space) f = spec.fargs[1] - tid = task.uid # FIXME: May not be correct to move this under uniformity #f.value = move(default_processor(), our_proc, value(f)) - @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" # Copy raw task arguments for analysis # N.B. Used later for checking dependencies @@ -234,13 +469,13 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Is the data written previously or now? if !arg_ws.may_alias - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" return arg end # Is the data writeable? if !arg_ws.inplace_move - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" return arg end @@ -257,7 +492,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) else @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" end end return arg_remote @@ -276,9 +511,6 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end # Check that any mutable and written arguments are already in the correct space - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move arg_space = memory_space(arg) @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" @@ -287,8 +519,12 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr # Calculate this task's syncdeps if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() + spec.options.syncdeps = Set{Any}() end + if spec.options.tag === nothing + spec.options.tag = to_tag() + end + syncdeps = spec.options.syncdeps map_or_ntuple(task_arg_ws) do idx arg_ws = task_arg_ws[idx] @@ -301,33 +537,46 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" get_write_deps!(state, our_space, ainfo, write_num, syncdeps) else - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" get_read_deps!(state, our_space, ainfo, write_num, syncdeps) end end return end - @dagdebug tid :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" - - # Launch user's task - new_fargs = map_or_ntuple(task_arg_ws) do idx - if is_typed(spec) - return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) - else - return Argument(task_arg_ws[idx].pos, remote_args[idx]) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task: preserve full argument list (spec.fargs); use remote values only for tracked args + new_fargs = if spec.fargs isa Tuple + ntuple(length(spec.fargs)) do i + arg = spec.fargs[i] + pos = arg.pos + j = findfirst(w -> w.pos == pos, task_arg_ws) + if j !== nothing + val = remote_args[j] + is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) + else + copy(arg) + end end + else + [let arg = spec.fargs[i], pos = arg.pos + j = findfirst(w -> w.pos == pos, task_arg_ws) + if j !== nothing + val = remote_args[j] + is_typed(spec) ? TypedArgument(pos, val) : Argument(pos, val) + else + copy(arg) + end + end for i in 1:length(spec.fargs)] end new_spec = DTaskSpec(new_fargs, spec.options) new_spec.options.scope = our_scope new_spec.options.exec_scope = our_scope new_spec.options.occupancy = Dict(Any=>0) - ctx = Sch.eager_context() - @maybelog ctx timespan_start(ctx, :datadeps_execute, (;thunk_id=task.uid), (;)) enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) - @maybelog ctx timespan_finish(ctx, :datadeps_execute, (;thunk_id=task.uid), (;space=our_space, deps=task_arg_ws, args=remote_args)) # Update read/write tracking for arguments map_or_ntuple(task_arg_ws) do idx @@ -340,7 +589,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr ainfo = aliasing!(state, our_space, arg_w) dep_mod = arg_w.dep_mod if dep.writedep - @dagdebug tid :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" add_writer!(state, arg_w, our_space, ainfo, task, write_num) else add_reader!(state, arg_w, our_space, ainfo, task, write_num) @@ -350,6 +599,7 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr end write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) - return write_num + return write_num, proc_idx end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl index 2c2c49920..671365793 100644 --- a/src/datadeps/remainders.jl +++ b/src/datadeps/remainders.jl @@ -9,11 +9,10 @@ This is used to perform partial data copies that only update the "remainder" reg struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing space::S spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} - ainfos::Vector{AliasingWrapper} syncdeps::Set{ThunkSyncdep} end -RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, ainfos::Vector{AliasingWrapper}, syncdeps::Set{ThunkSyncdep}) where S = - RemainderAliasing{S}(space, spans, ainfos, syncdeps) +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) memory_spans(ra::RemainderAliasing) = ra.spans @@ -43,6 +42,42 @@ memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders).. Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders +#= FIXME: Integrate with main documentation +Problem statement: + +Remainder copy calculation needs to ensure that, for a given argument and +dependency modifier, and for a given target memory space, any data not yet +updated (whether through this arg or through another that aliases) is added to +the remainder, while any data that has been updated is not in the remainder. +Remainder copies may be multi-part, as data may be spread across multiple other +memory spaces. + +Ainfo is not alone sufficient to identify the combination of argument and +dependency modifier, as ainfo is specific to an allocation in a given memory +space. Thus, this combination needs to be tracked together, and separately from +memory space. However, information may span multiple memory spaces (and thus +multiple ainfos), so we should try to make queries of cross-memory space +information fast, as they will need to be performed for every task, for every +combination. + +Game Plan: + +- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once +- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies +- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered +- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps +- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from +- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address + - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us +- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). +- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. + - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. + - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. + - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. + - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. + - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). +=# + struct FullCopy end """ @@ -87,14 +122,13 @@ function compute_remainder_for_arg!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper, write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history spaces_set = Set{MemorySpace}() push!(spaces_set, target_space) owner_space = state.arg_owner[arg_w] push!(spaces_set, owner_space) - - @label restart - - # Determine all memory spaces of the history for entry in state.arg_history[arg_w] push!(spaces_set, entry.space) end @@ -109,7 +143,6 @@ function compute_remainder_for_arg!(state::DataDepsState, push!(target_ainfos, LocalMemorySpan.(spans)) end nspans = length(first(target_ainfos)) - @assert all(==(nspans), length.(target_ainfos)) "Aliasing info for $(typeof(arg_w.arg))[$(arg_w.dep_mod)] has different number of spans in different memory spaces" # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) for entry in state.arg_history[arg_w] @@ -118,6 +151,8 @@ function compute_remainder_for_arg!(state::DataDepsState, @goto restart end end + check_uniform(spaces) + check_uniform(target_ainfos) # We may only need to schedule a full copy from the origin space to the # target space if this is the first time we've written to `arg_w` @@ -130,14 +165,10 @@ function compute_remainder_for_arg!(state::DataDepsState, end # Create our remainder as an interval tree over all target ainfos - VERIFY_SPAN_CURRENT_OBJECT[] = arg_w.arg remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) - for span in remainder - verify_span(span) - end # Create our tracker - tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Vector{AliasingWrapper},Set{ThunkSyncdep}}}() + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() # Walk backwards through the history of writes to this target # other_ainfo is the overlapping ainfo that was written to @@ -159,9 +190,11 @@ function compute_remainder_for_arg!(state::DataDepsState, other_ainfo = aliasing!(state, owner_space, arg_w) other_space = owner_space end + check_uniform(other_ainfo) + check_uniform(other_space) # Lookup all memory spans for arg_w in these spaces - other_remote_arg_w = first(collect(state.ainfo_arg[other_ainfo])) + other_remote_arg_w = state.ainfo_arg[other_ainfo] other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) other_ainfos = Vector{Vector{LocalMemorySpan}}() for space in spaces @@ -171,9 +204,9 @@ function compute_remainder_for_arg!(state::DataDepsState, end nspans = length(first(other_ainfos)) other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] - foreach(other_many_spans) do span - verify_span(span) - end + + check_uniform(other_many_spans) + check_uniform(spaces) if other_space == target_space # Only subtract, this data is already up-to-date in target_space @@ -188,19 +221,17 @@ function compute_remainder_for_arg!(state::DataDepsState, other_space_idx = something(findfirst(==(other_space), spaces)) target_space_idx = something(findfirst(==(target_space), spaces)) tracker_other_space = get!(tracker, other_space) do - (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Vector{AliasingWrapper}(), Set{ThunkSyncdep}()) + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) end @opcounter :compute_remainder_for_arg_schedule - has_overlap = schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) - if compute_syncdeps && has_overlap + schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" - get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[3]) - push!(tracker_other_space[2], other_ainfo) + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) end end - VERIFY_SPAN_CURRENT_OBJECT[] = nothing - if isempty(tracker) || all(tracked->isempty(tracked[1]), values(tracker)) + if isempty(tracker) return NoAliasing(), 0 end @@ -208,13 +239,12 @@ function compute_remainder_for_arg!(state::DataDepsState, mra = MultiRemainderAliasing() for space in spaces if haskey(tracker, space) - spans, ainfos, syncdeps = tracker[space] + spans, syncdeps = tracker[space] if !isempty(spans) - push!(mra.remainders, RemainderAliasing(space, spans, ainfos, syncdeps)) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) end end end - @assert !isempty(mra.remainders) "Expected at least one remainder (spaces: $spaces, tracker spaces: $(collect(keys(tracker))))" return mra, last_idx end @@ -230,13 +260,12 @@ copy from `other_many_spans` to the subtraced portion of `remainder`. function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N diff = Vector{ManyMemorySpan{N}}() subtract_spans!(remainder, other_many_spans, diff) + for span in diff source_span = span.spans[source_space_idx] dest_span = span.spans[dest_space_idx] - @assert span_len(source_span) == span_len(dest_span) "Source and dest spans are not the same size: $(span_len(source_span)) != $(span_len(dest_span))" push!(tracker, (source_span, dest_span)) end - return !isempty(diff) end ### Remainder copy functions @@ -250,7 +279,9 @@ Enqueues a copy operation to update the remainder regions of an object before a function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, f, idx, dest_scope, task, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) end end @@ -263,7 +294,7 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac # overwritten by more recent partial updates source_space = remainder_aliasing.space - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -276,23 +307,16 @@ function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpac push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! - source_ainfos = copy(remainder_aliasing.ainfos) - empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the sources and writes to the target - for ainfo in source_ainfos - add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end """ @@ -304,7 +328,9 @@ Enqueues a copy operation to update the remainder regions of an object back to t function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, dest_scope, write_num::Int) for remainder in remainder_aliasing.remainders + check_uniform(remainder.space) @assert !isempty(remainder.spans) + check_uniform(remainder.spans) enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) end end @@ -330,23 +356,16 @@ function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySp push!(remainder_syncdeps, syncdep) end empty!(remainder_aliasing.syncdeps) # We can't bring these to move! - source_ainfos = copy(remainder_aliasing.ainfos) - empty!(remainder_aliasing.ainfos) get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the sources and writes to the target - for ainfo in source_ainfos - add_reader!(state, arg_w, source_space, ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end @@ -357,7 +376,7 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: source_space = state.arg_owner[arg_w] target_ainfo = aliasing!(state, dest_space, arg_w) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" # Get the source and destination arguments arg_dest = state.remote_args[dest_space][arg_w.arg] @@ -370,17 +389,12 @@ function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w:: get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) - @dagdebug task.uid :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the source and writes to the target - add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + end add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, @@ -405,47 +419,38 @@ function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" # Launch the remainder copy task - ctx = Sch.eager_context() - id = rand(UInt) - @maybelog ctx timespan_start(ctx, :datadeps_copy, (;id), (;)) - copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) - @maybelog ctx timespan_finish(ctx, :datadeps_copy, (;id), (;thunk_id=copy_task.uid, from_space=source_space, to_space=dest_space, arg_w, from_arg=arg_source, to_arg=arg_dest)) - - # This copy task reads the source and writes to the target - add_reader!(state, arg_w, source_space, source_ainfo, copy_task, write_num) + copy_task = Dagger.with_options(; tag=to_tag()) do + Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps occupancy=Dict(Any=>0) meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + end + + # This copy task becomes a new writer for the target region add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) end # Main copy function for RemainderAliasing -function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S - # TODO: Support direct copy between GPU memory spaces - - # Copy the data from the source object - copies = remotecall_fetch(root_worker_id(from_space), from_space, dep_mod, from) do from_space, dep_mod, from - len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) - copies = Vector{UInt8}(undef, len) - from_raw = unwrap(from) - offset = UInt64(1) - with_context!(from_space) - GC.@preserve copies begin - for (from_span, _) in dep_mod.spans - read_remainder!(copies, offset, from_raw, from_span.ptr, from_span.len) - offset += from_span.len +function move!(dep_mod::RemainderAliasing, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) + # Get the source data for each span + copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = Vector{UInt8}[] + for (from_span, _) in dep_mod.spans + copy = Vector{UInt8}(undef, from_span.len) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copy)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) end + push!(copies, copy) end - @assert offset == len+UInt64(1) return copies end # Copy the data into the destination object - offset = UInt64(1) - to_raw = unwrap(to) - GC.@preserve copies begin - for (_, to_span) in dep_mod.spans - write_remainder!(copies, offset, to_raw, to_span.ptr, to_span.len) - offset += to_span.len + for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) end - @assert offset == length(copies)+UInt64(1) end # Ensure that the data is visible @@ -453,88 +458,3 @@ function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space: return end - -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Array, from_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(from)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) - from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} - # unsafe_wrap(Array, ...) doesn't like unaligned memory - unsafe_copyto!(Ptr{eltype(from)}(pointer(copies, copies_offset)), pointer(from_vec, from_offset_n), n) -end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::DenseArray, from_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(from)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - from_offset_n = UInt64((from_ptr - UInt64(pointer(from))) / elsize) + UInt64(1) - from_vec = reshape(from, prod(size(from)))::DenseVector{eltype(from)} - copies_typed = unsafe_wrap(Vector{eltype(from)}, Ptr{eltype(from)}(pointer(copies, copies_offset)), n) - copyto!(copies_typed, 1, from_vec, Int(from_offset_n), Int(n)) -end -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from, from_ptr::UInt64, n::UInt64) - real_from = find_object_holding_ptr(from, from_ptr) - return read_remainder!(copies, copies_offset, real_from, from_ptr, n) -end - -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Array, to_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(to)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) - to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} - # unsafe_wrap(Array, ...) doesn't like unaligned memory - unsafe_copyto!(pointer(to_vec, to_offset_n), Ptr{eltype(to)}(pointer(copies, copies_offset)), n) -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::DenseArray, to_ptr::UInt64, len::UInt64) - elsize = sizeof(eltype(to)) - @assert len / elsize == round(UInt64, len / elsize) "Span length is not an integer multiple of the element size: $(len) / $(elsize) = $(len / elsize) (elsize: $elsize)" - n = UInt64(len / elsize) - to_offset_n = UInt64((to_ptr - UInt64(pointer(to))) / elsize) + UInt64(1) - to_vec = reshape(to, prod(size(to)))::DenseVector{eltype(to)} - copies_typed = unsafe_wrap(Vector{eltype(to)}, Ptr{eltype(to)}(pointer(copies, copies_offset)), n) - copyto!(to_vec, Int(to_offset_n), copies_typed, 1, Int(n)) -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to, to_ptr::UInt64, n::UInt64) - real_to = find_object_holding_ptr(to, to_ptr) - return write_remainder!(copies, copies_offset, real_to, to_ptr, n) -end - -# Remainder copies for common objects -for wrapper in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular, SubArray) - @eval function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::$wrapper, from_ptr::UInt64, n::UInt64) - read_remainder!(copies, copies_offset, parent(from), from_ptr, n) - end - @eval function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::$wrapper, to_ptr::UInt64, n::UInt64) - write_remainder!(copies, copies_offset, parent(to), to_ptr, n) - end -end - -function read_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, from::Base.RefValue, from_ptr::UInt64, n::UInt64) - if from_ptr == UInt64(Base.pointer_from_objref(from) + fieldoffset(typeof(from), 1)) - unsafe_copyto!(pointer(copies, copies_offset), Ptr{UInt8}(from_ptr), n) - else - read_remainder!(copies, copies_offset, from[], from_ptr, n) - end -end -function write_remainder!(copies::Vector{UInt8}, copies_offset::UInt64, to::Base.RefValue, to_ptr::UInt64, n::UInt64) - if to_ptr == UInt64(Base.pointer_from_objref(to) + fieldoffset(typeof(to), 1)) - unsafe_copyto!(Ptr{UInt8}(to_ptr), pointer(copies, copies_offset), n) - else - write_remainder!(copies, copies_offset, to[], to_ptr, n) - end -end - -function find_object_holding_ptr(A::SparseMatrixCSC, ptr::UInt64) - span = LocalMemorySpan(pointer(A.nzval), length(A.nzval)*sizeof(eltype(A.nzval))) - if span_start(span) <= ptr <= span_end(span) - return A.nzval - end - span = LocalMemorySpan(pointer(A.colptr), length(A.colptr)*sizeof(eltype(A.colptr))) - if span_start(span) <= ptr <= span_end(span) - return A.colptr - end - span = LocalMemorySpan(pointer(A.rowval), length(A.rowval)*sizeof(eltype(A.rowval))) - @assert span_start(span) <= ptr <= span_end(span) "Pointer $ptr not found in SparseMatrixCSC" - return A.rowval -end \ No newline at end of file diff --git a/src/datadeps/scheduling.jl b/src/datadeps/scheduling.jl index 0bf9818f6..b2bcaca7b 100644 --- a/src/datadeps/scheduling.jl +++ b/src/datadeps/scheduling.jl @@ -111,7 +111,11 @@ function datadeps_schedule_task(sched::UltraScheduler, state::DataDepsState, all delete!(spaces_completed, our_space) continue end - our_proc = rand(our_space_procs) + our_proc = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + first(sort(collect(our_space_procs), by=Dagger.short_name)) + else + rand(our_space_procs) + end break end diff --git a/src/dtask.jl b/src/dtask.jl index e94803502..13e66cafe 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -65,11 +65,13 @@ function Base.wait(t::DTask) wait(t.future) return end -function Base.fetch(t::DTask; raw=false) +function Base.fetch(t::DTask; raw=false, move_value=nothing, unwrap=nothing) if !istaskstarted(t) throw(ConcurrencyViolationError("Cannot `fetch` an unlaunched `DTask`")) end - return fetch(t.future; raw) + # Datadeps/aliasing API: move_value=false => don't move => raw=true + raw_eff = move_value !== nothing ? !move_value : raw + return fetch(t.future; raw=raw_eff) end function waitany(tasks::Vector{DTask}) if isempty(tasks) diff --git a/src/lib/domain-blocks.jl b/src/lib/domain-blocks.jl index 2a0854e3b..95e5c360f 100644 --- a/src/lib/domain-blocks.jl +++ b/src/lib/domain-blocks.jl @@ -6,6 +6,8 @@ struct DomainBlocks{N} <: AbstractArray{ArrayDomain{N, NTuple{N, UnitRange{Int}} end Base.@deprecate_binding BlockedDomains DomainBlocks +ndims(::DomainBlocks{N}) where N = N + size(x::DomainBlocks) = map(length, x.cumlength) function _getindex(x::DomainBlocks{N}, idx::Tuple) where N starts = map((vec, i) -> i == 0 ? 0 : getindex(vec,i), x.cumlength, map(x->x-1, idx)) diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 1184f34dd..91cf88da3 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,24 +1,63 @@ +struct DistributedAcceleration <: Acceleration end + +const ACCELERATION = TaskLocalValue{Acceleration}(() -> DistributedAcceleration()) + +current_acceleration() = ACCELERATION[] + +default_processor(::DistributedAcceleration) = OSProc(myid()) +default_processor(accel::DistributedAcceleration, x) = default_processor(accel) +default_processor() = default_processor(current_acceleration()) + +accelerate!(accel::Symbol) = accelerate!(Val{accel}()) +accelerate!(::Val{:distributed}) = accelerate!(DistributedAcceleration()) + +initialize_acceleration!(a::DistributedAcceleration) = nothing +function accelerate!(accel::Acceleration) + initialize_acceleration!(accel) + ACCELERATION[] = accel +end +accelerate!(::Nothing) = nothing + +accel_matches_proc(accel::DistributedAcceleration, proc::OSProc) = true +accel_matches_proc(accel::DistributedAcceleration, proc) = true + +function compatible_processors(accel::Union{Acceleration,Nothing}, scope::AbstractScope, procs::Vector{<:Processor}) + comp = compatible_processors(scope, procs) + accel === nothing && return comp + return Set(p for p in comp if accel_matches_proc(accel, p)) +end + struct CPURAMMemorySpace <: MemorySpace owner::Int end -CPURAMMemorySpace() = CPURAMMemorySpace(myid()) root_worker_id(space::CPURAMMemorySpace) = space.owner -memory_space(x) = CPURAMMemorySpace(myid()) -function memory_space(x::Chunk) - proc = processor(x) - if proc isa OSProc - # TODO: This should probably be programmable - return CPURAMMemorySpace(proc.pid) - else - return only(memory_spaces(proc)) - end -end -memory_space(x::DTask) = - memory_space(fetch(x; raw=true)) +CPURAMMemorySpace() = CPURAMMemorySpace(myid()) + +default_processor(space::CPURAMMemorySpace) = OSProc(space.owner) +default_memory_space(accel::DistributedAcceleration) = CPURAMMemorySpace(myid()) +default_memory_space(accel::DistributedAcceleration, x) = default_memory_space(accel) +default_memory_space(x) = default_memory_space(current_acceleration(), x) +default_memory_space() = default_memory_space(current_acceleration()) + +memory_space(x, proc::Processor=default_processor()) = first(memory_spaces(proc)) +memory_space(x::Processor) = first(memory_spaces(x)) +memory_space(x::Chunk) = x.space +memory_space(x::DTask) = memory_space(fetch(x; move_value=false, unwrap=false)) memory_spaces(::P) where {P<:Processor} = throw(ArgumentError("Must define `memory_spaces` for `$P`")) + +function memory_spaces(proc::OSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end memory_spaces(proc::ThreadProc) = Set([CPURAMMemorySpace(proc.owner)]) processors(::S) where {S<:MemorySpace} = @@ -28,9 +67,12 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement -function unwrap(x::Chunk) - @assert x.handle.owner == myid() - MemPool.poolget(x.handle) +function unwrap(x::Chunk; uniform::Bool=false) + @assert root_worker_id(x.handle) == myid() "Chunk $x is not owned by this process: $(root_worker_id(x.handle)) != $(myid())" + if x.handle isa DRef + return MemPool.poolget(x.handle) + end + return MemPool.poolget(x.handle; uniform) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = throw(ArgumentError("No `move!` implementation defined for $F -> $T")) @@ -69,6 +111,16 @@ function move!(::Type{<:Tridiagonal}, to_space::MemorySpace, from_space::MemoryS return end +# FIXME: Take MemorySpace instead +function move_type(from_proc::Processor, to_proc::Processor, ::Type{T}) where T + if from_proc == to_proc + return T + end + return Base._return_type(move, Tuple{typeof(from_proc), typeof(to_proc), T}) +end +move_type(from_proc::Processor, to_proc::Processor, ::Type{<:Chunk{T}}) where T = + move_type(from_proc, to_proc, T) + ### Aliasing and Memory Spans type_may_alias(::Type{String}) = false @@ -88,20 +140,20 @@ function type_may_alias(::Type{T}) where T return false end -may_alias(::MemorySpace, ::MemorySpace) = false -may_alias(space1::M, space2::M) where M<:MemorySpace = space1 == space2 +may_alias(::MemorySpace, ::MemorySpace) = true may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner +# RemotePtr and MemorySpan are defined in utils/memory-span.jl (included earlier). + abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) -### Type-generic aliasing info wrapper - -mutable struct AliasingWrapper <: AbstractAliasing +struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 + AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -110,204 +162,8 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) - -### Small dictionary type - -struct SmallDict{K,V} <: AbstractDict{K,V} - keys::Vector{K} - vals::Vector{V} -end -SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) -function Base.getindex(d::SmallDict{K,V}, key) where {K,V} - key_idx = findfirst(==(convert(K, key)), d.keys) - if key_idx === nothing - throw(KeyError(key)) - end - return @inbounds d.vals[key_idx] -end -function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} - key_conv = convert(K, key) - key_idx = findfirst(==(key_conv), d.keys) - if key_idx === nothing - push!(d.keys, key_conv) - push!(d.vals, convert(V, val)) - else - d.vals[key_idx] = convert(V, val) - end - return val -end -Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) -Base.keys(d::SmallDict) = d.keys -Base.length(d::SmallDict) = length(d.keys) -Base.iterate(d::SmallDict) = iterate(d, 1) -Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) - -### Type-stable lookup structure for AliasingWrappers - -struct AliasingLookup - # The set of memory spaces that are being tracked - spaces::Vector{MemorySpace} - # The set of AliasingWrappers that are being tracked - # One entry for each AliasingWrapper - ainfos::Vector{AliasingWrapper} - # The memory spaces for each AliasingWrapper - # One entry for each AliasingWrapper - ainfos_spaces::Vector{Vector{Int}} - # The spans for each AliasingWrapper in each memory space - # One entry for each AliasingWrapper - spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} - # The set of AliasingWrappers that only exist in a single memory space - # One entry for each AliasingWrapper - ainfos_only_space::Vector{Int} - # The bounding span for each AliasingWrapper in each memory space - # One entry for each AliasingWrapper - bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} - # The interval tree of the bounding spans for each AliasingWrapper - # One entry for each MemorySpace - bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} - - AliasingLookup() = new(MemorySpace[], - AliasingWrapper[], - Vector{Int}[], - SmallDict{Int,Vector{LocalMemorySpan}}[], - Int[], - SmallDict{Int,LocalMemorySpan}[], - IntervalTree{LocatorMemorySpan{Int},UInt64}[]) -end -function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) - # Update the set of memory spaces and spans, - # and find the bounding spans for this AliasingWrapper - spaces_set = Set{MemorySpace}(lookup.spaces) - self_spaces_set = Set{Int}() - spans = SmallDict{Int,Vector{LocalMemorySpan}}() - for span in memory_spans(ainfo) - space = span.ptr.space - if !in(space, spaces_set) - push!(spaces_set, space) - push!(lookup.spaces, space) - push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) - end - space_idx = findfirst(==(space), lookup.spaces) - push!(self_spaces_set, space_idx) - spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) - push!(spans_in_space, LocalMemorySpan(span)) - end - push!(lookup.ainfos_spaces, collect(self_spaces_set)) - push!(lookup.spans, spans) - - # Update the set of AliasingWrappers - push!(lookup.ainfos, ainfo) - ainfo_idx = length(lookup.ainfos) - - # Check if the AliasingWrapper only exists in a single memory space - if length(self_spaces_set) == 1 - space_idx = only(self_spaces_set) - push!(lookup.ainfos_only_space, space_idx) - else - push!(lookup.ainfos_only_space, 0) - end - - # Add the bounding spans for this AliasingWrapper - bounding_spans = SmallDict{Int,LocalMemorySpan}() - for space_idx in keys(spans) - space_spans = spans[space_idx] - bound_start = minimum(span_start, space_spans) - bound_end = maximum(span_end, space_spans) - bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) - bounding_spans[space_idx] = bounding_span - insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) - end - push!(lookup.bounding_spans, bounding_spans) - - return ainfo_idx -end -struct AliasingLookupFinder - lookup::AliasingLookup - ainfo::AliasingWrapper - ainfo_idx::Int - spaces_idx::Vector{Int} - to_consider::Vector{Int} -end -Base.eltype(::AliasingLookupFinder) = AliasingWrapper -Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() -# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search -function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) - if ainfo_idx === nothing - ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) - end - spaces_idx = lookup.ainfos_spaces[ainfo_idx] - to_consider_spans = LocatorMemorySpan{Int}[] - for space_idx in spaces_idx - bounding_spans_tree = lookup.bounding_spans_tree[space_idx] - self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) - find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) - end - to_consider = Int[locator.owner for locator in to_consider_spans] - @assert all(to_consider .> 0) - return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) -end -Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) -function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) - ainfo_spaces = nothing - cursor_space_idx = 1 - - # New ainfos enter here - @label ainfo_restart - - # Check if we've exhausted all ainfos - if cursor_ainfo_idx > length(finder.to_consider) - return nothing - end - ainfo_idx = finder.to_consider[cursor_ainfo_idx] - - # Find the appropriate memory spaces for this ainfo - if ainfo_spaces === nothing - ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] - end - - # New memory spaces (for the same ainfo) enter here - @label space_restart - - # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo - if cursor_space_idx > length(ainfo_spaces) - cursor_ainfo_idx += 1 - ainfo_spaces = nothing - cursor_space_idx = 1 - @goto ainfo_restart - end - - # Find the currently considered memory space for this ainfo - space_idx = ainfo_spaces[cursor_space_idx] - - # Check if this memory space is part of our target ainfo's spaces - if !(space_idx in finder.spaces_idx) - cursor_space_idx += 1 - @goto space_restart - end - - # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space - other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] - self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] - if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) - cursor_space_idx += 1 - @goto space_restart - end - - # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing - # This is the slow path! - other_ainfo = finder.lookup.ainfos[ainfo_idx] - aliasing = will_alias(finder.ainfo, other_ainfo) - if !aliasing - cursor_ainfo_idx += 1 - ainfo_spaces = nothing - cursor_space_idx = 1 - @goto ainfo_restart - end - - # We overlap, so return the ainfo and the next ainfo index - return other_ainfo, cursor_ainfo_idx+1 -end +will_alias(x::AliasingWrapper, y::AliasingWrapper) = + will_alias(x.inner, y.inner) struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -322,11 +178,8 @@ struct CombinedAliasing <: AbstractAliasing end function memory_spans(ca::CombinedAliasing) # FIXME: Don't hardcode CPURAMMemorySpace - if length(ca.sub_ainfos) == 0 - return MemorySpan{CPURAMMemorySpace}[] - end - all_spans = memory_spans(ca.sub_ainfos[1]) - for sub_a in ca.sub_ainfos[2:end] + all_spans = MemorySpan{CPURAMMemorySpace}[] + for sub_a in ca.sub_ainfos append!(all_spans, memory_spans(sub_a)) end return all_spans @@ -336,23 +189,23 @@ Base.:(==)(ca1::CombinedAliasing, ca2::CombinedAliasing) = Base.hash(ca1::CombinedAliasing, h::UInt) = hash(ca1.sub_ainfos, hash(CombinedAliasing, h)) -struct ObjectAliasing{S<:MemorySpace} <: AbstractAliasing - ptr::RemotePtr{Cvoid,S} +struct ObjectAliasing <: AbstractAliasing + ptr::Ptr{Cvoid} sz::UInt end -ObjectAliasing(ptr::RemotePtr{Cvoid,S}, sz::Integer) where {S<:MemorySpace} = - ObjectAliasing{S}(ptr, UInt(sz)) function ObjectAliasing(x::T) where T @nospecialize x - ptr = RemotePtr{Cvoid}(pointer_from_objref(x)) + ptr = pointer_from_objref(x) sz = sizeof(T) return ObjectAliasing(ptr, sz) end -function memory_spans(oa::ObjectAliasing{S}) where S - span = MemorySpan{S}(oa.ptr, oa.sz) +function memory_spans(oa::ObjectAliasing) + rptr = RemotePtr{Cvoid}(oa.ptr) + span = MemorySpan{CPURAMMemorySpace}(rptr, oa.sz) return [span] end +aliasing(accel::Acceleration, x, T) = aliasing(x, T) function aliasing(x, dep_mod) if dep_mod isa Symbol return aliasing(getfield(x, dep_mod)) @@ -388,31 +241,16 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -function aliasing(x::Chunk, T) +aliasing(x::DTask, T) = aliasing(fetch(x; move_value=false, unwrap=false), T) +aliasing(x::DTask) = aliasing(fetch(x; move_value=false, unwrap=false)) +function aliasing(accel::DistributedAcceleration, x::Chunk, T) @assert x.handle isa DRef - if root_worker_id(x.processor) == myid() - return aliasing(unwrap(x), T) - end return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T aliasing(unwrap(x), T) end end -aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x - aliasing(unwrap(x)) -end -aliasing(x::DTask, T) = aliasing(fetch(x; raw=true), T) -aliasing(x::DTask) = aliasing(fetch(x; raw=true)) - -function aliasing(x::Base.RefValue{T}) where T - addr = UInt(Base.pointer_from_objref(x) + fieldoffset(typeof(x), 1)) - ptr = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) - ainfo = ObjectAliasing(ptr, sizeof(x)) - if isassigned(x) && type_may_alias(T) && type_may_alias(typeof(x[])) - return CombinedAliasing([ainfo, aliasing(x[])]) - else - return CombinedAliasing([ainfo]) - end -end +aliasing(x::Chunk, T) = aliasing(unwrap(x), T) +aliasing(x::Chunk) = aliasing(unwrap(x)) struct ContiguousAliasing{S} <: AbstractAliasing span::MemorySpan{S} @@ -465,22 +303,13 @@ function _memory_spans(a::StridedAliasing{T,N,S}, spans, ptr, dim) where {T,N,S} return spans end -function aliasing(x::SubArray{T,N}) where {T,N} +function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} if isbitstype(T) - p = parent(x) - space = memory_space(p) - S = typeof(space) - parent_ptr = RemotePtr{Cvoid}(UInt64(pointer(p)), space) - ptr = RemotePtr{Cvoid}(UInt64(pointer(x)), space) - NA = ndims(p) - raw_inds = parentindices(x) - inds = ntuple(i->raw_inds[i] isa Integer ? (raw_inds[i]:raw_inds[i]) : UnitRange(raw_inds[i]), NA) - sz = ntuple(i->length(inds[i]), NA) - return StridedAliasing{T,NA,S}(parent_ptr, - ptr, - inds, - sz, - strides(p)) + S = CPURAMMemorySpace + return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), + RemotePtr{Cvoid}(pointer(x)), + parentindices(x), + size(x), strides(x)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -597,8 +426,10 @@ end function will_alias(x_span::MemorySpan, y_span::MemorySpan) may_alias(x_span.ptr.space, y_span.ptr.space) || return false # FIXME: Allow pointer conversion instead of just failing - @assert x_span.ptr.space == y_span.ptr.space "Memory spans are in different spaces: $(x_span.ptr.space) vs. $(y_span.ptr.space)" + @assert x_span.ptr.space == y_span.ptr.space x_end = x_span.ptr + x_span.len - 1 y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end + +# LocalMemorySpan, ManyMemorySpan, ManyPair are defined in utils/memory-span.jl (included earlier). diff --git a/src/mpi.jl b/src/mpi.jl new file mode 100644 index 000000000..1b84a7b9d --- /dev/null +++ b/src/mpi.jl @@ -0,0 +1,948 @@ +using MPI + +const CHECK_UNIFORMITY = Ref{Bool}(false) +function check_uniformity!(check::Bool=true) + CHECK_UNIFORMITY[] = check +end +function check_uniform(value::Integer, original=value) + CHECK_UNIFORMITY[] || return true + comm = MPI.COMM_WORLD + rank = MPI.Comm_rank(comm) + matched = compare_all(value, comm) + if !matched + if rank == 0 + Core.print("[$rank] Found non-uniform value!\n") + end + Core.print("[$rank] value=$value, original=$original") + throw(ArgumentError("Non-uniform value")) + end + MPI.Barrier(comm) + return matched +end +function check_uniform(value, original=value) + CHECK_UNIFORMITY[] || return true + return check_uniform(hash(value), original) +end + +function compare_all(value, comm) + rank = MPI.Comm_rank(comm) + size = MPI.Comm_size(comm) + for i in 0:(size-1) + if i != rank + send_yield(value, comm, i, UInt32(0); check_seen=false) + end + end + match = true + for i in 0:(size-1) + if i != rank + other_value = recv_yield(comm, i, UInt32(0)) + if value != other_value + match = false + end + end + end + return match +end + +struct MPIAcceleration <: Acceleration + comm::MPI.Comm +end +MPIAcceleration() = MPIAcceleration(MPI.COMM_WORLD) + +function aliasing(accel::MPIAcceleration, x::Chunk, T) + handle = x.handle::MPIRef + @assert accel.comm == handle.comm "MPIAcceleration comm mismatch" + tag = to_tag() + check_uniform(tag) + rank = MPI.Comm_rank(accel.comm) + if handle.rank == rank + ainfo = aliasing(x, T) + #Core.print("[$rank] aliasing: $ainfo, sending\n") + @opcounter :aliasing_bcast_send_yield + bcast_send_yield(ainfo, accel.comm, handle.rank, tag) + else + #Core.print("[$rank] aliasing: receiving from $(handle.rank)\n") + ainfo = recv_yield(accel.comm, handle.rank, tag) + #Core.print("[$rank] aliasing: received $ainfo\n") + end + check_uniform(ainfo) + return ainfo +end +default_processor(accel::MPIAcceleration) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x) = MPIOSProc(accel.comm, 0) +default_processor(accel::MPIAcceleration, x::Chunk) = MPIOSProc(x.handle.comm, x.handle.rank) +default_processor(accel::MPIAcceleration, x::Function) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) +default_processor(accel::MPIAcceleration, T::Type) = MPIOSProc(accel.comm, MPI.Comm_rank(accel.comm)) + +#TODO: Add a lock +const MPIClusterProcChildren = Dict{MPI.Comm, Set{Processor}}() + +struct MPIClusterProc <: Processor + comm::MPI.Comm + function MPIClusterProc(comm::MPI.Comm) + populate_children(comm) + return new(comm) + end +end + +Sch.init_proc(state, proc::MPIClusterProc, log_sink) = Sch.init_proc(state, MPIOSProc(proc.comm), log_sink) + +MPIClusterProc() = MPIClusterProc(MPI.COMM_WORLD) + +function populate_children(comm::MPI.Comm) + children = get_processors(OSProc()) + MPIClusterProcChildren[comm] = children +end + +struct MPIOSProc <: Processor + comm::MPI.Comm + rank::Int +end + +function MPIOSProc(comm::MPI.Comm) + rank = MPI.Comm_rank(comm) + return MPIOSProc(comm, rank) +end + +function MPIOSProc() + return MPIOSProc(MPI.COMM_WORLD) +end + +ProcessScope(p::MPIOSProc) = ProcessScope(myid()) + +function check_uniform(proc::MPIOSProc, original=proc) + return check_uniform(hash(MPIOSProc), original) && + check_uniform(proc.rank, original) +end + +function memory_spaces(proc::MPIOSProc) + children = get_processors(proc) + spaces = Set{MemorySpace}() + for proc in children + for space in memory_spaces(proc) + push!(spaces, space) + end + end + return spaces +end + +struct MPIProcessScope <: AbstractScope + comm::MPI.Comm + rank::Int +end + +Base.isless(::MPIProcessScope, ::MPIProcessScope) = false +Base.isless(::MPIProcessScope, ::NodeScope) = true +Base.isless(::MPIProcessScope, ::UnionScope) = true +Base.isless(::MPIProcessScope, ::TaintScope) = true +Base.isless(::MPIProcessScope, ::AnyScope) = true +constrain(x::MPIProcessScope, y::MPIProcessScope) = + x == y ? y : InvalidScope(x, y) +constrain(x::NodeScope, y::MPIProcessScope) = + x == y.parent ? y : InvalidScope(x, y) + +Base.isless(::ExactScope, ::MPIProcessScope) = true +constrain(x::MPIProcessScope, y::ExactScope) = + x == y.parent ? y : InvalidScope(x, y) + +function enclosing_scope(proc::MPIOSProc) + return MPIProcessScope(proc.comm, proc.rank) +end + +function Dagger.to_scope(::Val{:mpi_rank}, sc::NamedTuple) + if sc.mpi_rank == Colon() + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=Colon()))) + else + @assert sc.mpi_rank isa Integer "Expected a single GPU device ID for :mpi_rank, got $(sc.mpi_rank)\nConsider using :mpi_ranks instead." + return Dagger.to_scope(Val{:mpi_ranks}(), merge(sc, (;mpi_ranks=[sc.mpi_rank]))) + end +end +Dagger.scope_key_precedence(::Val{:mpi_rank}) = 2 +function Dagger.to_scope(::Val{:mpi_ranks}, sc::NamedTuple) + comm = get(sc, :mpi_comm, MPI.COMM_WORLD) + if sc.ranks != Colon() + ranks = sc.ranks + else + ranks = MPI.Comm_size(comm) + end + inner_sc = NamedTuple(filter(kv->kv[1] != :mpi_ranks, Base.pairs(sc))...) + # FIXME: What to do here? + inner_scope = Dagger.to_scope(inner_sc) + scopes = Dagger.ExactScope[] + for rank in ranks + procs = Dagger.get_processors(Dagger.MPIOSProc(comm, rank)) + rank_scope = MPIProcessScope(comm, rank) + for proc in procs + proc_scope = Dagger.ExactScope(proc) + constrain(proc_scope, rank_scope) isa Dagger.InvalidScope && continue + push!(scopes, proc_scope) + end + end + return Dagger.UnionScope(scopes) +end +Dagger.scope_key_precedence(::Val{:mpi_ranks}) = 2 + +struct MPIProcessor{P<:Processor} <: Processor + innerProc::P + comm::MPI.Comm + rank::Int +end +proc_in_scope(proc::Processor, scope::MPIProcessScope) = false +proc_in_scope(proc::MPIProcessor, scope::MPIProcessScope) = + proc.comm == scope.comm && proc.rank == scope.rank + +function check_uniform(proc::MPIProcessor, original=proc) + return check_uniform(hash(MPIProcessor), original) && + check_uniform(proc.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(proc.innerProc), original) +end + +Dagger.iscompatible_func(::MPIProcessor, opts, ::Any) = true +Dagger.iscompatible_arg(::MPIProcessor, opts, ::Any) = true + +default_enabled(proc::MPIProcessor) = default_enabled(proc.innerProc) + +root_worker_id(proc::MPIProcessor) = myid() +root_worker_id(proc::MPIOSProc) = myid() +root_worker_id(proc::MPIClusterProc) = myid() + +get_parent(proc::MPIClusterProc) = proc +get_parent(proc::MPIOSProc) = MPIClusterProc(proc.comm) +get_parent(proc::MPIProcessor) = MPIOSProc(proc.comm, proc.rank) + +short_name(proc::MPIProcessor) = "(MPI: $(proc.rank), $(short_name(proc.innerProc)))" + +function get_processors(mosProc::MPIOSProc) + populate_children(mosProc.comm) + children = MPIClusterProcChildren[mosProc.comm] + mpiProcs = Set{Processor}() + for proc in children + push!(mpiProcs, MPIProcessor(proc, mosProc.comm, mosProc.rank)) + end + return mpiProcs +end + +#TODO: non-uniform ranking through MPI groups +#TODO: use a lazy iterator +function get_processors(proc::MPIClusterProc) + children = Set{Processor}() + for i in 0:(MPI.Comm_size(proc.comm)-1) + for innerProc in MPIClusterProcChildren[proc.comm] + push!(children, MPIProcessor(innerProc, proc.comm, i)) + end + end + return children +end + +struct MPIMemorySpace{S<:MemorySpace} <: MemorySpace + innerSpace::S + comm::MPI.Comm + rank::Int +end + +function check_uniform(space::MPIMemorySpace, original=space) + return check_uniform(space.rank, original) && + # TODO: Not always valid (if pointer is embedded, say for GPUs) + check_uniform(hash(space.innerSpace), original) +end + +default_processor(space::MPIMemorySpace) = MPIOSProc(space.comm, space.rank) +default_memory_space(accel::MPIAcceleration) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) + +default_memory_space(accel::MPIAcceleration, x) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, 0) +default_memory_space(accel::MPIAcceleration, x::Chunk) = MPIMemorySpace(CPURAMMemorySpace(myid()), x.handle.comm, x.handle.rank) +default_memory_space(accel::MPIAcceleration, x::Function) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) +default_memory_space(accel::MPIAcceleration, T::Type) = MPIMemorySpace(CPURAMMemorySpace(myid()), accel.comm, MPI.Comm_rank(accel.comm)) + +function memory_spaces(proc::MPIClusterProc) + rawMemSpace = Set{MemorySpace}() + for rnk in 0:(MPI.Comm_size(proc.comm) - 1) + for innerSpace in memory_spaces(OSProc()) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, rnk)) + end + end + return rawMemSpace +end + +function memory_spaces(proc::MPIProcessor) + rawMemSpace = Set{MemorySpace}() + for innerSpace in memory_spaces(proc.innerProc) + push!(rawMemSpace, MPIMemorySpace(innerSpace, proc.comm, proc.rank)) + end + return rawMemSpace +end + +root_worker_id(mem_space::MPIMemorySpace) = myid() + +function processors(memSpace::MPIMemorySpace) + rawProc = Set{Processor}() + for innerProc in processors(memSpace.innerSpace) + push!(rawProc, MPIProcessor(innerProc, memSpace.comm, memSpace.rank)) + end + return rawProc +end + +struct MPIRefID + tid::Int + uid::UInt + id::Int + function MPIRefID(tid, uid, id) + @assert tid > 0 || uid > 0 "Invalid MPIRefID: tid=$tid, uid=$uid, id=$id" + return new(tid, uid, id) + end +end +Base.hash(id::MPIRefID, h::UInt=UInt(0)) = + hash(id.tid, hash(id.uid, hash(id.id, hash(MPIRefID, h)))) + +function check_uniform(ref::MPIRefID, original=ref) + return check_uniform(ref.tid, original) && + check_uniform(ref.uid, original) && + check_uniform(ref.id, original) +end + +const MPIREF_TID = Dict{Int, Threads.Atomic{Int}}() +const MPIREF_UID = Dict{Int, Threads.Atomic{Int}}() + +mutable struct MPIRef + comm::MPI.Comm + rank::Int + size::Int + innerRef::Union{DRef, Nothing} + id::MPIRefID +end +Base.hash(ref::MPIRef, h::UInt=UInt(0)) = hash(ref.id, hash(MPIRef, h)) +root_worker_id(ref::MPIRef) = myid() + +function check_uniform(ref::MPIRef, original=ref) + return check_uniform(ref.rank, original) && + check_uniform(ref.id, original) +end + +move(from_proc::Processor, to_proc::Processor, x::MPIRef) = + move(from_proc, to_proc, poolget(x; uniform=FETCH_UNIFORM[])) + +function affinity(x::MPIRef) + if x.innerRef === nothing + return MPIOSProc(x.comm, x.rank)=>0 + else + return MPIOSProc(x.comm, x.rank)=>x.innerRef.size + end +end + +function take_ref_id!() + tid = 0 + uid = 0 + id = 0 + if Dagger.in_task() + tid = sch_handle().thunk_id.id + uid = 0 + counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + elseif MPI_TID[] != 0 + tid = MPI_TID[] + uid = 0 + counter = get!(MPIREF_TID, tid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + elseif MPI_UID[] != 0 + tid = 0 + uid = MPI_UID[] + counter = get!(MPIREF_UID, uid, Threads.Atomic{Int}(1)) + id = Threads.atomic_add!(counter, 1) + end + return MPIRefID(tid, uid, id) +end + +#TODO: partitioned scheduling with comm bifurcation +function tochunk_pset(x, space::MPIMemorySpace; device=nothing, kwargs...) + @assert space.comm == MPI.COMM_WORLD "$(space.comm) != $(MPI.COMM_WORLD)" + local_rank = MPI.Comm_rank(space.comm) + Mid = take_ref_id!() + if local_rank != space.rank + return MPIRef(space.comm, space.rank, 0, nothing, Mid) + else + # type= is for Chunk metadata only; MemPool.poolset does not accept it + pset_kw = (; (k => v for (k, v) in pairs(kwargs) if k !== :type)...) + return MPIRef(space.comm, space.rank, sizeof(x), poolset(x; device, pset_kw...), Mid) + end +end + +const DEADLOCK_DETECT = TaskLocalValue{Bool}(()->true) +const DEADLOCK_WARN_PERIOD = TaskLocalValue{Float64}(()->10.0) +const DEADLOCK_TIMEOUT_PERIOD = TaskLocalValue{Float64}(()->120.0) +const RECV_WAITING = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Base.Event}()) + +struct InplaceInfo + type::DataType + shape::Tuple +end +struct InplaceSparseInfo + type::DataType + m::Int + n::Int + colptr::Int + rowval::Int + nzval::Int +end + +function supports_inplace_mpi(value) + if value isa DenseArray && isbitstype(eltype(value)) + return true + else + return false + end +end +function recv_yield!(buffer, comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))") + if !supports_inplace_mpi(buffer) + return recv_yield(comm, src, tag), false + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + + buffer = recv_yield_inplace!(buffer, comm, rank, src, tag) + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + + return buffer, true + +end + +function recv_yield(comm, src, tag) + rank = MPI.Comm_rank(comm) + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv from [$src]") + + # Ensure no other receiver is waiting + our_event = Base.Event() + @label retry + other_event = lock(RECV_WAITING) do waiting + if haskey(waiting, (comm, src, tag)) + waiting[(comm, src, tag)] + else + waiting[(comm, src, tag)] = our_event + nothing + end + end + if other_event !== nothing + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Waiting for other receiver...") + wait(other_event) + @goto retry + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Receiving...") + + type = nothing + @label receive + value = recv_yield_serialized(comm, rank, src, tag) + if value isa InplaceInfo || value isa InplaceSparseInfo + value = recv_yield_inplace(value, comm, rank, src, tag) + end + + lock(RECV_WAITING) do waiting + delete!(waiting, (comm, src, tag)) + notify(our_event) + end + return value +end + +function recv_yield_inplace!(array, comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + @assert count == sizeof(array) "recv_yield_inplace: expected $(sizeof(array)) bytes, got $count" + buf = MPI.Buffer(array) + req = MPI.Imrecv!(buf, msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return array + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: Array && isbitstype(eltype(T)) "recv_yield_inplace only supports inplace MPI transfers of bitstype dense arrays" + array = Array{eltype(T)}(undef, _value.shape) + return recv_yield_inplace!(array, comm, my_rank, their_rank, tag) +end + +function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank, tag) + T = _value.type + @assert T <: SparseMatrixCSC "recv_yield_inplace only supports inplace MPI transfers of SparseMatrixCSC" + + colptr = recv_yield_inplace!(Vector{Int64}(undef, _value.colptr), comm, my_rank, their_rank, tag) + rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag) + nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag) + + return SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval) +end + +function recv_yield_serialized(comm, my_rank, their_rank, tag) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + + while true + (got, msg, stat) = MPI.Improbe(their_rank, tag, comm, MPI.Status) + if got + if MPI.Get_error(stat) != MPI.SUCCESS + error("recv_yield failed with error $(MPI.Get_error(stat))") + end + count = MPI.Get_count(stat, UInt8) + buf = Array{UInt8}(undef, count) + req = MPI.Imrecv!(MPI.Buffer(buf), msg) + __wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv") + return MPI.deserialize(buf) + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank) + yield() + end +end + +const SEEN_TAGS = Dict{Int32, Type}() +send_yield!(value, comm, dest, tag; check_seen::Bool=true) = + _send_yield(value, comm, dest, tag; check_seen, inplace=true) +send_yield(value, comm, dest, tag; check_seen::Bool=true) = + _send_yield(value, comm, dest, tag; check_seen, inplace=false) +function _send_yield(value, comm, dest, tag; check_seen::Bool=true, inplace::Bool) + rank = MPI.Comm_rank(comm) + + if check_seen && haskey(SEEN_TAGS, tag) && SEEN_TAGS[tag] !== typeof(value) + @error "[rank $(MPI.Comm_rank(comm))][tag $tag] Already seen tag (previous type: $(SEEN_TAGS[tag]), new type: $(typeof(value)))" exception=(InterruptException(),backtrace()) + end + if check_seen + SEEN_TAGS[tag] = typeof(value) + end + #Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting send to [$dest]: $(typeof(value)), is support inplace? $(supports_inplace_mpi(value))") + if inplace && supports_inplace_mpi(value) + send_yield_inplace(value, comm, rank, dest, tag) + else + send_yield_serialized(value, comm, rank, dest, tag) + end +end + +function send_yield_inplace(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_inplace + req = MPI.Isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") +end + +function send_yield_serialized(value, comm, my_rank, their_rank, tag) + @opcounter :send_yield_serialized + if value isa Array && isbitstype(eltype(value)) + send_yield_serialized(InplaceInfo(typeof(value), size(value)), comm, my_rank, their_rank, tag) + send_yield_inplace(value, comm, my_rank, their_rank, tag) + elseif value isa SparseMatrixCSC && isbitstype(eltype(value)) + send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag) + send_yield_inplace(value.colptr, comm, my_rank, their_rank, tag) + send_yield_inplace(value.rowval, comm, my_rank, their_rank, tag) + send_yield_inplace(value.nzval, comm, my_rank, their_rank, tag) + else + req = MPI.isend(value, comm; dest=their_rank, tag) + __wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send") + end +end + +function __wait_for_request(req, comm, my_rank, their_rank, tag, fn::String, kind::String) + time_start = time_ns() + detect = DEADLOCK_DETECT[] + warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9) + timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9) + while true + finish, status = MPI.Test(req, MPI.Status) + if finish + if MPI.Get_error(status) != MPI.SUCCESS + error("$fn failed with error $(MPI.Get_error(status))") + end + return + end + warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, kind, their_rank) + yield() + end +end + +function bcast_send_yield(value, comm, root, tag) + @opcounter :bcast_send_yield + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + for other_rank in 0:(sz-1) + rank == other_rank && continue + send_yield(value, comm, other_rank, tag) + end +end + +#= Maybe can be worth it to implement this +function bcast_send_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + + for other_rank in 0:(sz-1) + rank == other_rank && continue + #println("[rank $rank] Sending to rank $other_rank") + send_yield!(value, comm, other_rank, tag) + end +end + +function bcast_recv_yield!(value, comm, root, tag) + sz = MPI.Comm_size(comm) + rank = MPI.Comm_rank(comm) + #println("[rank $rank] receive from rank $root") + recv_yield!(value, comm, root, tag) +end +=# +function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest) + time_elapsed = (time_ns() - time_start) + if detect && time_elapsed > warn_period + @warn "[rank $rank][tag $tag] Hit probable hang on $kind (dest: $srcdest)" + return typemax(UInt64) + end + if detect && time_elapsed > timeout_period + error("[rank $rank][tag $tag] Hit hang on $kind (dest: $srcdest)") + end + return warn_period +end + +#discuss this with julian +WeakChunk(c::Chunk{T,H}) where {T,H<:MPIRef} = WeakChunk(c.handle.rank, c.handle.id.id, WeakRef(c)) + +function MemPool.poolget(ref::MPIRef; uniform::Bool=false) + @assert uniform || ref.rank == MPI.Comm_rank(ref.comm) "MPIRef rank mismatch: $(ref.rank) != $(MPI.Comm_rank(ref.comm))" + if uniform + tag = to_tag() + if ref.rank == MPI.Comm_rank(ref.comm) + value = poolget(ref.innerRef) + @opcounter :poolget_bcast_send_yield + bcast_send_yield(value, ref.comm, ref.rank, tag) + return value + else + return recv_yield(ref.comm, ref.rank, tag) + end + else + return poolget(ref.innerRef) + end +end +fetch_handle(ref::MPIRef; uniform::Bool=false) = poolget(ref; uniform) + +function move!(dep_mod, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + tag = to_tag() + if local_rank == from_space.rank + send_yield!(poolget(from.handle; uniform=false), to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + #@dagdebug nothing :mpi "[$local_rank][$tag] Receiving from rank $(from_space.rank) with tag $tag, type of buffer: $(typeof(poolget(to.handle; uniform=false)))" + to_val = poolget(to.handle; uniform=false) + val, inplace = recv_yield!(to_val, from_space.comm, from_space.rank, tag) + if !inplace + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to_val, val) + end + end + end + @dagdebug nothing :mpi "[$local_rank][$tag] Finished moving from $(from_space.rank) to $(to_space.rank) successfuly\n" +end +function move!(dep_mod::RemainderAliasing{<:MPIMemorySpace}, to_space::MPIMemorySpace, from_space::MPIMemorySpace, to::Chunk, from::Chunk) + @assert to.handle isa MPIRef && from.handle isa MPIRef "MPIRef expected" + @assert to.handle.comm == from.handle.comm "MPIRef comm mismatch" + @assert to.handle.rank == to_space.rank && from.handle.rank == from_space.rank "MPIRef rank mismatch" + local_rank = MPI.Comm_rank(from.handle.comm) + if to_space.rank == from_space.rank == local_rank + move!(dep_mod, to_space.innerSpace, from_space.innerSpace, to, from) + else + tag = to_tag() + @dagdebug nothing :mpi "[$local_rank][$tag] Moving from $(from_space.rank) to $(to_space.rank)\n" + if local_rank == from_space.rank + # Get the source data for each span + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + offset = 1 + for (from_span, _) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copies, offset)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + offset += from_span.len + #end + end + + # Send the spans + #send_yield(len, to_space.comm, to_space.rank, tag) + send_yield!(copies, to_space.comm, to_space.rank, tag; check_seen=false) + #send_yield(copies, to_space.comm, to_space.rank, tag) + elseif local_rank == to_space.rank + # Receive the spans + len = sum(span_tuple->span_len(span_tuple[1]), dep_mod.spans) + copies = Vector{UInt8}(undef, len) + recv_yield!(copies, from_space.comm, from_space.rank, tag) + #copies = recv_yield(from_space.comm, from_space.rank, tag) + + # Copy the data into the destination object + #for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + offset = 1 + for (_, to_span) in dep_mod.spans + #GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copies, offset)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + offset += to_span.len + #end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + end + end + + return +end + + +move(::MPIOSProc, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIOSProc, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +#TODO: out of place MPI move +function move(src::MPIOSProc, dst::MPIProcessor, x::Chunk) + @assert src.comm == dst.comm "Multi comm move not supported" + if Sch.SCHED_MOVE[] + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permited" + @assert src.rank == x.handle.rank == dst.rank + return poolget(x.handle) + end +end + +const MPI_UNIFORM = ScopedValue{Bool}(false) +# When true, move(_, _, MPIRef) uses poolget(; uniform=true) so the owner bcasts and the fetcher recv (e.g. rank 0 collecting). +const FETCH_UNIFORM = ScopedValue{Bool}(true) + +function remotecall_endpoint(f, accel::Dagger.MPIAcceleration, from_proc, to_proc, from_space, to_space, data) + loc_rank = MPI.Comm_rank(accel.comm) + task = DATADEPS_CURRENT_TASK[] + return with(MPI_UID=>task.uid, MPI_UNIFORM=>true) do + @assert data isa Chunk "Expected Chunk, got $(typeof(data))" + space = memory_space(data) + tag = to_tag() + type_tag = to_tag() + T = move_type(from_proc.innerProc, to_proc.innerProc, chunktype(data)) + T_new = f !== identity ? Base._return_type(f, Tuple{T}) : T + need_bcast = !isconcretetype(T_new) || T_new === Union{} || T_new === Nothing || T_new === Any + + if space.rank != from_proc.rank + # Data is already at destination (to_proc.rank) + @assert space.rank == to_proc.rank + if space.rank == loc_rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end + + # Data is on the source rank + @assert space.rank == from_proc.rank + if loc_rank == from_proc.rank == to_proc.rank + value = poolget(data.handle) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, value)) + return tochunk(data_converted, to_proc, to_space; type=typeof(data_converted)) + end + + if loc_rank == from_proc.rank + value = poolget(data.handle) + data_moved = move(from_proc.innerProc, to_proc.innerProc, value) + Dagger.send_yield(data_moved, accel.comm, to_proc.rank, tag) + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + elseif loc_rank == to_proc.rank + data_moved = Dagger.recv_yield(accel.comm, from_space.rank, tag) + data_converted = f(move(from_proc.innerProc, to_proc.innerProc, data_moved)) + T_actual = typeof(data_converted) + if need_bcast + bcast_send_yield(T_actual, accel.comm, to_proc.rank, type_tag) + end + return tochunk(data_converted, to_proc, to_space; type=T_actual) + else + T_actual = need_bcast ? recv_yield(accel.comm, to_proc.rank, type_tag) : T_new + return tochunk(nothing, to_proc, to_space; type=T_actual) + end + end +end + +# Chunk may be MPI-backed (MPIRef) but labeled with OSProc; treat source as the owning rank +function move(src::OSProc, dst::MPIProcessor, x::Chunk) + if x.handle isa MPIRef + return move(MPIOSProc(x.handle.comm, x.handle.rank), dst, x) + end + error("MPI move not supported") +end + +move(src::Processor, dst::MPIProcessor, x::Chunk) = error("MPI move not supported") +move(to_proc::MPIProcessor, chunk::Chunk) = + move(chunk.processor, to_proc, chunk) +move(to_proc::Processor, d::MPIRef) = + move(MPIOSProc(d.rank), to_proc, d) +move(to_proc::MPIProcessor, x) = + move(MPIOSProc(), to_proc, x) + +move(::MPIProcessor, ::MPIProcessor, x::Union{Function,Type}) = x +move(::MPIProcessor, ::MPIProcessor, x::Chunk{<:Union{Function,Type}}) = poolget(x.handle) + +@warn "Is this uniform logic valuable to have?" maxlog=1 +function move(src::MPIProcessor, dst::MPIProcessor, x::Chunk) + uniform = false #uniform = MPI_UNIFORM[] + @assert uniform || src.rank == dst.rank "Unwrapping not permitted" + if Sch.SCHED_MOVE[] + # We can either unwrap locally, or return nothing + if dst.rank == MPI.Comm_rank(dst.comm) + return poolget(x.handle) + end + else + # Either we're uniform (so everyone cooperates), or we're unwrapping locally + if !uniform + @assert src.rank == MPI.Comm_rank(src.comm) "Unwrapping not permitted" + @assert src.rank == x.handle.rank == dst.rank + end + return poolget(x.handle; uniform) + end +end + + +#FIXME:try to think of a better move! scheme +function execute!(proc::MPIProcessor, f, args...; kwargs...) + local_rank = MPI.Comm_rank(proc.comm) + islocal = local_rank == proc.rank + inplace_move = f === move! + result = nothing + tag = to_tag() + + if islocal || inplace_move + result = execute!(proc.innerProc, f, args...; kwargs...) + end + + if inplace_move + space = memory_space(nothing, proc)::MPIMemorySpace + dest_type = chunktype(args[4]) + return tochunk(nothing, proc, space; type=dest_type) + end + + # Infer return type; only bcast when inference is not concrete + fname = nameof(f) + arg_types = map(chunktype, args) + inferred_type = Base.promote_op(f, arg_types...) + + need_bcast = !isconcretetype(inferred_type) || inferred_type === Union{} || inferred_type === Nothing || inferred_type === Any + + if islocal + T = typeof(result) + space = memory_space(result, proc)::MPIMemorySpace + if need_bcast + @opcounter :execute_bcast_send_yield + bcast_send_yield((T, space.innerSpace), proc.comm, proc.rank, tag) + end + return tochunk(result, proc, space; type=T) + else + if need_bcast + T, innerSpace = recv_yield(proc.comm, proc.rank, tag) + space = MPIMemorySpace(innerSpace, proc.comm, proc.rank) + return tochunk(nothing, proc, space; type=T) + else + space = memory_space(nothing, proc)::MPIMemorySpace + return tochunk(nothing, proc, space; type=inferred_type) + end + end +end + +accelerate!(::Val{:mpi}) = accelerate!(MPIAcceleration()) + +function initialize_acceleration!(a::MPIAcceleration) + if !MPI.Initialized() + MPI.Init(;threadlevel=:multiple) + end + ctx = Dagger.Sch.eager_context() + sz = MPI.Comm_size(a.comm) + for i in 0:(sz-1) + push!(ctx.procs, MPIOSProc(a.comm, i)) + end + unique!(ctx.procs) +end + +""" + mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + +Ensure all ranks use the same concrete type for the given tasks by setting +each task's options.return_type to expected_type when it is concrete. +This allows chunktype(task) to return the concrete type on every rank +without an MPI allgather of actual result types. +""" +function mpi_propagate_chunk_types!(tasks, accel::MPIAcceleration, expected_type) + isconcretetype(expected_type) || return + for t in tasks + if t isa Thunk + if t.options !== nothing + t.options.return_type = expected_type + else + t.options = Options(return_type=expected_type) + end + end + end + return +end + +accel_matches_proc(accel::MPIAcceleration, proc::MPIOSProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIClusterProc) = true +accel_matches_proc(accel::MPIAcceleration, proc::MPIProcessor) = true +accel_matches_proc(accel::MPIAcceleration, proc) = false + +function distribute(accel::MPIAcceleration, A::AbstractArray{T,N}, dist::Blocks{N}) where {T,N} + comm = accel.comm + rank = MPI.Comm_rank(comm) + + DA = view(A, dist) + DB = DArray{T,N}(undef, dist, size(A)) + copyto!(DB, DA) + + return DB +end diff --git a/src/mpi_mempool.jl b/src/mpi_mempool.jl new file mode 100644 index 000000000..149c7900a --- /dev/null +++ b/src/mpi_mempool.jl @@ -0,0 +1,36 @@ +# Mempool for received MPI message data only (no envelopes). +# Key: (comm, source, tag). Used when a message is received but not the one the caller was waiting for. +# Included from mpi.jl; runs in Dagger module scope. + +const MPI_RECV_MEMPOOL = Base.Lockable(Dict{Tuple{MPI.Comm, Int, Int}, Vector{Any}}()) + +function mpi_mempool_put!(comm::MPI.Comm, source::Integer, tag::Integer, data::Any) + key = (comm, Int(source), Int(tag)) + ref = poolset(data) + lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) + pool[key] = Any[] + end + push!(pool[key], ref) + end + return nothing +end + +function mpi_mempool_take!(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + ref = lock(MPI_RECV_MEMPOOL) do pool + if !haskey(pool, key) || isempty(pool[key]) + return nothing + end + popfirst!(pool[key]) + end + ref === nothing && return nothing + return poolget(ref) +end + +function mpi_mempool_has(comm::MPI.Comm, source::Integer, tag::Integer) + key = (comm, Int(source), Int(tag)) + return lock(MPI_RECV_MEMPOOL) do pool + haskey(pool, key) && !isempty(pool[key]) + end +end diff --git a/src/mutable.jl b/src/mutable.jl new file mode 100644 index 000000000..1f48ead53 --- /dev/null +++ b/src/mutable.jl @@ -0,0 +1,41 @@ +function _mutable_inner(@nospecialize(f), proc, scope) + result = f() + return Ref(Dagger.tochunk(result, proc, scope)) +end + +""" + mutable(f::Base.Callable; worker, processor, scope) -> Chunk + +Calls `f()` on the specified worker or processor, returning a `Chunk` +referencing the result with the specified scope `scope`. +""" +function mutable(@nospecialize(f); worker=nothing, processor=nothing, scope=nothing) + if processor === nothing + if worker === nothing + processor = OSProc() + else + processor = OSProc(worker) + end + else + @assert worker === nothing "mutable: Can't mix worker and processor" + end + if scope === nothing + scope = processor isa OSProc ? ProcessScope(processor) : ExactScope(processor) + end + return fetch(Dagger.@spawn scope=scope _mutable_inner(f, processor, scope))[] +end + +""" + @mutable [worker=1] [processor=OSProc()] [scope=ProcessorScope()] f() + +Helper macro for [`mutable()`](@ref). +""" +macro mutable(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $mutable(f; $(opts...)) + end + end +end diff --git a/src/options.jl b/src/options.jl index eca59fbc9..09067da51 100644 --- a/src/options.jl +++ b/src/options.jl @@ -26,6 +26,7 @@ Stores per-task options to be passed to the scheduler. - `storage_leaf_tag::Union{MemPool.Tag,Nothing}=nothing`: If not `nothing`, specifies the MemPool storage leaf tag to associate with the task's result. This tag can be used by MemPool's storage devices to manipulate their behavior, such as the file name used to store data on disk." - `storage_retain::Union{Bool,Nothing}=nothing`: The value of `retain` to pass to `MemPool.poolset` when constructing the result `Chunk`. `nothing` defaults to `false`. - `name::Union{String,Nothing}=nothing`: If not `nothing`, annotates the task with a name for logging purposes. +- `tag::Union{UInt32,Nothing}=nothing`: (Data-deps/MPI) MPI message tag for this task; assigned automatically if `nothing`. - `stream_input_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the input buffer of the task. Defaults to 1. - `stream_output_buffer_amount::Union{Int,Nothing}=nothing`: (Streaming only) Specifies the amount of slots to allocate for the output buffer of the task. Defaults to 1. - `stream_buffer_type::Union{Type,Nothing}=nothing`: (Streaming only) Specifies the type of buffer to use for the input and output buffers of the task. Defaults to `Dagger.ProcessRingBuffer`. @@ -61,10 +62,16 @@ Base.@kwdef mutable struct Options name::Union{String,Nothing} = nothing + tag::Union{UInt32,Nothing} = nothing + stream_input_buffer_amount::Union{Int,Nothing} = nothing stream_output_buffer_amount::Union{Int,Nothing} = nothing stream_buffer_type::Union{Type, Nothing} = nothing stream_max_evals::Union{Int,Nothing} = nothing + + acceleration::Union{Acceleration,Nothing} = nothing + + return_type::Union{Type,Nothing} = nothing end Options(::Nothing) = Options() function Options(old_options::NamedTuple) diff --git a/src/processor.jl b/src/processor.jl index ac2e74f14..4944dc083 100644 --- a/src/processor.jl +++ b/src/processor.jl @@ -2,16 +2,6 @@ export OSProc, Context, addprocs!, rmprocs! import Base: @invokelatest -""" - Processor - -An abstract type representing a processing device and associated memory, where -data can be stored and operated on. Subtypes should be immutable, and -instances should compare equal if they represent the same logical processing -device/memory. Subtype instances should be serializable between different -nodes. Subtype instances may contain a "parent" `Processor` to make it easy to -transfer data to/from other types of `Processor` at runtime. -""" abstract type Processor end const PROCESSOR_CALLBACKS = Dict{Symbol,Any}() @@ -150,3 +140,20 @@ iscompatible_arg(proc::OSProc, opts, args...) = "Returns a very brief `String` representation of `proc`." short_name(proc::Processor) = string(proc) short_name(p::OSProc) = "W: $(p.pid)" + +"Returns true if the processor is on the local worker (for MPI/ordering)." +is_local_processor(proc::Processor) = (root_worker_id(proc) == myid()) + +"Ordering key for task firing (used by MPI to avoid deadlock)." +fire_order_key(proc::Processor) = (root_worker_id(proc), 0) + +@doc """ + Processor + +An abstract type representing a processing device and associated memory, where +data can be stored and operated on. Subtypes should be immutable, and +instances should compare equal if they represent the same logical processing +device/memory. Subtype instances should be serializable between different +nodes. Subtype instances may contain a "parent" `Processor` to make it easy to +transfer data to/from other types of `Processor` at runtime. +""" Processor diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index f0bed125e..7c724c0ab 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -15,7 +15,7 @@ import Base: @invokelatest import ..Dagger import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, ThunkID, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature -import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! +import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, root_worker_id, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc!, is_local_processor, fire_order_key, short_name import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek @@ -25,7 +25,7 @@ import ..Dagger: @reusable, @reusable_dict, @reusable_vector, @reusable_tasks, @ import TimespanLogging import TaskLocalValues: TaskLocalValue -import ScopedValues: @with +import ScopedValues: ScopedValue, @with, with const OneToMany = Dict{Thunk, Set{Thunk}} @@ -56,7 +56,7 @@ Fields: - `cache::WeakKeyDict{Thunk, Any}` - Maps from a finished `Thunk` to it's cached result, often a DRef - `valid::WeakKeyDict{Thunk, Nothing}` - Tracks all `Thunk`s that are in a valid scheduling state - `running::Set{Thunk}` - The set of currently-running `Thunk`s -- `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it +- `running_on::Dict{Thunk,Processor}` - Map from `Thunk` to the OS process executing it - `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk` - `node_order::Any` - Function that returns the order of a thunk - `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it @@ -82,15 +82,15 @@ struct ComputeState ready::Vector{Thunk} valid::Dict{Thunk, Nothing} running::Set{Thunk} - running_on::Dict{Thunk,OSProc} + running_on::Dict{Thunk,Processor} thunk_dict::Dict{Int, WeakThunk} node_order::Any - equiv_chunks::WeakKeyDict{DRef,Chunk} - worker_time_pressure::Dict{Int,Dict{Processor,UInt64}} - worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}} - worker_loadavg::Dict{Int,NTuple{3,Float64}} - worker_chans::Dict{Int, Tuple{RemoteChannel,RemoteChannel}} + equiv_chunks::WeakKeyDict{Any,Chunk} + worker_time_pressure::Dict{Processor,Dict{Processor,UInt64}} + worker_storage_pressure::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_storage_capacity::Dict{Processor,Dict{Union{StorageResource,Nothing},UInt64}} + worker_loadavg::Dict{Processor,NTuple{3,Float64}} + worker_chans::Dict{Int,Tuple{RemoteChannel,RemoteChannel}} signature_time_cost::Dict{Signature,UInt64} signature_alloc_cost::Dict{Signature,UInt64} transfer_rate::Ref{UInt64} @@ -111,10 +111,10 @@ function start_state(deps::Dict, node_order, chan) Vector{Thunk}(undef, 0), Dict{Thunk, Nothing}(), Set{Thunk}(), - Dict{Thunk,OSProc}(), + Dict{Thunk,Processor}(), Dict{Int, WeakThunk}(), node_order, - WeakKeyDict{DRef,Chunk}(), + WeakKeyDict{Any,Chunk}(), Dict{Int,Dict{Processor,UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(), @@ -152,29 +152,29 @@ const WORKER_MONITOR_TASKS = Dict{Int,Task}() const WORKER_MONITOR_CHANS = Dict{Int,Dict{UInt64,RemoteChannel}}() function init_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + pid = Dagger.root_worker_id(p) + @maybelog ctx timespan_start(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) # Initialize pressure and capacity - gproc = OSProc(p.pid) lock(state.lock) do - state.worker_time_pressure[p.pid] = Dict{Processor,UInt64}() + state.worker_time_pressure[p] = Dict{Processor,UInt64}() - state.worker_storage_pressure[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() - state.worker_storage_capacity[p.pid] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_pressure[p] = Dict{Union{StorageResource,Nothing},UInt64}() + state.worker_storage_capacity[p] = Dict{Union{StorageResource,Nothing},UInt64}() #= FIXME for storage in get_storage_resources(gproc) - pressure, capacity = remotecall_fetch(gproc.pid, storage) do storage + pressure, capacity = remotecall_fetch(root_worker_id(gproc), storage) do storage storage_pressure(storage), storage_capacity(storage) end - state.worker_storage_pressure[p.pid][storage] = pressure - state.worker_storage_capacity[p.pid][storage] = capacity + state.worker_storage_pressure[p][storage] = pressure + state.worker_storage_capacity[p][storage] = capacity end =# - state.worker_loadavg[p.pid] = (0.0, 0.0, 0.0) + state.worker_loadavg[p] = (0.0, 0.0, 0.0) end - if p.pid != 1 + if pid != 1 lock(WORKER_MONITOR_LOCK) do - wid = p.pid + wid = pid if !haskey(WORKER_MONITOR_TASKS, wid) t = Threads.@spawn begin try @@ -208,16 +208,16 @@ function init_proc(state, p, log_sink) end # Setup worker-to-scheduler channels - inp_chan = RemoteChannel(p.pid) - out_chan = RemoteChannel(p.pid) + inp_chan = RemoteChannel(pid) + out_chan = RemoteChannel(pid) lock(state.lock) do - state.worker_chans[p.pid] = (inp_chan, out_chan) + state.worker_chans[pid] = (inp_chan, out_chan) end # Setup dynamic listener - dynamic_listener!(ctx, state, p.pid) + dynamic_listener!(ctx, state, pid) - @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=p.pid), nothing) + @maybelog ctx timespan_finish(ctx, :init_proc, (;uid=state.uid, worker=pid), nothing) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! @@ -235,7 +235,7 @@ function _cleanup_proc(uid, log_sink) end function cleanup_proc(state, p, log_sink) ctx = Context(Int[]; log_sink) - wid = p.pid + wid = root_worker_id(p) @maybelog ctx timespan_start(ctx, :cleanup_proc, (;uid=state.uid, worker=wid), nothing) lock(WORKER_MONITOR_LOCK) do if haskey(WORKER_MONITOR_CHANS, wid) @@ -298,7 +298,7 @@ function compute_dag(ctx::Context, d::Thunk, options=SchedulerOptions()) node_order = x -> -get(ord, x, 0) state = start_state(deps, node_order, chan) - master = OSProc(myid()) + master = Dagger.default_processor() @maybelog ctx timespan_start(ctx, :scheduler_init, (;uid=state.uid), master) try @@ -393,8 +393,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt res = tresult.result @dagdebug thunk_id :take "Got finished task" - gproc = OSProc(pid) safepoint(state) + gproc = proc != nothing ? get_parent(proc) : OSProc(pid) lock(state.lock) do thunk_failed = false if res isa Exception @@ -421,11 +421,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt node = unwrap_weak_checked(state.thunk_dict[thunk_id])::Thunk metadata = tresult.metadata if metadata !== nothing - state.worker_time_pressure[pid][proc] = metadata.time_pressure + state.worker_time_pressure[gproc][proc] = metadata.time_pressure #to_storage = fetch(node.options.storage) #state.worker_storage_pressure[pid][to_storage] = metadata.storage_pressure #state.worker_storage_capacity[pid][to_storage] = metadata.storage_capacity - #state.worker_loadavg[pid] = metadata.loadavg + #state.worker_loadavg[gproc] = metadata.loadavg sig = signature(state, node) state.signature_time_cost[sig] = (metadata.threadtime + get(state.signature_time_cost, sig, 0)) ÷ 2 state.signature_alloc_cost[sig] = (metadata.gc_allocd + get(state.signature_alloc_cost, sig, 0)) ÷ 2 @@ -434,8 +434,8 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt end end if res isa Chunk - if !haskey(state.equiv_chunks, res) - state.equiv_chunks[res.handle::DRef] = res + if !haskey(state.equiv_chunks, res.handle) + state.equiv_chunks[res.handle] = res end end store_result!(state, node, res; error=thunk_failed) @@ -522,7 +522,7 @@ end const CHUNK_CACHE = Dict{Chunk,Dict{Processor,Any}}() struct ScheduleTaskLocation - gproc::OSProc + gproc::Processor proc::Processor end struct ScheduleTaskSpec @@ -532,6 +532,25 @@ struct ScheduleTaskSpec est_alloc_util::UInt64 est_occupancy::UInt32 end + +"Ordering key for task locations when using MPI acceleration (deterministic across ranks)." +function _mpi_fire_order_key(loc::ScheduleTaskLocation) + g = loc.gproc + p = loc.proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + +"Ordering key for a single Processor when using MPI acceleration (deterministic across ranks)." +function _mpi_proc_rank(proc::Processor) + g = get_parent(proc) + p = proc + g_rank = g isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? g.rank : root_worker_id(g) + p_rank = p isa Union{Dagger.MPIOSProc, Dagger.MPIProcessor} ? p.rank : root_worker_id(p) + return (g_rank, p_rank) +end + @reuse_scope function schedule!(ctx, state, sch_options, procs=procs_to_use(ctx, sch_options)) lock(state.lock) do safepoint(state) @@ -546,6 +565,7 @@ end to_fire_cleanup = @reuse_defer_cleanup empty!(to_fire) failed_scheduling = @reusable_vector :schedule!_failed_scheduling Union{Thunk,Nothing} nothing 32 failed_scheduling_cleanup = @reuse_defer_cleanup empty!(failed_scheduling) + # Select a new task and get its options task = nothing @label pop_task @@ -622,9 +642,9 @@ end end @label scope_computed - input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 + input_procs = @reusable_vector :schedule!_input_procs Union{Processor,Nothing} nothing 32 input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) - for proc in Dagger.compatible_processors(scope, procs) + for proc in Dagger.compatible_processors(options.acceleration, scope, procs) if !(proc in input_procs) push!(input_procs, proc) end @@ -656,7 +676,7 @@ end can_use, scope = can_use_proc(state, task, gproc, proc, options, scope) if can_use has_cap, est_time_util, est_alloc_util, est_occupancy = - has_capacity(state, proc, gproc.pid, options.time_util, options.alloc_util, options.occupancy, sig) + has_capacity(state, proc, gproc, options.time_util, options.alloc_util, options.occupancy, sig) if has_cap # Schedule task onto proc # FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util @@ -665,10 +685,10 @@ end Vector{ScheduleTaskSpec}() end push!(proc_tasks, ScheduleTaskSpec(task, scope, est_time_util, est_alloc_util, est_occupancy)) - state.worker_time_pressure[gproc.pid][proc] = - get(state.worker_time_pressure[gproc.pid], proc, 0) + + state.worker_time_pressure[gproc][proc] = + get(state.worker_time_pressure[gproc], proc, 0) + est_time_util - @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc.pid][proc]))" + @dagdebug task :schedule "Scheduling to $gproc -> $proc (cost: $(costs[proc]), pressure: $(state.worker_time_pressure[gproc][proc]))" sorted_procs_cleanup() costs_cleanup() @goto pop_task @@ -684,10 +704,21 @@ end costs_cleanup() @goto pop_task - # Fire all newly-scheduled tasks + # Fire all newly-scheduled tasks (owner/local first, then by fire_order_key to avoid MPI execute! deadlock) @label fire_tasks - for (task_loc, task_spec) in to_fire - fire_tasks!(ctx, task_loc, task_spec, state) + task_locs = collect(keys(to_fire)) + if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort!(task_locs, by=_mpi_fire_order_key) + end + rank = try + M = parentmodule(@__MODULE__) + (isdefined(M, :MPI) && M.MPI.Initialized()) ? Int(M.MPI.Comm_rank(M.MPI.COMM_WORLD)) : nothing + catch + nothing + end + for (i, task_loc) in enumerate(task_locs) + #Core.println("fire_order rank=", rank, " [", i, "/", length(task_locs), "] task_loc=", task_loc) + fire_tasks!(ctx, task_loc, to_fire[task_loc], state) end to_fire_cleanup() @@ -736,13 +767,13 @@ function monitor_procs_changed!(ctx, state, options) end function remove_dead_proc!(ctx, state, proc, options) - @assert options.single !== proc.pid "Single worker failed, cannot continue." + @assert options.single !== root_worker_id(proc) "Single worker failed, cannot continue." rmprocs!(ctx, [proc]) - delete!(state.worker_time_pressure, proc.pid) - delete!(state.worker_storage_pressure, proc.pid) - delete!(state.worker_storage_capacity, proc.pid) - delete!(state.worker_loadavg, proc.pid) - delete!(state.worker_chans, proc.pid) + delete!(state.worker_time_pressure, proc) + delete!(state.worker_storage_pressure, proc) + delete!(state.worker_storage_capacity, proc) + delete!(state.worker_loadavg, proc) + delete!(state.worker_chans, root_worker_id(proc)) end function finish_task!(ctx, state, node, thunk_failed) @@ -785,7 +816,7 @@ end function evict_all_chunks!(ctx, options, to_evict) if !isempty(to_evict) - @sync for w in map(p->p.pid, procs_to_use(ctx, options)) + @sync for w in map(p->root_worker_id(p), procs_to_use(ctx, options)) Threads.@spawn remote_do(evict_chunks!, w, ctx.log_sink, to_evict) end end @@ -856,9 +887,10 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end Tf = chunktype(first(args)) - @assert (options.single === nothing) || (gproc.pid == options.single) + pid = root_worker_id(gproc) + @assert (options.single === nothing) || (pid == options.single) # TODO: Set `sch_handle.tid.ref` to the right `DRef` - sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...) + sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[pid]...) # TODO: De-dup common fields (log_sink, uid, etc.) push!(to_send, TaskSpec( @@ -870,7 +902,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h)) end if !isempty(to_send) - if Dagger.root_worker_id(gproc) == myid() + if root_worker_id(gproc) == myid() @reusable_tasks :fire_tasks!_task_cache 32 _->nothing "fire_tasks!" FireTaskSpec(proc, state.chan, to_send) else # N.B. We don't batch these because we might get a deserialization @@ -1076,7 +1108,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re proc_occupancy = istate.proc_occupancy time_pressure = istate.time_pressure - wid = get_parent(to_proc).pid + wid = root_worker_id(to_proc) work_to_do = false while isopen(return_queue) # Wait for new tasks @@ -1131,12 +1163,15 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal a task @maybelog ctx timespan_start(ctx, :proc_steal_local, (;uid, worker=wid, processor=to_proc), nothing) - # Try to steal from local queues randomly + # Try to steal from local queues randomly (deterministic order when MPI to avoid deadlocks) # TODO: Prioritize stealing from busiest processors states = proc_states_values(uid) - # TODO: Try to pre-allocate this - P = randperm(length(states)) - for state in getindex.(Ref(states), P) + order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) + else + randperm(length(states)) + end + for state in getindex.(Ref(states), order) other_istate = state.state if other_istate.proc === to_proc continue @@ -1151,7 +1186,8 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re end task, occupancy = peek(queue) scope = task.scope - if Dagger.proc_in_scope(to_proc, scope) + accel = task.options.acceleration + if Dagger.proc_in_scope(to_proc, scope) && Dagger.accel_matches_proc(accel, to_proc) typemax(UInt32) - proc_occupancy_cached >= occupancy # Compatible, steal this task return dequeue_pair!(queue) @@ -1344,11 +1380,15 @@ function do_tasks(to_proc, return_queue, tasks) end notify(istate.reschedule) - # Kick other processors to make them steal + # Kick other processors to make them steal (deterministic order when MPI to avoid deadlocks) # TODO: Alternatively, automatically balance work instead of blindly enqueueing states = proc_states_values(uid) - P = randperm(length(states)) - for other_state in getindex.(Ref(states), P) + order = if Dagger.current_acceleration() isa Dagger.MPIAcceleration + sort(1:length(states), by=i->_mpi_proc_rank(states[i].state.proc)) + else + randperm(length(states)) + end + for other_state in getindex.(Ref(states), order) other_istate = other_state.state if other_istate.proc === to_proc continue @@ -1357,6 +1397,8 @@ function do_tasks(to_proc, return_queue, tasks) end @dagdebug nothing :processor "Kicked processors" end + +const SCHED_MOVE = ScopedValue{Bool}(false) """ do_task(to_proc, task::TaskSpec) -> Any @@ -1369,13 +1411,15 @@ Executes a single task specified by `task` on `to_proc`. ctx_vars = task.ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) - from_proc = OSProc() + options = task.options + Dagger.accelerate!(options.acceleration) + + from_proc = Dagger.default_processor() data = task.data Tf = task.Tf f = isdefined(Tf, :instance) ? Tf.instance : nothing # Wait for required resources to become available - options = task.options propagated = get_propagated_options(options) to_storage = options.storage !== nothing ? fetch(options.storage) : MemPool.GLOBAL_DEVICE[] #to_storage_name = nameof(typeof(to_storage)) @@ -1443,7 +1487,7 @@ Executes a single task specified by `task` on `to_proc`. @maybelog ctx timespan_finish(ctx, :storage_wait, (;thunk_id, processor=to_proc), (;f, device=typeof(to_storage))) =# - @dagdebug thunk_id :execute "Moving data" + @dagdebug thunk_id :execute "Moving data for $Tf" # Initiate data transfers for function and arguments transfer_time = Threads.Atomic{UInt64}(0) @@ -1462,11 +1506,13 @@ Executes a single task specified by `task` on `to_proc`. #= FIXME: This isn't valid if x is written to x = if x isa Chunk value = lock(TASK_SYNC) do - if haskey(CHUNK_CACHE, x) - Some{Any}(get!(CHUNK_CACHE[x], to_proc) do - # Convert from cached value - # TODO: Choose "closest" processor of same type first - some_proc = first(keys(CHUNK_CACHE[x])) + if haskey(CHUNK_CACHE, x) + Some{Any}(get!(CHUNK_CACHE[x], to_proc) do + # Convert from cached value + # TODO: Choose "closest" processor of same type first + cache_procs = keys(CHUNK_CACHE[x]) + some_proc = Dagger.current_acceleration() isa Dagger.MPIAcceleration ? + minimum(cache_procs, by=_mpi_proc_rank) : first(cache_procs) some_x = CHUNK_CACHE[x][some_proc] @dagdebug thunk_id :move "Cache hit for argument $id at $some_proc: $some_x" @invokelatest move(some_proc, to_proc, some_x) @@ -1501,13 +1547,23 @@ Executes a single task specified by `task` on `to_proc`. end else =# - new_value = @invokelatest move(to_proc, value) + new_value = with(SCHED_MOVE=>true) do + @invokelatest move(to_proc, value) + end #end - if new_value !== value - @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + # Preserve Chunk reference when move returns nothing (placeholder on this rank). This keeps + # type information correct at all ranks: chunktype(Chunk) is concrete even when Chunk holds no data. + # So execute! sees correct arg_types. Materializing the value (for the kernel) must happen in + # execute! and may require lazy recv from the executor if this rank has a placeholder. + if new_value === nothing && (value isa Dagger.Chunk || value isa Dagger.WeakChunk) + arg.value = value + else + if new_value !== value + @dagdebug thunk_id :move "Moved argument @ $position to $to_proc: $(typeof(value)) -> $(typeof(new_value))" + end + arg.value = new_value end - @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=new_value); tasks=[Base.current_task()]) - arg.value = new_value + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id, position, processor=to_proc), (;f, data=Dagger.value(arg)); tasks=[Base.current_task()]) return end end @@ -1546,7 +1602,7 @@ Executes a single task specified by `task` on `to_proc`. # FIXME #gcnum_start = Base.gc_num() - @dagdebug thunk_id :execute "Executing $(typeof(f))" + @dagdebug thunk_id :execute "Executing $Tf" logging_enabled = !(ctx.log_sink isa TimespanLogging.NoOpLog) @@ -1609,7 +1665,7 @@ Executes a single task specified by `task` on `to_proc`. notify(TASK_SYNC) end - @dagdebug thunk_id :execute "Returning" + @dagdebug thunk_id :execute "Returning $Tf with $(typeof(result_meta))" # TODO: debug_storage("Releasing $to_storage_name") metadata = ( diff --git a/src/sch/util.jl b/src/sch/util.jl index d3b7a4804..eee360c3b 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -430,8 +430,8 @@ function can_use_proc(state, task, gproc, proc, opts, scope) # Check against single if opts.single !== nothing @warn "The `single` option is deprecated, please use scopes instead\nSee https://juliaparallel.org/Dagger.jl/stable/scopes/ for details" maxlog=1 - if gproc.pid != opts.single - @dagdebug task :scope "Rejected $proc: gproc.pid ($(gproc.pid)) != single ($(opts.single))" + if root_worker_id(gproc) != opts.single + @dagdebug task :scope "Rejected $proc: gproc root_worker_id ($(root_worker_id(gproc))) != single ($(opts.single))" return false, scope end scope = constrain(scope, Dagger.ProcessScope(opts.single)) @@ -583,19 +583,21 @@ end # Add fixed cost for cross-worker task transfer (esimated at 1ms) # TODO: Actually estimate/benchmark this - task_xfer_cost = gproc.pid != myid() ? 1_000_000 : 0 # 1ms + task_xfer_cost = root_worker_id(gproc) != myid() ? 1_000_000 : 0 # 1ms # Compute final cost costs[proc] = est_time_util + (tx_cost/tx_rate) + task_xfer_cost end chunks_cleanup() - # Shuffle procs around, so equally-costly procs are equally considered + # Shuffle procs around, so equally-costly procs are equally considered (skip shuffle when MPI for deterministic tie-breaking) np = length(procs) @reusable :estimate_task_costs_P Vector{Int} 0 4 np P begin resize!(P, np) copyto!(P, 1:np) - randperm!(P) + if !(Dagger.current_acceleration() isa Dagger.MPIAcceleration) + randperm!(P) + end for idx in 1:np sorted_procs[idx] = procs[P[idx]] end diff --git a/src/scopes.jl b/src/scopes.jl index 79190c292..28aa8fa00 100644 --- a/src/scopes.jl +++ b/src/scopes.jl @@ -101,7 +101,7 @@ struct ExactScope <: AbstractScope parent::ProcessScope processor::Processor end -ExactScope(proc) = ExactScope(ProcessScope(get_parent(proc).pid), proc) +ExactScope(proc) = ExactScope(ProcessScope(root_worker_id(get_parent(proc))), proc) proc_in_scope(proc::Processor, scope::ExactScope) = proc == scope.processor "Indicates that the applied scopes `x` and `y` are incompatible." diff --git a/src/shard.jl b/src/shard.jl new file mode 100644 index 000000000..ecd0ee570 --- /dev/null +++ b/src/shard.jl @@ -0,0 +1,89 @@ +""" +Maps a value to one of multiple distributed "mirror" values automatically when +used as a thunk argument. Construct using `@shard` or `shard`. +""" +struct Shard + chunks::Dict{Processor,Chunk} +end + +""" + shard(f; kwargs...) -> Chunk{Shard} + +Executes `f` on all workers in `workers`, wrapping the result in a +process-scoped `Chunk`, and constructs a `Chunk{Shard}` containing all of these +`Chunk`s on the current worker. + +Keyword arguments: +- `procs` -- The list of processors to create pieces on. May be any iterable container of `Processor`s. +- `workers` -- The list of workers to create pieces on. May be any iterable container of `Integer`s. +- `per_thread::Bool=false` -- If `true`, creates a piece per each thread, rather than a piece per each worker. +""" +function shard(@nospecialize(f); procs=nothing, workers=nothing, per_thread=false) + if procs === nothing + if workers !== nothing + procs = [OSProc(w) for w in workers] + else + procs = lock(Sch.eager_context()) do + copy(Sch.eager_context().procs) + end + end + if per_thread + _procs = ThreadProc[] + for p in procs + append!(_procs, filter(p->p isa ThreadProc, get_processors(p))) + end + procs = _procs + end + else + if workers !== nothing + throw(ArgumentError("Cannot combine `procs` and `workers`")) + elseif per_thread + throw(ArgumentError("Cannot combine `procs` and `per_thread=true`")) + end + end + isempty(procs) && throw(ArgumentError("Cannot create empty Shard")) + shard_running_dict = Dict{Processor,DTask}() + for proc in procs + scope = proc isa OSProc ? ProcessScope(proc) : ExactScope(proc) + thunk = Dagger.@spawn scope=scope _mutable_inner(f, proc, scope) + shard_running_dict[proc] = thunk + end + shard_dict = Dict{Processor,Chunk}() + for proc in procs + shard_dict[proc] = fetch(shard_running_dict[proc])[] + end + return Shard(shard_dict) +end + +"Creates a `Shard`. See [`Dagger.shard`](@ref) for details." +macro shard(exs...) + opts = esc.(exs[1:end-1]) + ex = exs[end] + quote + let f = @noinline ()->$(esc(ex)) + $shard(f; $(opts...)) + end + end +end + +function move(from_proc::Processor, to_proc::Processor, shard::Shard) + # Match either this proc or some ancestor + # N.B. This behavior may bypass the piece's scope restriction + proc = to_proc + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + parent = Dagger.get_parent(proc) + while parent != proc + proc = parent + parent = Dagger.get_parent(proc) + if haskey(shard.chunks, proc) + return move(from_proc, to_proc, shard.chunks[proc]) + end + end + + throw(KeyError(to_proc)) +end +Base.iterate(s::Shard) = iterate(values(s.chunks)) +Base.iterate(s::Shard, state) = iterate(values(s.chunks), state) +Base.length(s::Shard) = length(s.chunks) diff --git a/src/submission.jl b/src/submission.jl index 4ff4f2294..fffcc577d 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -285,7 +285,13 @@ function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function DTaskMetadata(spec::DTaskSpec) + rt = spec.options.return_type + if rt !== nothing && isconcretetype(rt) && rt !== Any + return DTaskMetadata(rt) + end + return DTaskMetadata(eager_metadata(spec.fargs)) +end function eager_metadata(fargs) f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f @@ -298,6 +304,10 @@ function eager_spawn(spec::DTaskSpec) uid = eager_next_id() future = ThunkFuture() metadata = DTaskMetadata(spec) + # Propagate inferred return type to options so execute! can skip MPI bcast + if isconcretetype(metadata.return_type) + spec.options.return_type = metadata.return_type + end return DTask(uid, future, metadata) end @@ -320,10 +330,16 @@ function eager_launch!(pair::DTaskPair) end end + # Propagate DTask return_type into options so the created Thunk has chunktype for downstream inference + options = spec.options + if isconcretetype(task.metadata.return_type) + options = copy(options) + options.return_type = task.metadata.return_type + end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - fargs, spec.options, true)) + fargs, options, true)) task.thunk_ref = thunk_id.ref end # FIXME: Don't convert Tuple to Vector{Argument} @@ -353,7 +369,13 @@ function eager_launch!(pairs::Vector{DTaskPair}) end end end - all_options = Options[pair.spec.options for pair in pairs] + # Propagate DTask return_type into options so created Thunks have chunktype for downstream inference + all_options = Options[ + let opts = pair.spec.options + isconcretetype(pair.task.metadata.return_type) ? (o = copy(opts); o.return_type = pair.task.metadata.return_type; o) : opts + end + for pair in pairs + ] # Submit the tasks #=FIXME:REALLOC=# diff --git a/src/thunk.jl b/src/thunk.jl index e13e299f0..c24e0c329 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -247,6 +247,14 @@ isweak(t) = false Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) +# Use options.return_type when set (e.g. from mpi_propagate_chunk_types! or eager_metadata) +# so that Thunk arguments propagate type to downstream eager_metadata/execute! +function chunktype(t::Thunk) + if t.options !== nothing && t.options.return_type !== nothing && isconcretetype(t.options.return_type) + return t.options.return_type + end + return typeof(t) +end Base.convert(::Type{ThunkSyncdep}, t::WeakThunk) = ThunkSyncdep(nothing, t) ThunkSyncdep(t::WeakThunk) = ThunkSyncdep(nothing, t) @@ -462,7 +470,7 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) - if !isempty(kwargs) + if !Base.isempty(kwargs) kwargs = only(kwargs).args end if body !== nothing @@ -530,7 +538,7 @@ function spawn(f, args...; kwargs...) @nospecialize f args kwargs # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] @@ -545,7 +553,7 @@ function spawn(f, args...; kwargs...) end function typed_spawn(f, args...; kwargs...) # Merge all passed options - if length(args) >= 1 && first(args) isa Options + if length(args) >= 1 && first(args) isa Options # N.B. Make a defensive copy in case user aliases Options struct task_options = copy(first(args)::Options) args = args[2:end] diff --git a/src/tochunk.jl b/src/tochunk.jl new file mode 100644 index 000000000..ff15e426e --- /dev/null +++ b/src/tochunk.jl @@ -0,0 +1,119 @@ +@warn "Update tochunk docstring" maxlog=1 +""" + tochunk(x, proc::Processor, scope::AbstractScope; device=nothing, rewrap=false, kwargs...) -> Chunk + +Create a chunk from data `x` which resides on `proc` and which has scope +`scope`. + +`device` specifies a `MemPool.StorageDevice` (which is itself wrapped in a +`Chunk`) which will be used to manage the reference contained in the `Chunk` +generated by this function. If `device` is `nothing` (the default), the data +will be inspected to determine if it's safe to serialize; if so, the default +MemPool storage device will be used; if not, then a `MemPool.CPURAMDevice` will +be used. + +`type` can be specified manually to force the type to be `Chunk{type}`. + +If `rewrap==true` and `x isa Chunk`, then the `Chunk` will be rewrapped in a +new `Chunk`. + +All other kwargs are passed directly to `MemPool.poolset`. +""" +tochunk(x::X, proc::P, space::M; kwargs...) where {X,P<:Processor,M<:MemorySpace} = + tochunk(x, proc, space, AnyScope(); kwargs...) +function tochunk(x::X, proc::P, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S,M<:MemorySpace} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if x isa Chunk + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +# Disambiguate: Chunk-specific 3-arg so kwcall(tochunk, Chunk, Processor, Scope) is not ambiguous with utils/chunks.jl +function tochunk(x::Chunk, proc::P, scope::S; rewrap=false, kwargs...) where {P<:Processor,S} + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end +function tochunk(x::X, proc::P, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,P<:Processor,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + space = x.space + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + space = default_memory_space(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),P,S,typeof(space)}(type, domain(x), ref, proc, scope, space) +end +function tochunk(x::X, space::M, scope::S; device=nothing, type=X, rewrap=false, kwargs...) where {X,M<:MemorySpace,S} + if type === Nothing + throw(ArgumentError("Chunk type cannot be Nothing. Placeholder chunks must be created with an explicit type= (e.g. tochunk(nothing, proc, space; type=Matrix{Float64})). x=$(repr(x))")) + end + if device === nothing + device = if Sch.walk_storage_safe(x) + MemPool.GLOBAL_DEVICE[] + else + MemPool.CPURAMDevice() + end + end + if x isa Chunk + proc = x.processor + check_proc_space(x, proc, space) + return maybe_rewrap(x, proc, space, scope; type, rewrap) + end + proc = default_processor(current_acceleration(), x) + ref = tochunk_pset(x, space; device, type, kwargs...) + return Chunk{type,typeof(ref),typeof(proc),S,M}(type, domain(x), ref, proc, scope, space) +end +# 2-arg: avoid overwriting utils/chunks.jl's tochunk(Any, Any) and tochunk(Any); only add Processor/MemorySpace variants +# Chunk + Processor: disambiguate vs utils/chunks.jl's tochunk(x::Chunk, proc; ...) +tochunk(x::Chunk, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, proc::Processor; kwargs...) = tochunk(x, proc, AnyScope(); kwargs...) +tochunk(x, space::MemorySpace; kwargs...) = tochunk(x, space, AnyScope(); kwargs...) + +check_proc_space(x, proc, space) = nothing +function check_proc_space(x::Chunk, proc, space) + if x.space !== space + throw(ArgumentError("Memory space mismatch: Chunk=$(x.space) != Requested=$space")) + end +end +function check_proc_space(x::Thunk, proc, space) + # FIXME: Validate +end +function maybe_rewrap(x, proc, space, scope; type, rewrap) + if rewrap + return remotecall_fetch(x.handle.owner) do + tochunk(MemPool.poolget(x.handle), proc, scope; kwargs...) + end + else + return x + end +end + +tochunk_pset(x, space::MemorySpace; device=nothing, type=nothing, kwargs...) = poolset(x; device, kwargs...) + +# savechunk: defined in utils/chunks.jl (fork Chunk has space field; do not duplicate here) diff --git a/src/types/acceleration.jl b/src/types/acceleration.jl new file mode 100644 index 000000000..b647dd303 --- /dev/null +++ b/src/types/acceleration.jl @@ -0,0 +1 @@ +abstract type Acceleration end \ No newline at end of file diff --git a/src/types/chunk.jl b/src/types/chunk.jl new file mode 100644 index 000000000..9b8102a6d --- /dev/null +++ b/src/types/chunk.jl @@ -0,0 +1,27 @@ +""" + Chunk + +A reference to a piece of data located on a remote worker. `Chunk`s are +typically created with `Dagger.tochunk(data)`, and the data can then be +accessed from any worker with `collect(::Chunk)`. `Chunk`s are +serialization-safe, and use distributed refcounting (provided by +`MemPool.DRef`) to ensure that the data referenced by a `Chunk` won't be GC'd, +as long as a reference exists on some worker. + +Each `Chunk` is associated with a given `Dagger.Processor`, which is (in a +sense) the processor that "owns" or contains the data. Calling +`collect(::Chunk)` will perform data movement and conversions defined by that +processor to safely serialize the data to the calling worker. + +## Constructors +See [`tochunk`](@ref). +""" + +mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope, M<:MemorySpace} + chunktype::Type{T} + domain + handle::H + processor::P + scope::S + space::M +end diff --git a/src/types/memory-space.jl b/src/types/memory-space.jl new file mode 100644 index 000000000..247ceccb0 --- /dev/null +++ b/src/types/memory-space.jl @@ -0,0 +1 @@ +abstract type MemorySpace end \ No newline at end of file diff --git a/src/types/processor.jl b/src/types/processor.jl new file mode 100644 index 000000000..1e333413f --- /dev/null +++ b/src/types/processor.jl @@ -0,0 +1,2 @@ +# Docstring for Processor is attached in src/processor.jl after OSProc is defined (avoids "Replacing docs" warning). +abstract type Processor end \ No newline at end of file diff --git a/src/types/scope.jl b/src/types/scope.jl new file mode 100644 index 000000000..0197fddf9 --- /dev/null +++ b/src/types/scope.jl @@ -0,0 +1 @@ +abstract type AbstractScope end \ No newline at end of file diff --git a/src/utils/chunks.jl b/src/utils/chunks.jl index 9f0c3b487..1300a5a1d 100644 --- a/src/utils/chunks.jl +++ b/src/utils/chunks.jl @@ -161,7 +161,8 @@ function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); device=nothing, re end end ref = poolset(x; device, kwargs...) - Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope) + space = memory_space(proc) + Chunk{X,typeof(ref),P,S,typeof(space)}(X, domain(x), ref, proc, scope, space) end function tochunk(x::Chunk, proc=nothing, scope=nothing; rewrap=false, kwargs...) if rewrap @@ -185,5 +186,6 @@ function savechunk(data, dir, f) fr = FileRef(f, sz) proc = OSProc() scope = AnyScope() # FIXME: Scoped to this node - Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope)}(typeof(data), domain(data), fr, proc, scope, true) + space = memory_space(proc) + Chunk{typeof(data),typeof(fr),typeof(proc),typeof(scope),typeof(space)}(typeof(data), domain(data), fr, proc, scope, space) end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 873e47e79..678445051 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -59,4 +59,7 @@ macro opcounter(category, count=1) end end) end -opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] + +# No-op debug helper for tracking largest values (used alongside @opcounter) +largest_value_update!(::Any) = nothing \ No newline at end of file diff --git a/src/weakchunk.jl b/src/weakchunk.jl new file mode 100644 index 000000000..e31070536 --- /dev/null +++ b/src/weakchunk.jl @@ -0,0 +1,23 @@ +struct WeakChunk + wid::Int + id::Int + x::WeakRef +end + +function WeakChunk(c::Chunk) + return WeakChunk(c.handle.owner, c.handle.id, WeakRef(c)) +end + +unwrap_weak(c::WeakChunk) = c.x.value +function unwrap_weak_checked(c::WeakChunk) + cw = unwrap_weak(c) + @assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))" + return cw +end +wrap_weak(c::Chunk) = WeakChunk(c) +isweak(c::WeakChunk) = true +isweak(c::Chunk) = false +is_task_or_chunk(c::WeakChunk) = true +Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) = + error("Cannot serialize a WeakChunk") +chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c)) diff --git a/test/mpi.jl b/test/mpi.jl new file mode 100644 index 000000000..a84ffdce1 --- /dev/null +++ b/test/mpi.jl @@ -0,0 +1,70 @@ +using Dagger +using MPI +using LinearAlgebra +using SparseArrays + +Dagger.accelerate!(:mpi) + +comm = MPI.COMM_WORLD +rank = MPI.Comm_rank(comm) +size = MPI.Comm_size(comm) + +# Use a large array (adjust size as needed for your RAM) +N = 100 +tag = 123 + +if rank == 0 + arr = sprand(N, N, 0.6) +else + arr = spzeros(N, N) +end + +# --- Out-of-place broadcast --- +function bcast_outofplace() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield(arr, comm, 0, tag+1) + else + Dagger.bcast_recv_yield(comm, 0, tag+1) + end + MPI.Barrier(comm) +end +# --- In-place broadcast --- + +function bcast_inplace() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield!(arr, comm, 0, tag) + else + Dagger.bcast_recv_yield!(arr, comm, 0, tag) + end + MPI.Barrier(comm) +end + +function bcast_inplace_metadata() + MPI.Barrier(comm) + if rank == 0 + Dagger.bcast_send_yield_metadata(arr, comm, 0) + end + MPI.Barrier(comm) +end + + +inplace = @time bcast_inplace() + + +MPI.Barrier(comm) +MPI.Finalize() + + + + +#= +A = rand(Blocks(2,2), 4, 4) +Ac = collect(A) +println(Ac) + + +move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2]) +=# +