VOOZH about

URL: https://huggingface.co/google/gemma-scope-2b-pt-res/discussions/8

โ‡ฑ google/gemma-scope-2b-pt-res ยท gemma-2-2b layer 20 SAE width 65k SAE seems very off


gemma-2-2b layer 20 SAE width 65k SAE seems very off

#8
by charlieoneill - opened

I have been evaluating gemma-2-2b SAEs on a dataset of medical text. Looking at the 16k width on layer 20, the metrics seem to be about what I'd expect:

{
 "l0_139": {
 "l2_loss": 148.585,
 "l1_loss": 2728.04,
 "l0": 183.4201708984375,
 "frac_variance_explained": -0.40267578125,
 "cossim": 0.92232421875,
 "l2_ratio": 0.93435546875,
 "relative_reconstruction_bias": 1.97640625,
 "loss_original": 1.8141272115707396,
 "loss_reconstructed": 2.10102525472641,
 "loss_zero": 12.452932243347169,
 "frac_recovered": 0.9730078125,
 "frac_alive": 0.9940185546875,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 139,
 "layer": 20,
 "width": "16k"
 }
 },
 "l0_22": {
 "l2_loss": 284.485,
 "l1_loss": 2719.16,
 "l0": 55.3458203125,
 "frac_variance_explained": -31.495,
 "cossim": 0.87486328125,
 "l2_ratio": 0.91640625,
 "relative_reconstruction_bias": 15.23140625,
 "loss_original": 1.8141272115707396,
 "loss_reconstructed": 2.4616863882541655,
 "loss_zero": 12.452932243347169,
 "frac_recovered": 0.9390829825401306,
 "frac_alive": 0.82476806640625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 22,
 "layer": 20,
 "width": "16k"
 }
 },
 "l0_294": {
 "l2_loss": 130.0175,
 "l1_loss": 3763.92,
 "l0": 352.6443994140625,
 "frac_variance_explained": -0.01845703125,
 "cossim": 0.9406640625,
 "l2_ratio": 0.94486328125,
 "relative_reconstruction_bias": 1.71236328125,
 "loss_original": 1.8141272115707396,
 "loss_reconstructed": 2.0600193762779235,
 "loss_zero": 12.452932243347169,
 "frac_recovered": 0.9768525409698486,
 "frac_alive": 0.99761962890625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 294,
 "layer": 20,
 "width": "16k"
 }
 },
 "l0_38": {
 "l2_loss": 251.58,
 "l1_loss": 2645.76,
 "l0": 73.9402734375,
 "frac_variance_explained": -20.34841796875,
 "cossim": 0.889765625,
 "l2_ratio": 0.9233984375,
 "relative_reconstruction_bias": 11.4728125,
 "loss_original": 1.8141272115707396,
 "loss_reconstructed": 2.3733639335632324,
 "loss_zero": 12.452932243347169,
 "frac_recovered": 0.947366454899311,
 "frac_alive": 0.89910888671875,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 38,
 "layer": 20,
 "width": "16k"
 }
 },
 "l0_71": {
 "l2_loss": 189.87,
 "l1_loss": 2500.32,
 "l0": 109.7097705078125,
 "frac_variance_explained": -4.80037109375,
 "cossim": 0.90638671875,
 "l2_ratio": 0.92884765625,
 "relative_reconstruction_bias": 4.6397265625,
 "loss_original": 1.8141272115707396,
 "loss_reconstructed": 2.1981925880908966,
 "loss_zero": 12.452932243347169,
 "frac_recovered": 0.9638544994592667,
 "frac_alive": 0.96929931640625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 71,
 "layer": 20,
 "width": "16k"
 }
 }
 }

However, the 65k for layer 20 has really weird metrics, including a very poor loss recovered (i.e. Equation 10 from the gated SAEs paper: https://arxiv.org/pdf/2404.16014), despite having a low L2 loss. I thought it may be a quirk of the dataset, but have reproduced this somewhat on monology/pile-uncopyrighted:

{
 "l0_114": {
 "l2_loss": 65.14174501419068,
 "l1_loss": 326.4906903076172,
 "l0": 19.7434326171875,
 "frac_variance_explained": -1.1298050680756568,
 "cossim": 0.44833588257431983,
 "l2_ratio": 1.458713674545288,
 "relative_reconstruction_bias": 3.926573168039322,
 "loss_original": 2.151599160730839,
 "loss_reconstructed": 12.79894030570984,
 "loss_zero": 12.452933530807496,
 "frac_recovered": -0.03705257594643627,
 "frac_alive": 0.1755828857421875,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 114,
 "layer": 20,
 "width": "65k"
 }
 },
 "l0_20": {
 "l2_loss": 78.6240915298462,
 "l1_loss": 274.0826930999756,
 "l0": 6.4778857421875,
 "frac_variance_explained": -8.00341603398323,
 "cossim": 0.3754740992188454,
 "l2_ratio": 1.6491711509227753,
 "relative_reconstruction_bias": 8.075071120262146,
 "loss_original": 2.151599160730839,
 "loss_reconstructed": 18.244347710609436,
 "loss_zero": 12.452933530807496,
 "frac_recovered": -0.5657323953509331,
 "frac_alive": 0.02691650390625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 20,
 "layer": 20,
 "width": "65k"
 }
 },
 "l0_221": {
 "l2_loss": 61.26867036819458,
 "l1_loss": 394.06997283935544,
 "l0": 30.2818212890625,
 "frac_variance_explained": -0.004639597833156586,
 "cossim": 0.47954541400074957,
 "l2_ratio": 1.4224228554964065,
 "relative_reconstruction_bias": 2.8707287490367888,
 "loss_original": 2.151599160730839,
 "loss_reconstructed": 10.630927562713623,
 "loss_zero": 12.452933530807496,
 "frac_recovered": 0.17530182713409886,
 "frac_alive": 0.2276763916015625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 221,
 "layer": 20,
 "width": "65k"
 }
 },
 "l0_34": {
 "l2_loss": 77.58435577392578,
 "l1_loss": 281.88170654296874,
 "l0": 8.2050439453125,
 "frac_variance_explained": -10.130443168580532,
 "cossim": 0.41340469181537626,
 "l2_ratio": 1.6560911977291106,
 "relative_reconstruction_bias": 9.29422394156456,
 "loss_original": 2.151599160730839,
 "loss_reconstructed": 17.128004446029664,
 "loss_zero": 12.452933530807496,
 "frac_recovered": -0.4539519951120019,
 "frac_alive": 0.065582275390625,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 34,
 "layer": 20,
 "width": "65k"
 }
 },
 "l0_61": {
 "l2_loss": 77.927738571167,
 "l1_loss": 314.41664611816407,
 "l0": 14.6854248046875,
 "frac_variance_explained": -7.959465856552124,
 "cossim": 0.41553613662719724,
 "l2_ratio": 1.6819834589958191,
 "relative_reconstruction_bias": 7.942268486022949,
 "loss_original": 2.151599160730839,
 "loss_reconstructed": 15.694214601516723,
 "loss_zero": 12.452933530807496,
 "frac_recovered": -0.31723272004863245,
 "frac_alive": 0.1244659423828125,
 "hyperparameters": {
 "n_inputs": 200,
 "context_length": 1024,
 "l0": 61,
 "layer": 20,
 "width": "65k"
 }
 }
}

I will evaluate some other SAEs and other gemma models to see if this is just a specific problem with this SAE in this model in this layer. I did all evaluation with the dictionary_learning repo (https://github.com/saprmarks/dictionary_learning). But would be good if someone sanity checks me / tells me if I'm missing something.

charlieoneill changed discussion status to closed

Discussion closed? Was this a bug?

ยท Sign up or log in to comment