-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add tests for layernorm and add op_handler functions #3890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: FIX-pytorch-additivity-failed
Are you sure you want to change the base?
Add tests for layernorm and add op_handler functions #3890
Conversation
Adds nonlinear_1d as the op_handler for LayerNorm and passthrough for Identity
Testing for a model with LayerNorm both with it being in the first layer and not being the first layer, plus testing for the case where the background input matches the test input exactly in one of the features
@@ -400,6 +401,7 @@ def nonlinear_1d(module, grad_input, grad_output): | ||
op_handler["BatchNorm2d"] = linear_1d | ||
op_handler["BatchNorm3d"] = linear_1d | ||
|
||
op_handler["LayerNorm"] = nonlinear_1d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks for the update. Unfortunately this is wrong. Sorry, I had a thought about this and the problem here is that multiple inputs (given one output) vary. So for instance, let's have a look at the tensorflow implementation for multiplicative attribution and deduct it from the definition of shapley values:
Now let's assume we have just
We are looking for a multiplication (or any function that takes two arguments):
Then this formula becomes:
We would need to do something similar for the layer norm BUT this is computationally expensive, so we might be able to use the approach that the tensorflow implementation takes for softmax
(yes, pytorch impl is wrong here too!). I don't fully understand this yet, but doesn't seem too complicated.
Sorry, I just figured this out too late. Really appreciate your effort. Let me know how you want to proceed with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! It was too easy to be true. So, if I understand this correctly, in the original DeepLIFT paper they have the rescale rule which actually only applies to a single input. That's what's implemented in the nonlinear_1d
function? And that's why this isn't valid for LayerNorm
which takes several inputs. So we need to use the full equation to derive the proper function. Is this right? If so I think I understand and can try to see how to translate the way it's done in tensorflow to pytorch.
P.S. Does this mean that the softmax
implementation for pytorch also needs to be fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JulioHC00 yes, that's right. And also we would need to fix the implementation for pytorch's softmax
attribution. I believe that once we understand how the tensorflow softmax is implemented, we can apply this to pytorch LayerNorm and softmax as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CloseChoice I guess the only part I'm confused about is how does this fit with the op_handlers? Don't these handle the gradients in the backward propagation? How does this relate to implementing the shap values
equation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CloseChoice When I integrate LayerNorm into a testcase and use the op_handler['LayerNorm'] = softmax
, the testcase passes. Do we need to create a testcase that does show how Softmax/LayerNorm are broken? What would be in such a testcase?
Overview
Partially closes #3438 by adding support for LayerNorm and testing it
Description of the changes proposed in this pull request:
Adds LayerNorm to the op_handler dictionary with nonlinear_1d as well as Identity as passthrough. I've kept the structure of the tests as similar as possible to previous ones for consistency.
Checklist