-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH add from_cv_results
in PrecisionRecallDisplay
(single Display)
#30508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to mention, I think I would like to decide on the order parameters for these display classes and their methods. They seem to have a lot of overlap and it would be great if they could be consistent.
I know that this would not matter when using the methods but it would be nice for the documentation API page if they were consistent?
|
||
estimator_name : str, default=None | ||
Name of estimator. If None, then the estimator name is not shown. | ||
curve_name : str or list of str, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought curve_name
is a more generalizable term (vs estimator_name
), especially with cv multi curves where we want to name each curve by the fold number.
Changing this name will mean that we must change _validate_plot_params
and thus all other classes that use _BinaryClassifierCurveDisplayMixin
I note that the parameter is named differently here (PrecisionRecallDisplay
init) vs in the from_prediction
and from_estimator
methods (where it's called name
). I'm not sure if this was accidental or to distinguish it from the method parameter 'name's?
# If multi-curve, ensure all args are of the right length | ||
req_multi = [ | ||
input for input in (self.precision, self.recall) if isinstance(input, list) | ||
] | ||
if req_multi and ((len(req_multi) != 2) or len({len(arg) for arg in req_multi}) > 1): | ||
raise ValueError( | ||
"When plotting multiple precision-recall curves, `self.precision` " | ||
"and `self.recall` should both be lists of the same length." | ||
) | ||
elif self.average_precision is not None: | ||
default_line_kwargs["label"] = f"AP = {self.average_precision:0.2f}" | ||
elif name is not None: | ||
default_line_kwargs["label"] = name | ||
n_multi = len(self.precision) if req_multi else None | ||
if req_multi: | ||
for name, param in zip( | ||
["self.average_precision", "`name` or `self.curve_name`"], | ||
(self.average_precision, name_) | ||
): | ||
if not((isinstance(param, list) and len(param) != n_multi) or param is not None): | ||
raise ValueError( | ||
f"For multi precision-recall curves, {name} must either be " | ||
"a list of the same length as `self.precision` and " | ||
"`self.recall`, or None." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I struggled to come up with a nice way to do this. The checks we need are:
precision
andrecall
, both need to be lists of the same length or both need to be single ndarray- for multi curve,
average_precision
andname
can either be a list of the same length or None.
This latter point is important, as previously I simply checked that all 4 parameters are of the same length if they were lists. I didn't check that 2 optional parameters needed to be None
if they were not a list, for the multi-curve situation.
Suggestions welcome for making this nicer.
The good part though is that this is easily factorized out and can be generalised for all similar displays.
name_ = [name_] * n_multi if name_ is None else name_ | ||
average_precision_ = ( | ||
[None] * n_multi if self.average_precision is None else self.average_precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this, but could not immediately think of a better way to do it
) | ||
# Note `pos_label` cannot be `None` (default=1), unlike other metrics | ||
# such as roc_auc | ||
average_precision = average_precision_score( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note pos_label
cannot be None
here (default=1), unlike other metrics as roc_auc
precision_all.append(precision) | ||
recall_all.append(recall) | ||
ap_all.append(average_precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't like this but not sure on the zip suggested in #30399 (comment) as you've got to unpack at the end 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes on review suggestions. Namely to make all the multi class params (precisions
, recalls
etc) list of ndarrays.
Also realised we did not need separate plot_single_curve
function, as most of the complexity was in _get_line_kwargs
if fold_line_kws is None: | ||
fold_line_kws = [ | ||
{"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decided that we should not specify single colour because indeed the the legend would be useless.
names : str, default=None | ||
Names of each precision-recall curve for labeling. If `None`, use | ||
name provided at `PrecisionRecallDisplay` initialization. If not | ||
provided at initialization, no labeling is shown. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems reasonable that if we change the name
parameter in the class init, we should change it here to, especially as we don't advocate people to use plot
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed this with @glemaitre and decided that it is okay to change to names
. We should however make it clear what this is setting - the label of the curve in the legend.
The problem use case we thought about was if someone created a plot and display object, then wanted to add one curve to it using plot
, names
would not make sense in this case. However, it would be difficult for us to manage the legend in such a case, so decided that it would be up to the user to manage the legend in such a case.
if len(self.line_) == 1: | ||
self.line_ = self.line_[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should line_
always be a list or should we do this to be backwards compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decided that we should deprecated line_
and add lines_
.
We'll add a getter such that if you try to access line_
you get a warning and the first item of lines_
, which will be removed in 2 releases.
Just wanted to document here that we discussed a potential enhancement for comparing between estimators, where you have cv results from several estimators (so several fold curves for each estimator). Potentially this could be added as a separate function, where you pass the display object, and estimators desired. Not planned, just a potential additional in future. |
Reference Issues/PRs
Follows on from #30399
What does this implement/fix? Explain your changes.
Proof of concept of adding multi displays to
PrecisionRecallDisplay
from_cv_results
inRocCurveDisplay
(singleRocCurveDisplay
) #30399, so we can definitely factorize out, though small intricacies may make it complexplot
method is complex due to handling both single and multi curve and doing a lot more checking, as user is able to use it outside of thefrom_estimator
andfrom_predictions
methods.Detailed discussions of problems in review comments.
Any other comments?
cc @glemaitre @jeremiedbb