-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix compile times #150
base: master
Are you sure you want to change the base?
Fix compile times #150
Conversation
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., | ||
MeanPool((2, 2))]...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., | |
MeanPool((2, 2))]...) | |
Chain([conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)]..., MeanPool((2, 2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type for conv_bn
is already a Vector
, so shouldn't just Chain(conv_bn((1, 1), inplanes, outplanes; bias = false, rev = true)..., MeanPool((2, 2)))
work? Also, I know this suggestion has been shot down before because it would cause visual noise, but simply tweaking conv_bn
to return a Chain
does wonders for the TTFG:
master:
julia> using Metalhead
julia> using Flux: Zygote
julia> den = DenseNet();
julia> ip = rand(Float32, 224, 224, 3, 1);
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
77.621622 seconds (124.76 M allocations: 11.324 GiB, 1.67% gc time, 97.00% compilation time)
with conv_bn
returning a Chain
:
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
28.244888 seconds (89.40 M allocations: 9.049 GiB, 3.60% gc time, 90.78% compilation time)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ This needs some tricks to get this fast though. One major trick being that large Vector
s that are being splatted to give Chain
s....should not be (Flux 0.13 deals with this, so this works). Removing a single splat to a large vector of layers (the "body" of the DenseNet
) makes it shoot back up:
julia> @time Zygote.gradient((m,x) -> sum(m(x)), den, ip);
46.788491 seconds (117.59 M allocations: 10.873 GiB, 2.65% gc time, 94.90% compilation time)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woops, you are indeed right and the suggestion looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I am curious about is the large discrepancy b/w first compiles on master. I regularly get ~500s TTFG with DenseNet, you don't seem to get nearly as bad times. Mine is with GPUs turned off. Does that make up some of the difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am testing on an M1 Mac CPU, with 4 threads and Julia master. Maybe some of the discrepancy is there? Julia 1.8+ seemed to be an order of magnitude faster than Julia 1.7 last I checked for compilation of some stuff
DenseNet
had a major regression in the compile time to differentiate it over the releases.This is often times due to very long
Chain
s. This is a small fix that makes things a lot more manageable for the moment.This is a pattern we have across the library, so maybe something to fix elsewhere as well.