Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "FunctionImplementations"
uuid = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
authors = ["ITensor developers <support@itensor.org> and contributors"]
version = "0.3.2"
version = "0.4.0"

[weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ FunctionImplementations = {path = ".."}
[compat]
Documenter = "1"
Literate = "2"
FunctionImplementations = "0.3"
FunctionImplementations = "0.4"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
FunctionImplementations = {path = ".."}

[compat]
FunctionImplementations = "0.3"
FunctionImplementations = "0.4"
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module FunctionImplementationsLinearAlgebraExt
import FunctionImplementations as FI
import LinearAlgebra as LA

struct DiagonalStyle <: FI.AbstractArrayStyle end
FI.Style(::Type{<:LA.Diagonal}) = DiagonalStyle()
const permuteddims_diag = DiagonalStyle()(FI.permuteddims)
struct DiagonalImplementationStyle <: FI.AbstractArrayImplementationStyle end
FI.ImplementationStyle(::Type{<:LA.Diagonal}) = DiagonalImplementationStyle()
const permuteddims_diag = DiagonalImplementationStyle()(FI.permuteddims)
function permuteddims_diag(a::AbstractArray, perm)
(ndims(a) == length(perm) && isperm(perm)) ||
throw(ArgumentError("no valid permutation of dimensions"))
Expand Down
142 changes: 71 additions & 71 deletions src/style.jl
Original file line number Diff line number Diff line change
@@ -1,156 +1,156 @@
### This is based on the BroadcastStyle code in
### https://github.com/JuliaLang/julia/blob/master/base/broadcast.jl
### Objects with customized behavior for a certain function should declare a Style
### Objects with customized behavior for a certain function should declare a ImplementationStyle

"""
`Style` is an abstract type and trait-function used to determine behavior of
objects. `Style(typeof(x))` returns the style associated
`ImplementationStyle` is an abstract type and trait-function used to determine behavior of
objects. `ImplementationStyle(typeof(x))` returns the style associated
with `x`. To customize the behavior of a type, one can declare a style
by defining a type/method pair

struct MyContainerStyle <: Style end
FunctionImplementations.Style(::Type{<:MyContainer}) = MyContainerStyle()
struct MyContainerImplementationStyle <: ImplementationStyle end
FunctionImplementations.ImplementationStyle(::Type{<:MyContainer}) = MyContainerImplementationStyle()

"""
abstract type Style end
Style(::Type{T}) where {T} = throw(MethodError(Style, (T,)))
abstract type ImplementationStyle end
ImplementationStyle(::Type{T}) where {T} = throw(MethodError(ImplementationStyle, (T,)))

struct UnknownStyle <: Style end
Style(::Type{Union{}}, slurp...) = UnknownStyle() # ambiguity resolution
struct UnknownImplementationStyle <: ImplementationStyle end
ImplementationStyle(::Type{Union{}}, slurp...) = UnknownImplementationStyle() # ambiguity resolution

"""
(s::Style)(f)
(s::ImplementationStyle)(f)

Calling a Style `s` with a function `f` as `s(f)` is a shorthand for creating a
Calling a ImplementationStyle `s` with a function `f` as `s(f)` is a shorthand for creating a
[`FunctionImplementations.Implementation`](@ref) object wrapping the function `f` with
Style `s`.
ImplementationStyle `s`.
"""
(s::Style)(f) = Implementation(f, s)
(s::ImplementationStyle)(f) = Implementation(f, s)

"""
`FunctionImplementations.AbstractArrayStyle <: Style` is the abstract supertype for any style
`FunctionImplementations.AbstractArrayImplementationStyle <: ImplementationStyle` is the abstract supertype for any style
associated with an `AbstractArray` type.

Note that if two or more `AbstractArrayStyle` subtypes conflict, the resulting
Note that if two or more `AbstractArrayImplementationStyle` subtypes conflict, the resulting
style will fall back to that of `Array`s. If this is undesirable, you may need to
define binary [`Style`](@ref) rules to control the output type.
define binary [`ImplementationStyle`](@ref) rules to control the output type.

See also [`FunctionImplementations.DefaultArrayStyle`](@ref).
See also [`FunctionImplementations.DefaultArrayImplementationStyle`](@ref).
"""
abstract type AbstractArrayStyle <: Style end
abstract type AbstractArrayImplementationStyle <: ImplementationStyle end

"""
`FunctionImplementations.DefaultArrayStyle()` is a [`FunctionImplementations.Style`](@ref)
indicating that an object behaves as an array. Specifically, `DefaultArrayStyle` is
`FunctionImplementations.DefaultArrayImplementationStyle()` is a [`FunctionImplementations.ImplementationStyle`](@ref)
indicating that an object behaves as an array. Specifically, `DefaultArrayImplementationStyle` is
used for any `AbstractArray` type that hasn't defined a specialized style, and in the
absence of overrides from other arguments the resulting output type is `Array`.
"""
struct DefaultArrayStyle <: AbstractArrayStyle end
Style(::Type{<:AbstractArray}) = DefaultArrayStyle()
struct DefaultArrayImplementationStyle <: AbstractArrayImplementationStyle end
ImplementationStyle(::Type{<:AbstractArray}) = DefaultArrayImplementationStyle()

# `ArrayConflict` is an internal type signaling that two or more different `AbstractArrayStyle`
# `ArrayImplementationConflict` is an internal type signaling that two or more different `AbstractArrayImplementationStyle`
# objects were supplied as arguments, and that no rule was defined for resolving the
# conflict. The resulting output is `Array`. While this is the same output type
# produced by `DefaultArrayStyle`, `ArrayConflict` "poisons" the Style so that
# 3 or more arguments still return an `ArrayConflict`.
struct ArrayConflict <: AbstractArrayStyle end
# produced by `DefaultArrayImplementationStyle`, `ArrayImplementationConflict` "poisons" the ImplementationStyle so that
# 3 or more arguments still return an `ArrayImplementationConflict`.
struct ArrayImplementationConflict <: AbstractArrayImplementationStyle end

### Binary Style rules
### Binary ImplementationStyle rules
"""
Style(::Style1, ::Style2) = Style3()
ImplementationStyle(::ImplementationStyle1, ::ImplementationStyle2) = ImplementationStyle3()

Indicate how to resolve different `Style`s. For example,
Indicate how to resolve different `ImplementationStyle`s. For example,

Style(::Primary, ::Secondary) = Primary()
ImplementationStyle(::Primary, ::Secondary) = Primary()

would indicate that style `Primary` has precedence over `Secondary`.
You do not have to (and generally should not) define both argument orders.
The result does not have to be one of the input arguments, it could be a third type.
"""
Style(::S, ::S) where {S <: Style} = S() # homogeneous types preserved
# Fall back to UnknownStyle. This is necessary to implement argument-swapping
Style(::Style, ::Style) = UnknownStyle()
# UnknownStyle loses to everything
Style(::UnknownStyle, ::UnknownStyle) = UnknownStyle()
Style(::S, ::UnknownStyle) where {S <: Style} = S()
ImplementationStyle(::S, ::S) where {S <: ImplementationStyle} = S() # homogeneous types preserved
# Fall back to UnknownImplementationStyle. This is necessary to implement argument-swapping
ImplementationStyle(::ImplementationStyle, ::ImplementationStyle) = UnknownImplementationStyle()
# UnknownImplementationStyle loses to everything
ImplementationStyle(::UnknownImplementationStyle, ::UnknownImplementationStyle) = UnknownImplementationStyle()
ImplementationStyle(::S, ::UnknownImplementationStyle) where {S <: ImplementationStyle} = S()
# Precedence rules
Style(::A, ::A) where {A <: AbstractArrayStyle} = A()
function Style(a::A, b::B) where {A <: AbstractArrayStyle, B <: AbstractArrayStyle}
ImplementationStyle(::A, ::A) where {A <: AbstractArrayImplementationStyle} = A()
function ImplementationStyle(a::A, b::B) where {A <: AbstractArrayImplementationStyle, B <: AbstractArrayImplementationStyle}
if Base.typename(A) ≡ Base.typename(B)
return A()
end
return UnknownStyle()
return UnknownImplementationStyle()
end
# Any specific array type beats DefaultArrayStyle
Style(a::AbstractArrayStyle, ::DefaultArrayStyle) = a
# Any specific array type beats DefaultArrayImplementationStyle
ImplementationStyle(a::AbstractArrayImplementationStyle, ::DefaultArrayImplementationStyle) = a

## logic for deciding the Style
## logic for deciding the ImplementationStyle

"""
style(cs...)::Style
style(cs...)::ImplementationStyle

Decides which `Style` to use for any number of value arguments.
Uses [`Style`](@ref) to get the style for each argument, and uses
Decides which `ImplementationStyle` to use for any number of value arguments.
Uses [`ImplementationStyle`](@ref) to get the style for each argument, and uses
[`result_style`](@ref) to combine styles.

# Examples
```jldoctest
julia> FunctionImplementations.style([1], [1 2; 3 4])
FunctionImplementations.DefaultArrayStyle()
FunctionImplementations.DefaultArrayImplementationStyle()
```
"""
function style end

style() = DefaultArrayStyle()
style(c) = result_style(Style(typeof(c)))
style() = DefaultArrayImplementationStyle()
style(c) = result_style(ImplementationStyle(typeof(c)))
style(c1, c2) = result_style(style(c1), style(c2))
@inline style(c1, c2, cs...) = result_style(style(c1), style(c2, cs...))

"""
result_style(s1::Style[, s2::Style])::Style
result_style(s1::ImplementationStyle[, s2::ImplementationStyle])::ImplementationStyle

Takes one or two `Style`s and combines them using [`Style`](@ref) to
determine a common `Style`.
Takes one or two `ImplementationStyle`s and combines them using [`ImplementationStyle`](@ref) to
determine a common `ImplementationStyle`.

# Examples

```jldoctest
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayStyle(), FunctionImplementations.DefaultArrayStyle())
FunctionImplementations.DefaultArrayStyle()
julia> FunctionImplementations.result_style(FunctionImplementations.DefaultArrayImplementationStyle(), FunctionImplementations.DefaultArrayImplementationStyle())
FunctionImplementations.DefaultArrayImplementationStyle()

julia> FunctionImplementations.result_style(FunctionImplementations.UnknownStyle(), FunctionImplementations.DefaultArrayStyle())
FunctionImplementations.DefaultArrayStyle()
julia> FunctionImplementations.result_style(FunctionImplementations.UnknownImplementationStyle(), FunctionImplementations.DefaultArrayImplementationStyle())
FunctionImplementations.DefaultArrayImplementationStyle()
```
"""
function result_style end

result_style(s::Style) = s
function result_style(s1::S, s2::S) where {S <: Style}
result_style(s::ImplementationStyle) = s
function result_style(s1::S, s2::S) where {S <: ImplementationStyle}
return s1 ≡ s2 ? s1 : error("inconsistent styles, custom rule needed")
end
# Test both orders so users typically only have to declare one order
result_style(s1, s2) = result_join(s1, s2, Style(s1, s2), Style(s2, s1))
result_style(s1, s2) = result_join(s1, s2, ImplementationStyle(s1, s2), ImplementationStyle(s2, s1))

# result_join is the final arbiter. Because `Style` for undeclared pairs results in UnknownStyle,
# we defer to any case where the result of `Style` is known.
result_join(::Any, ::Any, ::UnknownStyle, ::UnknownStyle) = UnknownStyle()
result_join(::Any, ::Any, ::UnknownStyle, s::Style) = s
result_join(::Any, ::Any, s::Style, ::UnknownStyle) = s
# result_join is the final arbiter. Because `ImplementationStyle` for undeclared pairs results in UnknownImplementationStyle,
# we defer to any case where the result of `ImplementationStyle` is known.
result_join(::Any, ::Any, ::UnknownImplementationStyle, ::UnknownImplementationStyle) = UnknownImplementationStyle()
result_join(::Any, ::Any, ::UnknownImplementationStyle, s::ImplementationStyle) = s
result_join(::Any, ::Any, s::ImplementationStyle, ::UnknownImplementationStyle) = s
# For AbstractArray types with undefined precedence rules,
# we have to signal conflict. Because ArrayConflict is a subtype of AbstractArray,
# this will "poison" any future operations (if we instead returned `DefaultArrayStyle`, then for
# we have to signal conflict. Because ArrayImplementationConflict is a subtype of AbstractArray,
# this will "poison" any future operations (if we instead returned `DefaultArrayImplementationStyle`, then for
# 3-array functions returned type would depend on argument order).
result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::UnknownStyle, ::UnknownStyle) =
ArrayConflict()
result_join(::AbstractArrayImplementationStyle, ::AbstractArrayImplementationStyle, ::UnknownImplementationStyle, ::UnknownImplementationStyle) =
ArrayImplementationConflict()
# Fallbacks in case users define `rule` for both argument-orders (not recommended)
result_join(::Any, ::Any, s1::S, s2::S) where {S <: Style} = result_style(s1, s2)
result_join(::Any, ::Any, s1::S, s2::S) where {S <: ImplementationStyle} = result_style(s1, s2)

@noinline function result_join(::S, ::T, ::U, ::V) where {S, T, U, V}
error(
"""
conflicting rules defined
FunctionImplementations.Style(::$S, ::$T) = $U()
FunctionImplementations.Style(::$T, ::$S) = $V()
One of these should be undefined (and thus return FunctionImplementations.UnknownStyle)."""
FunctionImplementations.ImplementationStyle(::$S, ::$T) = $U()
FunctionImplementations.ImplementationStyle(::$T, ::$S) = $V()
One of these should be undefined (and thus return FunctionImplementations.UnknownImplementationStyle)."""
)
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Adapt = "4"
Aqua = "0.8"
BlockArrays = "1.4"
FillArrays = "1.15"
FunctionImplementations = "0.3"
FunctionImplementations = "0.4"
JLArrays = "0.3"
LinearAlgebra = "1.10"
SafeTestsets = "0.1"
Expand Down
80 changes: 40 additions & 40 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,78 +12,78 @@ using Test: @test, @testset
@test f.f ≡ +
@test f.style ≡ MyAddAlgorithm()
end
@testset "(s::Style)(f)" begin
# Test the shorthand for creating an Implementation by calling a Style with a
@testset "(s::ImplementationStyle)(f)" begin
# Test the shorthand for creating an Implementation by calling a ImplementationStyle with a
# function.
@test FI.style([1, 2, 3])(getindex) ≡
FI.Implementation(getindex, FI.DefaultArrayStyle())
FI.Implementation(getindex, FI.DefaultArrayImplementationStyle())
end
@testset "Style" begin
# Test basic Style trait for different array types
@test FI.Style(typeof([1, 2, 3])) ≡ FI.DefaultArrayStyle()
@test FI.style([1, 2, 3]) ≡ FI.DefaultArrayStyle()
@test FI.Style(typeof([1 2; 3 4])) ≡ FI.DefaultArrayStyle()
@test FI.Style(typeof(rand(2, 3, 4))) ≡ FI.DefaultArrayStyle()
@testset "ImplementationStyle" begin
# Test basic ImplementationStyle trait for different array types
@test FI.ImplementationStyle(typeof([1, 2, 3])) ≡ FI.DefaultArrayImplementationStyle()
@test FI.style([1, 2, 3]) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof([1 2; 3 4])) ≡ FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(typeof(rand(2, 3, 4))) ≡ FI.DefaultArrayImplementationStyle()

# Test custom Style definition
struct CustomStyle <: FI.Style end
# Test custom ImplementationStyle definition
struct CustomImplementationStyle <: FI.ImplementationStyle end
struct CustomArray end
FI.Style(::Type{CustomArray}) = CustomStyle()
@test FI.Style(CustomArray) isa CustomStyle
FI.ImplementationStyle(::Type{CustomArray}) = CustomImplementationStyle()
@test FI.ImplementationStyle(CustomArray) isa CustomImplementationStyle

# Test custom AbstractArrayStyle definition
# Test custom AbstractArrayImplementationStyle definition
struct MyArray{T, N} <: AbstractArray{T, N}
data::Array{T, N}
end
struct MyArrayStyle <: FI.AbstractArrayStyle end
FI.Style(::Type{<:MyArray}) = MyArrayStyle()
@test FI.Style(MyArray) isa MyArrayStyle
struct MyArrayImplementationStyle <: FI.AbstractArrayImplementationStyle end
FI.ImplementationStyle(::Type{<:MyArray}) = MyArrayImplementationStyle()
@test FI.ImplementationStyle(MyArray) isa MyArrayImplementationStyle

# Test style homogeneity rule (same type returns preserved)
s1 = FI.DefaultArrayStyle()
s2 = FI.DefaultArrayStyle()
@test FI.Style(s1, s2) ≡ FI.DefaultArrayStyle()
s1 = FI.DefaultArrayImplementationStyle()
s2 = FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(s1, s2) ≡ FI.DefaultArrayImplementationStyle()

# Test UnknownStyle precedence
unknown = FI.UnknownStyle()
known = FI.DefaultArrayStyle()
@test FI.Style(known, unknown) ≡ known
@test FI.Style(unknown, unknown) ≡ unknown
# Test UnknownImplementationStyle precedence
unknown = FI.UnknownImplementationStyle()
known = FI.DefaultArrayImplementationStyle()
@test FI.ImplementationStyle(known, unknown) ≡ known
@test FI.ImplementationStyle(unknown, unknown) ≡ unknown

# Test ArrayConflict
conflict = FI.ArrayConflict()
@test conflict isa FI.ArrayConflict
@test conflict isa FI.AbstractArrayStyle
# Test ArrayImplementationConflict
conflict = FI.ArrayImplementationConflict()
@test conflict isa FI.ArrayImplementationConflict
@test conflict isa FI.AbstractArrayImplementationStyle

# Test style with no arguments
@test FI.style() ≡ FI.DefaultArrayStyle()
@test FI.style() ≡ FI.DefaultArrayImplementationStyle()

# Test style with single argument
@test FI.style([1, 2]) ≡ FI.DefaultArrayStyle()
@test FI.style([1 2; 3 4]) ≡ FI.DefaultArrayStyle()
@test FI.style([1, 2]) ≡ FI.DefaultArrayImplementationStyle()
@test FI.style([1 2; 3 4]) ≡ FI.DefaultArrayImplementationStyle()

# Test style with two arguments
result = FI.style([1, 2], [1 2; 3 4])
@test result ≡ FI.DefaultArrayStyle()
@test result ≡ FI.DefaultArrayImplementationStyle()

# Test style with same dimensions
result = FI.style([1], [2])
@test result ≡ FI.DefaultArrayStyle()
@test result ≡ FI.DefaultArrayImplementationStyle()

# Test style with multiple arguments
result = FI.style([1], [1 2], rand(2, 3, 4))
@test result ≡ FI.DefaultArrayStyle()
@test result ≡ FI.DefaultArrayImplementationStyle()

# Test result_style with single argument
@test FI.result_style(FI.DefaultArrayStyle()) isa FI.DefaultArrayStyle
@test FI.result_style(FI.DefaultArrayImplementationStyle()) isa FI.DefaultArrayImplementationStyle

# Test result_style with two identical styles
s = FI.DefaultArrayStyle()
s = FI.DefaultArrayImplementationStyle()
@test FI.result_style(s, s) ≡ s

# Test result_style with UnknownStyle
known = FI.DefaultArrayStyle()
unknown = FI.UnknownStyle()
# Test result_style with UnknownImplementationStyle
known = FI.DefaultArrayImplementationStyle()
unknown = FI.UnknownImplementationStyle()
@test FI.result_style(known, unknown) ≡ known
@test FI.result_style(unknown, known) ≡ known
end
Expand Down
Loading