Skip to content

Commit 712a201

Browse files
committed
Refactor common parts
1 parent 5a53ef2 commit 712a201

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

src/fillbroadcast.jl

+20-17
Original file line numberDiff line numberDiff line change
@@ -118,33 +118,36 @@ _isfill(f::Number) = true
118118
_isfill(f::Ref) = true
119119
_isfill(::Any) = false
120120

121-
_broadcast_maybecopy(bc::Broadcast.Broadcasted{<:AbstractFillStyle}) = copy(bc)
122-
_broadcast_maybecopy(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...)
123-
_broadcast_maybecopy(x) = x
121+
function _copy_fill(bc)
122+
v = _getindex_value(bc)
123+
if _iszeros(bc)
124+
return Zeros(typeof(v), axes(bc))
125+
elseif _isones(bc)
126+
return Ones(typeof(v), axes(bc))
127+
end
128+
return Fill(v, axes(bc))
129+
end
130+
131+
# recursively copy the purely fill components
132+
function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
133+
_isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
134+
end
135+
_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
136+
_preprocess_fill(x) = x
124137

125138
function _fallback_copy(bc)
126-
# treat the fill components
127-
bc2 = Base.broadcasted(bc.f, map(_broadcast_maybecopy, bc.args)...)
139+
# copy the purely fill components
140+
bc2 = Base.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
128141
# fallback style
129142
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{ndims(bc)}}
130143
copy(convert(S, bc2))
131144
end
132145

133146
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
134-
if _iszeros(bc)
135-
return Zeros(typeof(_getindex_value(bc)), axes(bc))
136-
elseif _isones(bc)
137-
return Ones(typeof(_getindex_value(bc)), axes(bc))
138-
elseif _isfill(bc)
139-
return Fill(_getindex_value(bc), axes(bc))
140-
else
141-
_fallback_copy(bc)
142-
end
147+
_isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc)
143148
end
144149
# make the zero-dimensional case consistent with Base
145-
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}})
146-
_fallback_copy(bc)
147-
end
150+
Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc)
148151

149152
# some cases that preserve 0d
150153
function broadcast_preserving_0d(f, As...)

0 commit comments

Comments
 (0)