diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 08d0f0a5..87eb75d8 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -29,6 +29,7 @@ jobs: group: - "Core" - "Downstream" + - "JET" uses: "SciML/.github/.github/workflows/tests.yml@v1" with: group: "${{ matrix.group }}" diff --git a/Project.toml b/Project.toml index 617ceac2..56f4baab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -authors = ["Chris Rackauckas "] version = "3.41.0" +authors = ["Chris Rackauckas "] [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -48,6 +48,7 @@ DocStringExtensions = "0.9.3" FastBroadcast = "0.3.5" ForwardDiff = "0.10.38, 1" GPUArraysCore = "0.2" +JET = "0.9, 0.11" KernelAbstractions = "0.9.36" LinearAlgebra = "1.10" Measurements = "2.11" @@ -76,6 +77,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" @@ -93,4 +95,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] +test = ["Aqua", "FastBroadcast", "ForwardDiff", "JET", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"] diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index fbaf3486..3e8abf4e 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -1,6 +1,6 @@ """ NamedArrayPartition(; kwargs...) - NamedArrayPartition(x::NamedTuple) + NamedArrayPartition(x::NamedTuple) Similar to an `ArrayPartition` but the individual arrays can be accessed via the constructor-specified names. However, unlike `ArrayPartition`, each individual array @@ -22,7 +22,7 @@ function NamedArrayPartition(x::NamedTuple) return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) end -# Note: overloading `getproperty` means we cannot access `NamedArrayPartition` +# Note: overloading `getproperty` means we cannot access `NamedArrayPartition` # fields except through `getfield` and accessor functions. ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) @@ -53,7 +53,7 @@ end function Base.similar( A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S} NamedArrayPartition( - similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices)) + similar(getfield(A, :array_partition), T, S, R...), getfield(A, :names_to_indices)) end Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) @@ -68,7 +68,7 @@ function Base.getproperty(x::NamedArrayPartition, s::Symbol) getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) end -# this enables x.s = some_array. +# this enables x.s = some_array. @inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) index = getproperty(getfield(x, :names_to_indices), s) ArrayPartition(x).x[index] .= v diff --git a/src/utils.jl b/src/utils.jl index 908f55ad..b783eaf9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -119,7 +119,8 @@ function recursivefill!(b::AbstractArray{T, N}, a::T2) where {T <: StaticArraysCore.SArray, T2 <: Union{Number, Bool}, N} @inbounds for i in eachindex(b) - b[i] = fill(a, typeof(b[i])) + # Preserve static array shape while replacing all entries with the scalar + b[i] = map(_ -> a, b[i]) end end @@ -128,7 +129,8 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N}, T2 <: Union{Number, Bool}, N} @inbounds for b in bs, i in eachindex(b) - b[i] = fill(a, typeof(b[i])) + # Preserve static array shape while replacing all entries with the scalar + b[i] = map(_ -> a, b[i]) end end diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 2778bfcc..55ffaa9d 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -924,6 +924,8 @@ function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, J = map(i -> Base.unalias(A, i), to_indices(A, I)) elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1) J = map(i -> Base.unalias(A, i), to_indices(A, Base.tail(I))) + else + J = map(i -> Base.unalias(A, i), to_indices(A, I)) end @boundscheck checkbounds(A, J...) SubArray(A, J) @@ -1200,6 +1202,7 @@ end struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays VectorOfArrayStyle{N}(::Val{N}) where {N} = VectorOfArrayStyle{N}() +VectorOfArrayStyle(::Val{N}) where {N} = VectorOfArrayStyle{N}() # The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle. Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a diff --git a/test/jet_tests.jl b/test/jet_tests.jl new file mode 100644 index 00000000..494b9fb9 --- /dev/null +++ b/test/jet_tests.jl @@ -0,0 +1,14 @@ +using JET, Test, RecursiveArrayTools + +# Get all reports first +result = JET.report_package(RecursiveArrayTools; target_modules = (RecursiveArrayTools,)) +reports = JET.get_reports(result) + +# Filter out similar_type inference errors from StaticArraysCore +filtered_reports = filter(reports) do report + s = string(report) + !(occursin("similar_type", s) && occursin("StaticArraysCore", s)) +end + +# Check if there are any non-filtered errors +@test isempty(filtered_reports) diff --git a/test/runtests.jl b/test/runtests.jl index 4ec9d6f4..647eacfc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,4 +56,8 @@ end @time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl") @time @safetestset "ArrayPartition GPU" include("gpu/arraypartition_gpu.jl") end + + if GROUP == "JET" || GROUP == "All" + @time @safetestset "JET Tests" include("jet_tests.jl") + end end