Skip to content

Commit 5b2845a

Browse files
authored
Improve array indexing (#142)
1 parent ab0d6fa commit 5b2845a

File tree

8 files changed

+35
-38
lines changed

8 files changed

+35
-38
lines changed

src/functions/indAffineIterative.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,10 @@ end
3333
function prox!(y, f::IndAffineIterative{M, V}, x, gamma) where {M, V}
3434
# Von Neumann's alternating projections
3535
R = real(eltype(x))
36-
m = size(f.A, 1)
3736
y .= x
3837
for k = 1:1000
3938
maxres = R(0)
40-
for i = 1:m
39+
for i in eachindex(f.b)
4140
resi = (f.b[i] - dot(f.A[i,:], y))
4241
y .= y + resi*f.A[i,:] # no need to divide: rows of A are normalized
4342
absresi = resi > 0 ? resi : -resi

src/functions/indBallL0.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,28 @@ function (f::IndBallL0)(x)
3434
return R(0)
3535
end
3636

37+
function _get_top_k_abs_indices(x::AbstractVector, k)
38+
range = firstindex(x):(firstindex(x) + k - 1)
39+
return partialsortperm(x, range, by=abs, rev=true)
40+
end
41+
42+
_get_top_k_abs_indices(x, k) = _get_top_k_abs_indices(x[:], k)
43+
3744
function prox!(y, f::IndBallL0, x, gamma)
3845
T = eltype(x)
39-
p = []
40-
if ndims(x) == 1
41-
p = partialsortperm(x, 1:f.r, by=abs, rev=true)
42-
else
43-
p = partialsortperm(x[:], 1:f.r, by=abs, rev=true)
44-
end
45-
sort!(p)
46-
idx = 1
47-
for i = 1:length(p)
48-
y[idx:p[i]-1] .= T(0)
46+
p = _get_top_k_abs_indices(x, f.r)
47+
y .= T(0)
48+
for i in eachindex(p)
4949
y[p[i]] = x[p[i]]
50-
idx = p[i]+1
5150
end
52-
y[idx:end] .= T(0)
5351
return real(T)(0)
5452
end
5553

5654
function prox_naive(f::IndBallL0, x, gamma)
5755
T = eltype(x)
5856
p = sortperm(abs.(x)[:], rev=true)
5957
y = similar(x)
60-
y[p[1:f.r]] .= x[p[1:f.r]]
61-
y[p[f.r+1:end]] .= T(0)
58+
y[p[begin:begin+f.r-1]] .= x[p[begin:begin+f.r-1]]
59+
y[p[begin+f.r:end]] .= T(0)
6260
return y, real(T)(0)
6361
end

src/functions/indPSD.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function (f::IndPSD)(x::AbstractVector{Float64})
9494
f.scaling && scale_diagonal!(y, sqrt(2))
9595
9696
Z = dspev!(:N, :L, y)
97-
for i in 1:length(Z)
97+
for i in eachindex(Z)
9898
# Do we allow for some tolerance here?
9999
if Z[i] <= -1e-14
100100
return +Inf
@@ -118,7 +118,7 @@ function prox!(y::AbstractVector{Float64}, f::IndPSD, x::AbstractVector{Float64}
118118
# Now let M = Z*diagm(W)*Z'
119119
M = M*Z'
120120
n = length(W)
121-
k = 1
121+
k = firstindex(y)
122122
# Store lower diagonal of M in y
123123
for j in 1:n, i in j:n
124124
y[k] = M[i,j]
@@ -135,8 +135,8 @@ function prox_naive(f::IndPSD, x::AbstractVector{Float64}, gamma)
135135
# Formula for size of matrix
136136
n = Int(sqrt(1/4+2*length(x))-1/2)
137137
X = Matrix{Float64}(undef, n, n)
138-
k = 1
139-
# Store y in M
138+
k = firstindex(x)
139+
# Store x in X
140140
for j = 1:n, i = j:n
141141
# Lower half
142142
X[i,j] = x[k]
@@ -164,7 +164,7 @@ function prox_naive(f::IndPSD, x::AbstractVector{Float64}, gamma)
164164
end
165165

166166
y = similar(x)
167-
k = 1
167+
k = firstindex(y)
168168
# Store Lower half of X in y
169169
for j = 1:n, i = j:n
170170
y[k] = X[i,j]

src/functions/normL21.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,17 @@ function (f::NormL21)(X)
3333
nslice = R(0)
3434
n21X = R(0)
3535
if f.dim == 1
36-
for j = 1:size(X, 2)
36+
for j in axes(X, 2)
3737
nslice = R(0)
38-
for i = 1:size(X, 1)
38+
for i in axes(X, 1)
3939
nslice += abs(X[i, j])^2
4040
end
4141
n21X += sqrt(nslice)
4242
end
4343
elseif f.dim == 2
44-
for i = 1:size(X, 1)
44+
for i in axes(X, 1)
4545
nslice = R(0)
46-
for j = 1:size(X, 2)
46+
for j in axes(X, 2)
4747
nslice += abs(X[i, j])^2
4848
end
4949
n21X += sqrt(nslice)
@@ -58,29 +58,29 @@ function prox!(Y, f::NormL21, X, gamma)
5858
nslice = R(0)
5959
n21X = R(0)
6060
if f.dim == 1
61-
for j = 1:size(X, 2)
61+
for j in axes(X, 2)
6262
nslice = R(0)
63-
for i = 1:size(X, 1)
63+
for i in axes(X, 1)
6464
nslice += abs(X[i, j])^2
6565
end
6666
nslice = sqrt(nslice)
6767
scal = 1 - gl / nslice
6868
scal = scal <= 0 ? R(0) : scal
69-
for i = 1:size(X, 1)
69+
for i in axes(X, 1)
7070
Y[i, j] = scal * X[i, j]
7171
end
7272
n21X += scal * nslice
7373
end
7474
elseif f.dim == 2
75-
for i = 1:size(X, 1)
75+
for i in axes(X, 1)
7676
nslice = R(0)
77-
for j = 1:size(X, 2)
77+
for j in axes(X, 2)
7878
nslice += abs(X[i, j])^2
7979
end
8080
nslice = sqrt(nslice)
8181
scal = 1-gl/nslice
8282
scal = scal <= 0 ? R(0) : scal
83-
for j = 1:size(X, 2)
83+
for j in axes(X, 2)
8484
Y[i, j] = scal * X[i, j]
8585
end
8686
n21X += scal * nslice

test/test_calculus.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ stuff = [
100100
)
101101
]
102102
103-
@testset "$i" for i = 1:length(stuff)
103+
@testset "$i" for i in eachindex(stuff)
104104
f = stuff[i]["funcs"][1]
105105
g = stuff[i]["funcs"][2]
106106
107-
for j = 1:length(stuff[i]["args"])
107+
for j in eachindex(stuff[i]["args"])
108108
x = stuff[i]["args"][j]
109109
gamma = stuff[i]["gammas"][j]
110110

test/test_gradients.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ stuff = [
141141
),
142142
]
143143

144-
for i = 1:length(stuff)
144+
for i in eachindex(stuff)
145145

146146
f = stuff[i]["f"]
147147
x = stuff[i]["x"]

test/test_graph.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ stuff = [
6666
),
6767
]
6868

69-
for i = 1:length(stuff)
69+
for i in eachindex(stuff)
7070
constr = stuff[i]["constr"]
7171

7272
if haskey(stuff[i], "wrong")
73-
for j = 1:length(stuff[i]["wrong"])
73+
for j in eachindex(stuff[i]["wrong"])
7474
wrong = stuff[i]["wrong"][j]
7575
@test_throws ErrorException constr(wrong...)
7676
end
7777
end
7878

79-
for j = 1:length(stuff[i]["params"])
79+
for j in eachindex(stuff[i]["params"])
8080
params = stuff[i]["params"][j]
8181
x = stuff[i]["args"][j]
8282
f = constr(params...)

test/test_results.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ stuff = [
321321
)
322322
]
323323

324-
@testset "$(i)" for i = 1:length(stuff)
324+
@testset "$(i)" for i in eachindex(stuff)
325325

326326
f = stuff[i]["f"]
327327
x = stuff[i]["x"]

0 commit comments

Comments
 (0)