Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.10.17"
version = "0.10.18"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -33,13 +33,13 @@ BlockArrays = "1.2"
DiagonalArrays = "0.3"
Dictionaries = "0.4.3"
FillArrays = "1.13"
FunctionImplementations = "0.3.1"
FunctionImplementations = "0.4"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
MacroTools = "0.5.13"
MapBroadcast = "0.1.5"
MatrixAlgebraKit = "0.6"
SparseArraysBase = "0.8.3"
SparseArraysBase = "0.9"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.6.2"
TypeParameterAccessors = "0.4.1"
Expand Down
10 changes: 5 additions & 5 deletions src/abstractblocksparsearray/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice
using Base.Broadcast: BroadcastStyle

function Base.Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray})
return Broadcast.BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype)))
end

# Fix ambiguity error with `BlockArrays`.
Expand All @@ -16,7 +16,7 @@ function Base.Broadcast.BroadcastStyle(
},
},
)
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
return BlockSparseArrayStyle{ndims(arraytype)}()
end
function Base.Broadcast.BroadcastStyle(
arraytype::Type{
Expand All @@ -32,7 +32,7 @@ function Base.Broadcast.BroadcastStyle(
},
},
)
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
return BlockSparseArrayStyle{ndims(arraytype)}()
end
function Base.Broadcast.BroadcastStyle(
arraytype::Type{
Expand All @@ -44,7 +44,7 @@ function Base.Broadcast.BroadcastStyle(
},
},
)
return Broadcast.BlockSparseArrayStyle{ndims(arraytype)}()
return BlockSparseArrayStyle{ndims(arraytype)}()
end

# These catch cases that aren't caught by the standard
Expand All @@ -59,7 +59,7 @@ function Base.copyto!(
return copyto!_blocksparse(dest, bc)
end
function Base.copyto!(
dest::AnyAbstractBlockSparseArray{<:Any, N}, bc::Broadcasted{<:Broadcast.BlockSparseArrayStyle{N}}
dest::AnyAbstractBlockSparseArray{<:Any, N}, bc::Broadcasted{<:BlockSparseArrayStyle{N}}
) where {N}
return copyto!_blocksparse(dest, bc)
end
6 changes: 3 additions & 3 deletions src/abstractblocksparsearray/unblockedsubarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ function BlockArrays.blocks(a::UnblockedSubArray)
return SingleBlockView(a)
end

using FunctionImplementations: FunctionImplementations, Style
function FunctionImplementations.Style(arraytype::Type{<:UnblockedSubArray})
return Style(blocktype(parenttype(arraytype)))
using FunctionImplementations: FunctionImplementations, ImplementationStyle
function FunctionImplementations.ImplementationStyle(arraytype::Type{<:UnblockedSubArray})
return ImplementationStyle(blocktype(parenttype(arraytype)))
end

function ArrayLayouts.MemoryLayout(arraytype::Type{<:UnblockedSubArray})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using BlockArrays:
blockedrange,
mortar,
unblock
using FunctionImplementations: FunctionImplementations, Style, style, zero!
using FunctionImplementations: FunctionImplementations, ImplementationStyle, style, zero!
using GPUArraysCore: @allowscalar
using SplitApplyCombine: groupcount
using TypeParameterAccessors: similartype
Expand All @@ -28,8 +28,8 @@ const AnyAbstractBlockSparseVecOrMat{T, N} = Union{
AnyAbstractBlockSparseVector{T}, AnyAbstractBlockSparseMatrix{T},
}

function FunctionImplementations.Style(arrayt::Type{<:AnyAbstractBlockSparseArray})
return BlockSparseArrayStyle()
function FunctionImplementations.ImplementationStyle(arrayt::Type{<:AnyAbstractBlockSparseArray})
return BlockSparseArrayImplementationStyle()
end

# a[1:2, 1:2]
Expand Down Expand Up @@ -303,7 +303,7 @@ function Base.similar(
elt::Type,
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
)
return Style(arraytype)(similar)(arraytype, elt, axes)
return ImplementationStyle(arraytype)(similar)(arraytype, elt, axes)
end

function Base.similar(
Expand Down
16 changes: 8 additions & 8 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using BlockArrays:
using FunctionImplementations: FunctionImplementations, permuteddims, zero!
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
AbstractSparseArrayStyle,
AbstractSparseArrayImplementationStyle,
getstoredindex,
getunstoredindex,
eachstoredindex,
Expand Down Expand Up @@ -109,16 +109,16 @@ blockstype(a::BlockArray) = blockstype(typeof(a))
blocktype(arraytype::Type{<:BlockArray}) = eltype(blockstype(arraytype))
blocktype(a::BlockArray) = eltype(blocks(a))

abstract type AbstractBlockSparseArrayStyle <: AbstractSparseArrayStyle end
abstract type AbstractBlockSparseArrayImplementationStyle <: AbstractSparseArrayImplementationStyle end

struct BlockSparseArrayStyle <: AbstractBlockSparseArrayStyle end
const blocksparse_style = BlockSparseArrayStyle()
struct BlockSparseArrayImplementationStyle <: AbstractBlockSparseArrayImplementationStyle end
const blocksparse_style = BlockSparseArrayImplementationStyle()

function FunctionImplementations.Style(
style1::AbstractBlockSparseArrayStyle,
style2::AbstractBlockSparseArrayStyle,
function FunctionImplementations.ImplementationStyle(
style1::AbstractBlockSparseArrayImplementationStyle,
style2::AbstractBlockSparseArrayImplementationStyle,
)
return BlockSparseArrayStyle()
return BlockSparseArrayImplementationStyle()
end

const blocks_blocksparse = blocksparse_style(blocks)
Expand Down
60 changes: 29 additions & 31 deletions src/blocksparsearrayinterface/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,55 @@ using Base.Broadcast: BroadcastStyle, Broadcasted
using GPUArraysCore: @allowscalar
using MapBroadcast: Mapped

module Broadcast
using Base.Broadcast: AbstractArrayStyle
abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
AbstractArrayStyle{N} end
struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
AbstractBlockSparseArrayStyle{N, B}
blockstyle::B
end
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
return BlockSparseArrayStyle{N, typeof(blockstyle)}(blockstyle)
end
function BlockSparseArrayStyle{N, B}() where {N, B <: AbstractArrayStyle{N}}
return BlockSparseArrayStyle{N, B}(B())
end
function BlockSparseArrayStyle{N}() where {N}
return BlockSparseArrayStyle{N}(Base.Broadcast.DefaultArrayStyle{N}())
end
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
BlockSparseArrayStyle{M}(::Val{N}) where {M, N} = BlockSparseArrayStyle{N}()
function BlockSparseArrayStyle{M, B}(::Val{N}) where {M, B <: AbstractArrayStyle{M}, N}
return BlockSparseArrayStyle{N}(B(Val(N)))
end
using Base.Broadcast: AbstractArrayStyle
abstract type AbstractBlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
AbstractArrayStyle{N} end
struct BlockSparseArrayStyle{N, B <: AbstractArrayStyle{N}} <:
AbstractBlockSparseArrayStyle{N, B}
blockstyle::B
end
function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N}
return BlockSparseArrayStyle{N, typeof(blockstyle)}(blockstyle)
end
function BlockSparseArrayStyle{N, B}() where {N, B <: AbstractArrayStyle{N}}
return BlockSparseArrayStyle{N, B}(B())
end
function BlockSparseArrayStyle{N}() where {N}
return BlockSparseArrayStyle{N}(Base.Broadcast.DefaultArrayStyle{N}())
end
BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}()
BlockSparseArrayStyle{M}(::Val{N}) where {M, N} = BlockSparseArrayStyle{N}()
function BlockSparseArrayStyle{M, B}(::Val{N}) where {M, B <: AbstractArrayStyle{M}, N}
return BlockSparseArrayStyle{N}(B(Val(N)))
end

function blockstyle(
::Broadcast.AbstractBlockSparseArrayStyle{N, B},
::AbstractBlockSparseArrayStyle{N, B},
) where {N, B <: Base.Broadcast.AbstractArrayStyle{N}}
return B()
end

function Base.Broadcast.BroadcastStyle(
style1::Broadcast.AbstractBlockSparseArrayStyle,
style2::Broadcast.AbstractBlockSparseArrayStyle,
style1::AbstractBlockSparseArrayStyle,
style2::AbstractBlockSparseArrayStyle,
)
style = Base.Broadcast.result_style(blockstyle(style1), blockstyle(style2))
return Broadcast.BlockSparseArrayStyle(style)
return BlockSparseArrayStyle(style)
end

Base.Broadcast.BroadcastStyle(a::Broadcast.BlockSparseArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
Base.Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
function Base.Broadcast.BroadcastStyle(
::Broadcast.BlockSparseArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle
::BlockSparseArrayStyle{N}, a::Base.Broadcast.DefaultArrayStyle
) where {N}
return Base.Broadcast.BroadcastStyle(Base.Broadcast.DefaultArrayStyle{N}(), a)
end
function Base.Broadcast.BroadcastStyle(
::Broadcast.BlockSparseArrayStyle{N}, ::Base.Broadcast.Style{Tuple}
::BlockSparseArrayStyle{N}, ::Base.Broadcast.Style{Tuple}
) where {N}
return Base.Broadcast.DefaultArrayStyle{N}()
end

function Base.similar(bc::Broadcasted{<:Broadcast.BlockSparseArrayStyle}, elt::Type, ax)
function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type, ax)
# Find the first array in the broadcast expression.
# TODO: Make this more generic, base it off sure this handles GPU arrays properly.
bc′ = Base.Broadcast.flatten(bc)
Expand Down Expand Up @@ -84,7 +82,7 @@ end

# Broadcasting implementation
function Base.copyto!(
dest::AbstractArray{<:Any, N}, bc::Broadcasted{Broadcast.BlockSparseArrayStyle{N}}
dest::AbstractArray{<:Any, N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
) where {N}
return copyto!_blocksparse(dest, bc)
end
2 changes: 1 addition & 1 deletion src/blocksparsearrayinterface/cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using BlockArrays: blocks
using FunctionImplementations.Concatenate: Concatenated, cat!

function Base.copyto!(
dest::AbstractArray, concat::Concatenated{<:Broadcast.BlockSparseArrayStyle}
dest::AbstractArray, concat::Concatenated{<:BlockSparseArrayStyle}
)
# TODO: This assumes the destination blocking is commensurate with
# the blocking of the sources, for example because it was constructed
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ ArrayLayouts = "1"
BlockArrays = "1.8"
BlockSparseArrays = "0.10"
DiagonalArrays = "0.3"
FunctionImplementations = "0.3"
FunctionImplementations = "0.4"
GPUArraysCore = "0.2"
JLArrays = "0.2, 0.3"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6"
Random = "1"
SafeTestsets = "0.1"
SparseArraysBase = "0.8"
SparseArraysBase = "0.9"
StableRNGs = "1"
Suppressor = "0.2"
TensorAlgebra = "0.6"
Expand Down
Loading