-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
precision_recall()
counts within group connections twice but not between group connections
#62
Comments
Hey @FloHu Great catch! |
I think we need to decide with @shntnu and probably also @gwaygenomics how this should be and then implement it. I find it very confusing that we are not counting nearest neighbors but instead connections that can have two directions. |
Thanks for tracking this down @FloHu I'll point you to this:
Maybe this is where the fix belongs? As far as how to move forward, I recommend filing a pull request to implement the appropriate code changes. Perhaps you can start first with breaking the tests by adding the correct calculation that you could calculate by hand (or a different, more reliable tool). Then, we'll know if any fix corrects the calculation. |
Exactly, the counting of the connections did not filter out/consider the bidirectionality. @gwaygenomics , yes I think what you are quoting is where the fix belongs. I have this already on my to do list and will submit a pull request when I'm done. |
So we discussed this today in a meeting. Unless you @FloHu are motivated to work on this branch and create a propper PR to fix theis issue. |
Great @michaelbornholdt - I'll do my best to provide a speedy review as well. To ensure that my review is speedy, I suggest that you approach the PR as I suggested above:
|
Actually, if you could do that @michaelbornholdt that would be great. As you said, it should be a quick fix. I'm happy to check my examples then with the new version and compare them with what I posted here. |
Hopefully, this aligns with what we discussed @michaelbornholdt: The implementation as it is right now is correct in the sense that we are reporting a class probability metric, (here, Precision@K) for a query result, where the query happens to be all the compounds of the MOA, and the results are the ranked list of compounds in the dataset, including the compounds in that MOA. Let's say that the query had Q elements, and the whole dataset had D elements. The query result should have We are now reporting Precision@K (or whatever) on these QD-Q elements, each of which is an edge between a compound in the MOA and some other compound in the dataset (including compounds in the MOA). The only issue is that the K here is the size of the head of the list, but the list comprises edges from a set of points, not a single point, so you should look at the top K*Q elements of the list i.e. when you compute precision, you should compute, Precision@KQ really (for it to make intuitive sense). And because the Q is different for different MOAs, the KQ will be different for each. So that's one way to fix it – changed K to KQ, per MOA. The other way to fix it is to not use a set to query, and instead, report the class probability metric per compound, and then aggregate per MOA i.e. macro-average instead of micro-average, roughly. That's the solution you converged on yesterday, which works well. The only issue with the macro-averaging is that, if the package did support multi-labels, your macro-average would be inaccurate because you'd have averaged metrics of compounds belonging to an MOA, but the metrics of each compound would be based on all the classes to which it belongs. E.g. imagine a compound C belong to M1 and M2, and its ability to cluster with M2 compounds was awesome but with M1 compounds sucked. The macro-average for M1 would then benefit from C's ability to cluster well with M2 compounds, which is really unfair, you know. However, because the package does not support multi-label, it is fine to go with your solution (this is the tl;dr, appropriately placed at the end :D) PS – I'm not in the zone of thinking about this anymore, but I think it's the case that this is an issue only for class probability metrics that have a threshold, and not others, e.g. PRAUC and AP are not affected |
So there is a branch on my repo which has a 'fixed' version of precision recall. As we said, its really small. If someone has the time to come up with good tests, we could merge it sometime soon... https://github.com/michaelbornholdt/cytominer-eval/tree/fix_prec_recall |
if all the existing tests did was to test that ITGAV was the best, I'd argue that it is not your responsibility to improve them. I would say though, that it is your responsibility to document in a new issue that the tests for precision-recall should be improved. You should link this issue. That being said, you must have done some internal testing to confirm that your fix works - why not add those tests if you have them already? That should take ~1 hour to add and review. worth it IMO So, my vote is to file a pull request from that branch you list and get it into the repo, before having a perfect test (esp. since one didn't exist already), to avoid the situation where branches diverge and future folks use precision recall incorrectly. |
My internal test was looking through the debug mode and checking if everything looked correct. There will be a Merge conflict with hit@k. So I think I need to figure out which one to do first. |
I am not sure I follow @shntnu. You write: Is this not exactly the problem? Assuming Q = 4 and D = 6 (so 2 non-query nodes): this way we count within query connections twice (QQ - Q, so 16-4 in a 4x4 table) whereas connections from the query to the non-query points are counted once: 4*(6-4) = 8 (these edges are all unique). The fact that Q has a different size for each MoA is separate from that. I have not thought through your @kq and macro/micro average discussion in detail now but clearly this factor needs to be considered too. |
Here's an easy way to check: what if Q=1? The count cannot be QQ/2 - Q + Q(D-Q) because you'll get D - 1.5. It should be D-1. (I assume we're clear that D is the total number of vertices?) Don't think of it as double counting same-MOA edges. Think of it as: for each query vertex, you have Q-1 same-MOA edges and D-Q different-MOA edges. |
Sorry, my bad, I meant (Q*Q - Q)/2 (i.e. the number of pairs in a table if order does not matter and without self-self). In which case you always get an integer and for Q = 1 the result is 0. |
When reading the source code for
cytominer_eval.operations.precision_recall()
I noticed that the similarity_melted_df variable counts each replicate pair twice, e.g. A1 --> A2 and A2 --> A1.This becomes a problem because only the first replicate_group_col in lines 49-52 is subsequently used for grouping:
In the next step, each group is passed to
calculate_precision_recall()
:With the effect that all samples from within a group are counted twice. However, samples from outside the group are only counted once because
group_by
will filter out one direction.Let me clarify this with an example. Consider 5 samples, the first 3 from group 'A', the second 2 from group 'B', both with greater within-group than between group correlations:
Then what
calculate_precision_recall
will see is this:For example, one can see that the
sample_pair_a
column has a row forA1-->A2
and one forA2-->A1
but only one forA1-->B1
.B1-->A1
is missing because of the way the melted data frame is generated and the grouping is performed. One can also see that the similarity metrics for within group connections appear in duplicates.Accordingly the outcome for precision and recall at k=4 is the following:
Precision: all 4 closest connections are from within group for A but only 2 for group B.
Recall: 4/6 connections found for A but all 2 found for B.
In summary, the computations are not entirely correct, especially for smaller groups. Also consider that with odd values for k only one of the two connections of the symmetric pair is used.
Admittedly, this is a bit mind-boggling. I recommend using a debugger if you want to trace all the steps in detail by yourself.
Proposed solution: I would suggest to count each pair only once when creating the melted data frame.
The text was updated successfully, but these errors were encountered: