Created
August 20, 2021 05:24
-
-
Save Tokazama/c396eb25bfb2dc3ad57cfb9df150f4ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
abstract type AccessStyle end | |
struct LinearElement <: AccessStyle end | |
struct CartesianElement <: AccessStyle end | |
struct LinearCollection <: AccessStyle end | |
struct CartesianCollection <: AccessStyle end | |
## ArrayIndex | |
const LinearStrideIndex{S,O} = StrideIndex{1,(1,),Nothing,Tuple{S},Tuple{O}} | |
LinearStrideIndex(stride::CanonicalInt, offset::CanonicalInt) = StrideIndex{1,(1,),Nothing}((stride,), (offset,)) | |
struct PermutedIndex{N,I1,I2} <: ArrayIndex{N} | |
PermutedIndex{N,I1,I2}() where {N,I1,I2} = new{N,I1::NTuple{N,Int},I2::NTuple{N,Int}}() | |
PermutedIndex(a::PermutedDimsArray{T,N,I1,I2}) where {T,N,I1,I2} = PermutedIndex{N,I1,I2}() | |
PermutedIndex(::MatAdjTrans) = PermutedIndex{2,(2,1),(2,1)}() | |
end | |
struct SubIndex{N,I} <: ArrayIndex{N} | |
indices::I | |
SubIndex{N}(inds::Tuple) where {N} = new{N,typeof(inds)}(inds) | |
SubIndex(x::SubArray{T,N}) where {T,N} = SubIndex{N}(x.indices) | |
end | |
struct ComposedIndex{N,I1,I2} <: ArrayIndex{N} | |
i1::I1 | |
i2::I2 | |
ComposedIndex(i1::I1, i2::I2) where {I1,I2} = new{ndims(I1),I1,I2}(i1, i2) | |
end | |
@inline function Base.getindex(x::PermutedIndex{N,I1,I2}, i::AbstractCartesianIndex{N}) where {N,I1,I2} | |
return NDIndex(permute(Tuple(i), Val(I2))) | |
end | |
@inline function Base.getindex(x::LinearStrideIndex, i::CanonicalInt) | |
getfield(offsets(x), 1) + i * getfield(strides(x), 1) | |
end | |
Base.getindex(x::ConjugateIndex, i::AbstractCartesianIndex{2}) = getfield(Tuple(i), 2) | |
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::CanonicalInt) | |
return @inbounds(getfield(x, :i2)[getfield(x, :i1)[i]]) | |
end | |
@propagate_inbounds function Base.getindex(x::ComposedIndex, i::AbstractCartesianIndex) | |
return @inbounds(getfield(x, :i2)[getfield(x, :i1)[i]]) | |
end | |
## composed | |
Base.:(∘)(x::ArrayIndex, y::ArrayIndex) = ComposedIndex(y, x) | |
@inline function Base.:(∘)(x::StrideIndex{N,R,C}, y::PermutedIndex{N,perm,iperm}) where {N,R,C,perm,iperm} | |
if C === nothing || C === -1 | |
c2 = C | |
else | |
c2 = getfield(iperm, C) | |
end | |
return StrideIndex{N,permute(R, Val(perm)),c2}( | |
permute(strides(x), Val(perm)), | |
permute(offsets(x), Val(perm)), | |
) | |
end | |
@inline function Base.:(∘)(x::StrideIndex{N,R,C}, y::SubIndex{Ns,I}) where {N,R,C,Ns,I<:Tuple{Vararg{Any,N}}} | |
c = static(C) | |
if _get_tuple(I, c) <: AbstractUnitRange | |
c2 = known(getfield(_from_sub_dims(static(N), I), C)) | |
elseif (_get_tuple(I, c) <: AbstractArray) && (_get_tuple(I, c) <: Integer) | |
c2 = -1 | |
else | |
c2 = nothing | |
end | |
pdims = _to_sub_dims(I) | |
return StrideIndex{Ns,permute(R, pdims),c2}( | |
eachop(getmul, pdims, map(maybe_static_step, y.indices), strides(x)), | |
permute(offsets(x), pdims), | |
) | |
end | |
## layouts | |
layout(A::Array, ::LinearElement) = A, LinearStrideIndex(static(1), static(0)) | |
layout(A::Array, ::CartesianElement) = A, StrideIndex(A) | |
@inline function layout(A::PermutedIndex, ::CartesianElement) | |
buffer, index = layout(parent(A), CartesianElement()) | |
return buffer, (index ∘ PermutedIndex(A)) | |
end | |
@inline function layout(A::Base.FastSubArray, ::LinearElement) | |
buffer, index = layout(parent(A), LinearElement()) | |
return buffer, (index ∘ LinearStrideIndex(x.stride1, x.offset1)) | |
end | |
function layout(A::Base.FastContiguousSubArray, ::LinearElement) | |
buffer, index = layout(parent(A), LinearElement()) | |
return buffer, (index ∘ LinearStrideIndex(static(1), x.offset1)) | |
end | |
function layout(A::SubArray, ::CartesianElement) | |
buffer, index = layout(parent(A), CartesianElement()) | |
return buffer, (index ∘ SubIndex(A)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment