Last active
November 28, 2023 17:56
-
-
Save xrq-phys/8b6e52f0f371acb0244950f755d7476f to your computer and use it in GitHub Desktop.
Julia Implementation of NumPy's Tensordot Functon, Compatible with Flux/Zygote's Automatic Differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tensordot.jl - Minimal Tensordot Implementation | |
Minimal tensordot for supporting Zygote's automatic differentiation. | |
This bunch of code is also compatible with FluxML's Tracker module. | |
""" | |
module Tensordot | |
using Zygote: @adjoint | |
using LinearAlgebra | |
""" | |
contract(T1, T2, axesL, axesR) | |
Contracts tensors 'T1' with 'T2' in axes specified in axesL and axesR. | |
Interface is somehow the same as np.tensordot(T1, T2, axes=(axesL, axesR)). | |
""" | |
contract(TL, TR, axesL::Array{Int}, axesR::Array{Int}) = begin | |
# Gets shape and axes information from helper. | |
shapeInfo = contractprep(size(TL), size(TR), axesL, axesR) | |
# Apply transformation | |
contractraw(TL, TR, shapeInfo...) | |
end # contract | |
# Raw contraction function. | |
contractraw(TL, TR, permL::Array{Int}, permR::Array{Int}, | |
shapeL::Tuple, shapeR::Tuple, extL::Array, extR::Array) = begin | |
# Multiply and restore to original shape. | |
reshape((reshape(permutedims(TL, permL), shapeL) * | |
reshape(permutedims(TR, permR), shapeR)), (extL..., extR...)) | |
end # contractraw | |
""" | |
prepcontract(T1, T2, axesL, axesR) | |
Index preparations for contracting. | |
""" | |
contractprep(shapeL::Tuple, shapeR::Tuple, axesL::Array{Int}, axesR::Array{Int}) = begin | |
shapeL = [shapeL...] | |
shapeR = [shapeR...] | |
# TODO: Check axes boundary. | |
# Dumb index permuting & size extraction | |
dumbperm(shape::Array{Int}, pick::Array{Int}) = begin | |
sbarrier = sort(pick) .+ 1 | |
ebarrier = sort(pick) .- 1 | |
sbarrier = vcat([1], sbarrier) | |
ebarrier = vcat(ebarrier, length(shape)) | |
regular = Int[] | |
for i = 1:length(sbarrier) | |
append!(regular, sbarrier[i]:ebarrier[i]) | |
end # for | |
return regular | |
end # dumbperm | |
# External permutation | |
permL = dumbperm(shapeL, axesL) | |
permR = dumbperm(shapeR, axesR) | |
# External shape | |
extL = [shapeL[i] for i in permL] | |
extR = [shapeR[i] for i in permR] | |
# Contractional permutation | |
append!(permL, axesL) | |
prepend!(permR, axesR) | |
outerL = if (length(extL)==0) 1 else reduce(*, extL) end | |
outerR = if (length(extR)==0) 1 else reduce(*, extR) end | |
innerL = reduce(*, [shapeL[i] for i in axesL]) | |
innerR = reduce(*, [shapeR[i] for i in axesR]) | |
permL, permR, (outerL, innerL), (innerR, outerR), extL, extR | |
end # contractprep | |
# Adjoint of contract should refrain from digging into index processing. | |
@adjoint contractprep(shapeL::Tuple, shapeR::Tuple, axesL::Array{Int}, axesR::Array{Int}) = begin | |
contractprep(shapeL, shapeR, axesL, axesR), _ -> nothing | |
end # @adjoint contractprep | |
end # module |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment