1
1
module BlockSparseArraysTensorAlgebraExt
2
- using BlockArrays: AbstractBlockedUnitRange
3
-
4
- using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
5
- using TensorProducts: OneToOne
6
2
3
+ using BlockArrays: AbstractBlockedUnitRange
7
4
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
5
+ using TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
8
6
9
7
TensorAlgebra. FusionStyle (:: AbstractBlockedUnitRange ) = BlockReshapeFusion ()
10
8
@@ -20,99 +18,4 @@ function TensorAlgebra.splitdims(
20
18
return blockreshape (a, axes)
21
19
end
22
20
23
- using BlockArrays:
24
- AbstractBlockVector,
25
- AbstractBlockedUnitRange,
26
- Block,
27
- BlockIndexRange,
28
- blockedrange,
29
- blocks
30
- using BlockSparseArrays:
31
- BlockSparseArrays,
32
- AbstractBlockSparseArray,
33
- AbstractBlockSparseArrayInterface,
34
- AbstractBlockSparseMatrix,
35
- BlockSparseArray,
36
- BlockSparseArrayInterface,
37
- BlockSparseMatrix,
38
- BlockSparseVector,
39
- block_merge
40
- using DerivableInterfaces: @interface
41
- using GradedUnitRanges:
42
- GradedUnitRanges,
43
- AbstractGradedUnitRange,
44
- blockmergesortperm,
45
- blocksortperm,
46
- dual,
47
- invblockperm,
48
- nondual,
49
- unmerged_tensor_product
50
- using LinearAlgebra: Adjoint, Transpose
51
- using TensorAlgebra:
52
- TensorAlgebra, FusionStyle, BlockReshapeFusion, SectorFusion, fusedims, splitdims
53
-
54
- # TODO : Make a `ReduceWhile` library.
55
- include (" reducewhile.jl" )
56
-
57
- TensorAlgebra. FusionStyle (:: AbstractGradedUnitRange ) = SectorFusion ()
58
-
59
- # TODO : Need to implement this! Will require implementing
60
- # `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
61
- function BlockSparseArrays. block_merge (
62
- a:: AbstractGradedUnitRange , blockmerger:: AbstractBlockedUnitRange
63
- )
64
- return a
65
- end
66
-
67
- # Sort the blocks by sector and then merge the common sectors.
68
- function block_mergesort (a:: AbstractArray )
69
- I = blockmergesortperm .(axes (a))
70
- return a[I... ]
71
- end
72
-
73
- function TensorAlgebra. fusedims (
74
- :: SectorFusion , a:: AbstractArray , merged_axes:: AbstractUnitRange...
75
- )
76
- # First perform a fusion using a block reshape.
77
- # TODO avoid groupreducewhile. Require refactor of fusedims.
78
- unmerged_axes = groupreducewhile (
79
- unmerged_tensor_product, axes (a), length (merged_axes); init= OneToOne ()
80
- ) do i, axis
81
- return length (axis) ≤ length (merged_axes[i])
82
- end
83
-
84
- a_reshaped = fusedims (BlockReshapeFusion (), a, unmerged_axes... )
85
- # Sort the blocks by sector and merge the equivalent sectors.
86
- return block_mergesort (a_reshaped)
87
- end
88
-
89
- function TensorAlgebra. splitdims (
90
- :: SectorFusion , a:: AbstractArray , split_axes:: AbstractUnitRange...
91
- )
92
- # First, fuse axes to get `blockmergesortperm`.
93
- # Then unpermute the blocks.
94
- axes_prod = groupreducewhile (
95
- unmerged_tensor_product, split_axes, ndims (a); init= OneToOne ()
96
- ) do i, axis
97
- return length (axis) ≤ length (axes (a, i))
98
- end
99
- blockperms = blocksortperm .(axes_prod)
100
- sorted_axes = map ((r, I) -> only (axes (r[I])), axes_prod, blockperms)
101
-
102
- # TODO : This is doing extra copies of the blocks,
103
- # use `@view a[axes_prod...]` instead.
104
- # That will require implementing some reindexing logic
105
- # for this combination of slicing.
106
- a_unblocked = a[sorted_axes... ]
107
- a_blockpermed = a_unblocked[invblockperm .(blockperms)... ]
108
- return splitdims (BlockReshapeFusion (), a_blockpermed, split_axes... )
109
- end
110
-
111
- # TODO : Handle this through some kind of trait dispatch, maybe
112
- # a `SymmetryStyle`-like trait to check if the block sparse
113
- # matrix has graded axes.
114
- function Base. axes (a:: Adjoint{<:Any,<:AbstractBlockSparseMatrix} )
115
- return dual .(reverse (axes (a' )))
116
- end
117
-
118
21
end
0 commit comments