First commit to add weighted mean square#140
First commit to add weighted mean square#140marcobonici wants to merge 6 commits intoPumasAI:mainfrom
Conversation
chriselrod
left a comment
There was a problem hiding this comment.
Mind also adding tests for the value and gradient?
src/loss.jl
Outdated
| end | ||
| (::WeightedSquaredLoss)(y, w) = WeightedSquaredLoss(y, w) | ||
| WeightedSquaredLoss() = WeightedSquaredLoss(nothing) | ||
| target(wsl::WeightedSquaredLoss) = getfield(wsl, :y)#maybe need to return both :y and :weights? |
There was a problem hiding this comment.
Yes, the target should be sliceable and the loss should be callable on target's result to create a new one.
It's used for slicing/iterating over batches.
There was a problem hiding this comment.
So, the point is that WeightedSquaredLoss(target(wsl)) should be able to run, did I get it right?
There was a problem hiding this comment.
Yes.
$ rg 'target\('
docs/src/examples/custom_loss_layer.md
49:SimpleChains.target(loss::BinaryLogitCrossEntropyLoss) = loss.targets
src/optimize.jl
93: tgt = view_slice_last(target(loss), f:l)
125: tgt = target(loss)
177: tgt = target(loss)
488: t = target(_chn)
679: tgt = target(chn)
src/loss.jl
25:target(_) = nothing
26:target(sc::SimpleChain) = target(last(sc.layers))
27:preserve_buffer(l::AbstractLoss) = target(l)
28:StrideArraysCore.object_and_preserve(l::AbstractLoss) = l, target(l)
31:iterate_over_losses(sc) = _iterate_over_losses(target(sc))
40: align(length(first(target(sl))) * static_sizeof(T)), static_sizeof(T)
42:function _layer_output_size_needs_temp_of_equal_len_as_target(
47: align(length(target(sl)) * static_sizeof(T)), static_sizeof(T)
66:target(sl::SquaredLoss) = getfield(sl, :y)
69:Base.getindex(sl::SquaredLoss, r) = SquaredLoss(view_slice_last(target(sl), r))
120:target(sl::AbsoluteLoss) = getfield(sl, :y)
127: AbsoluteLoss(view_slice_last(target(sl), r))
197:target(sl::LogitCrossEntropyLoss) = getfield(sl, :y)
205: _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
212: _layer_output_size_needs_temp_of_equal_len_as_target(Val{T}(), sl, s)
254: LogitCrossEntropyLoss(view(target(sl), r))
273: correct_count(Y, target(loss))
283: ec = correct_count(Y, target(loss))
src/penalty.jl
68:target(c::AbstractPenalty) = target(getchain(c))
Note that we also need things like view_slice_last(target(loss), f:l) to work.
So view_slice_last should be implemented.
Some form of PtrArray(tgt) should also work, but you could define a different function to use there that calls PtrArray by default, as overloading constructors to return something else is generally frowned upon.
There was a problem hiding this comment.
A few things @chriselrod .
So, as I thought target(wsl) needs to give back all the field of the struct. This is needed because, as you pointed out, WeightedSquaredLoss(target(wsl)) need to be working.
So, I have updated the target method
target(wsl::WeightedSquaredLoss) = getfield(wsl, :y), getfield(wsl, :w)
Since this is giving back a tuple, I have added a constructor, using the splat operator
WeightedSquaredLoss(x::Tuple) = WeightedSquaredLoss(x...)
Do you have any consideration on that? In the meantime, I'll focus on view_slice_last.
There was a problem hiding this comment.
If I correctly understand, view_slice_last is used to slice the fields of the loss. If so, this could possibly working.
function view_slice_last(target(wsl::WeightedSquaredLoss), r)
return Tuple(view_slice_last(f, r) for f in target(wsl))
end
I am returning a Tuple assuming that this can work with my constructor I just created.
Co-authored-by: Chris Elrod <elrodc@gmail.com>
Codecov ReportPatch coverage has no change and project coverage change:
Additional details and impacted files@@ Coverage Diff @@
## main #140 +/- ##
==========================================
- Coverage 73.82% 72.97% -0.85%
==========================================
Files 15 15
Lines 2617 2646 +29
==========================================
- Hits 1932 1931 -1
- Misses 685 715 +30
☔ View full report in Codecov by Sentry. |
Added the Weighted Mean Square Error Loss