what do neural network weights reliably learn?
chandan singh
why analyze neural net weights?
- neural nets learn useful representations
- understanding weights might help with engineering
- might help us understand the brain
- approaches: neuroscience learning rules, unsupervised learning, NN analysis, feature engineering
- engineering strong feature representations (mallat 12, huang et al. 06)
- sparse coding (olshausen & field 96)
- analyzing weight matrices (denil et al. 13, martin et al. 18)
- analyzing neural activations (tishby et al. 15, 17)
Wfinal=Winit+WprojWproj∈col(X)
this is a consequence of the learning rule:
Y=g(WX)
∂L∂W=∂L∂g(WX)⋅X
weights sometimes grow (especially when using ADAM)
⟹Wproj dominates
1st layer weight viz (optimizer = adam)
1st layer weight viz (optimizer = sgd)
mnist sparse dictionary (λ = 10)
mnist sparse dictionary (λ = 100)
why the first layer?: training mlps
|
Linear Classifier |
MLP-2 First-Layer |
MLP-2 Last-layer |
MNIST |
0.92 |
0.96 |
0.90 |
why the first layer?: training mlps
|
Linear Classifier |
Linear + Lenet First-layer |
MNIST |
0.92 |
0.98 |
1st and last layer norms grow
different nets learn the same thing
adam weights can be pruned more
(val) preds correlate
top k match
quantifying memorization
memorization is stable
hyperparameters that increase memorization
- ADAM over SGD
- smaller batch size
- larger learning rate
- larger width
- etc.
how does ADAM cause memorization?
- ADAM disproportionately increases first layer's learning rate
- increasing only first layer's learning rate qualitatively reproduces ADAM
going deeper: memorization in deep cnns
bias vs var
mse changes based on distr
mse with pcs
logistic regression with cvs (score is mse)
linear regression with cvs (score is mse)
what do neural network weights reliably learn?
chandan singh