Skip to content

Commit 3745b09

Browse files
committed
Allow non-static indices
1 parent 103e9d4 commit 3745b09

File tree

4 files changed

+73
-4
lines changed

4 files changed

+73
-4
lines changed

src/indexing.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
@inline index_size(::Size, ::Int) = Size()
7676
@inline index_size(::Size, a::StaticArray) = Size(a)
7777
@inline index_size(s::Size, ::Colon) = s
78-
@inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,)
78+
@inline index_size(::Size, a::AbstractRange{<:Integer}) = Size(length(a),)
7979

8080
@inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds)
8181

@@ -92,9 +92,9 @@ linear_index_size(ind_sizes::Type{<:Size}...) = _linear_index_size((), ind_sizes
9292
@inline _linear_index_size(t::Tuple, ::Type{Size{S}}, ind_sizes...) where {S} = _linear_index_size((t..., prod(S)), ind_sizes...)
9393

9494
_ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
95-
_ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j])
9695
_ind(i::Int, j::Int, ::Type{Colon}) = j
9796
_ind(i::Int, j::Int, ::Type{<:SOneTo}) = j
97+
_ind(i::Int, j::Int, ::Type{<:AbstractArray}) = :(inds[$i][$j])
9898

9999
################################
100100
## Non-scalar linear indexing ##
@@ -215,7 +215,7 @@ end
215215

216216
# getindex
217217

218-
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
218+
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, AbstractRange, Colon}...)
219219
_getindex(a, index_sizes(Size(a), inds...), inds)
220220
end
221221

test/abstractarray.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ using StaticArrays, Test, LinearAlgebra
8585
@test similar(v, SOneTo(3), SOneTo(4)) isa MMatrix{3,4,Int}
8686
@test similar(v, 3, SOneTo(4)) isa Matrix
8787

88-
@test m[:, 1:2] isa Matrix
88+
@test m[:, 1:2] isa SMatrix{2, 2, Int}
8989
@test m[:, [true, false, false]] isa Matrix
9090
@test m[:, SOneTo(2)] isa SMatrix{2, 2, Int}
9191
@test m[:, :] isa SMatrix{2, 3, Int}

test/indexing.jl

+30
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,34 @@ using StaticArrays, Test
223223
@test eltype(Bvv) == Int
224224
@test Bvv[:] == [B[1,2,3,4], B[1,1,3,4]]
225225
end
226+
227+
@testset "Indexing with constants" begin
228+
function SVector_UnitRange()
229+
x = SA[1, 2, 3]
230+
x[2:end]
231+
end
232+
@test SVector_UnitRange() === SA[2, 3]
233+
VERSION v"1.1" && @test_const_fold SVector_UnitRange()
234+
235+
function SVector_StepRange()
236+
x = SA[1, 2, 3, 4]
237+
x[1:2:end]
238+
end
239+
@test SVector_StepRange() === SA[1, 3]
240+
VERSION v"1.1" && @test_const_fold SVector_StepRange()
241+
242+
function SMatrix_UnitRange_UnitRange()
243+
x = SA[1 2 3; 4 5 6]
244+
x[1:2, 2:end]
245+
end
246+
@test SMatrix_UnitRange_UnitRange() === SA[2 3; 5 6]
247+
VERSION v"1.1" && @test_const_fold SMatrix_UnitRange_UnitRange()
248+
249+
function SMatrix_StepRange_StepRange()
250+
x = SA[1 2 3; 4 5 6]
251+
x[1:1:2, 1:2:end]
252+
end
253+
@test SMatrix_StepRange_StepRange() === SA[1 3; 4 6]
254+
VERSION v"1.1" && @test_const_fold SMatrix_StepRange_StepRange()
255+
end
226256
end

test/testutil.jl

+39
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,45 @@ should_not_be_inlined(x) = _should_not_be_inlined(x)
9696
end
9797

9898

99+
"""
100+
@test_const_fold f(args...)
101+
102+
Test that constant folding works with a function call `f(args...)`.
103+
"""
104+
macro test_const_fold(ex)
105+
quote
106+
ci, = $(esc(:($InteractiveUtils.@code_typed optimize = true $ex)))
107+
@test $(esc(ex)) == constant_return(ci)
108+
end
109+
end
110+
111+
struct NonConstantValue end
112+
113+
function constant_return(ci)
114+
if :rettype in fieldnames(typeof(ci))
115+
ci.rettype isa Core.Compiler.Const && return ci.rettype.val
116+
return NonConstantValue()
117+
else
118+
# for julia < 1.2
119+
ex = ci.code[end]
120+
Meta.isexpr(ex, :return) || return NonConstantValue()
121+
val = ex.args[1]
122+
return val isa QuoteNode ? val.value : val
123+
end
124+
end
125+
126+
@testset "@test_const_fold" begin
127+
should_const_fold() = (1, 2, 3)
128+
@test_const_fold should_const_fold()
129+
130+
x = Ref(1)
131+
should_not_const_fold() = x[]
132+
ts = @testset ErrorCounterTestSet "" begin
133+
@test_const_fold should_not_const_fold()
134+
end
135+
@test ts.errorcount == 0 && ts.failcount == 1 && ts.passcount == 0
136+
end
137+
99138
"""
100139
@inferred_maybe_allow allow ex
101140

0 commit comments

Comments
 (0)