Experiment 2.7.1: Delete hue subspace

This is a re-run of Ex 2.7 with more mature tooling. See the earlier notebook for discussion.

Our model has learnt to represent hue in the first two dimensions of latent space. If we ablate the weights related to those activation dimensions, then the model should lose the ability to operate on hue. The result should be high loss across all vibrant colors, and low loss on unsaturated colors.

from __future__ import annotations

nbid = '2.7.1'  # ID for tagging assets
nbname = 'Ablate hue'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import uv_freeze, project_packages
from utils.logging import SimpleLoggingConfig
from ex_color.vis import NbViz

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

image = (
    modal.Image.debian_slim()
    .pip_install(*uv_freeze(all_groups=True, not_groups='dev'))
    .add_local_python_source(*project_packages())
)
volume = modal.Volume.from_name(f'{project}-{nbid}', create_if_missing=True, version=2)
app = modal.App(name=f'{project}-{nbid}', image=image, volumes={'/data': volume})

viz = NbViz(nbid)
None  # prevent auto-display of this cell

Model parameters

We use the following regularizers:

  • Anchor: pins red to $(1,0,0,0)$ (4D)
  • Separate: angular repulsion to reduce global clumping (applied within each batch)
  • Planarity: pulls vibrant hues to the $[0, 1]$ plane
  • Unitarity: pulls all embeddings to the surface of the unit hypersphere
import torch

from ex_color.loss import Separate, Planarity, RegularizerConfig

K = 4  # bottleneck dimensionality
N = 1  # number of nonlinear layers
H = 16  # hidden layer size
RED = (1, 0, 0, 0)
assert len(RED) == K
BATCH_SIZE = 64
CUBE_SUBDIVISIONS = 8
NUM_RUNS = 60  # to probe seed sensitivity
RUN_SEEDS = [i for i in range(NUM_RUNS)]

reg_separate = RegularizerConfig(
    name='separate',
    compute_loss_term=Separate(power=100.0, shift=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)
reg_planar = RegularizerConfig(
    name='subspace',
    compute_loss_term=Planarity(),
    label_affinities={'vibrant': 1.0},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
from mini.temporal.dopesheet import Dopesheet

dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
viz.tab_dopesheet(dopesheet)
viz.plot_dopesheet(dopesheet)

Parameter schedule

STEP PHASE ACTION lr separate subspace
0 Train 1e-08 0.01
10 0.01
375 0.1
750 0.1 0.01
1125 0.1
1425 0.1 0 0
1500 0.05
Plot showing the parameter schedule for the training run, titled "". The plot has two sections: the upper section shows various regularization weights over time, and the lower section shows the learning rate over time. The x-axis represents training steps.

Data

Data is the same as last time: color cubes with values in RGB.

from torch.utils.data import DataLoader, RandomSampler

from ex_color.data.cube_dataset import prep_color_dataset, vibrancy, stochastic_labels, exact_labels


def prep_train_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-corners',
        vibrant=lambda c: vibrancy(c) ** 8 * 0.008,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )


def prep_val_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-centers',
        vibrant=lambda c: vibrancy(c) == 1,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=2,
        collate_fn=exact_labels,
    )

Train

from typing import Callable

import torch
import wandb

from ex_color.model import CNColorMLP
from ex_color.seed import set_deterministic_mode
from ex_color.workflow import train_model
from ex_color.evaluation import Result
from utils.time import hour


@app.function(
    cpu=1,
    max_containers=20,
    timeout=1 * hour,
    env={'WANDB_API_KEY': wandb.Api().api_key or ''},
)
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    *,
    seed: int,
    score_fn: Callable[[CNColorMLP], float],
):
    """Train the model with the given dopesheet and variant."""
    logging_config.apply()

    if seed is not None:
        set_deterministic_mode(seed)

    train_loader = prep_train_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    val_loader = prep_val_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    model = CNColorMLP(K, n_nonlinear=N)
    res = train_model(
        model,
        dopesheet,
        regularizers,
        train_loader,
        val_loader,
        experiment_name=experiment_name,
        project=project,
        hparams={'seed': seed},
    )

    score = score_fn(res.model)
    key = f'model-{res.id_}.pt'
    torch.save(res.model.state_dict(), f'/data/{key}')
    return Result(seed, key, res.url, res.summary, score)
from ex_color.evaluation import EvaluationPlan, ScoreByVibrancy
from ex_color.surgery import ablate

ablation_plan = EvaluationPlan(
    {'ablated'},
    lambda m: ablate(m, 'bottleneck', [0, 1]),
    [],
)

score_fn = ScoreByVibrancy(ablation_plan, power=2.3, cube_subdivisions=CUBE_SUBDIVISIONS)
import asyncio

# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
regularizers = [reg_separate, reg_planar]


async def sweep():
    logging_config.apply()
    workers = [train.remote.aio(dopesheet, regularizers, seed=seed, score_fn=score_fn) for seed in RUN_SEEDS]
    return await asyncio.gather(*workers)


with app.run():
    results = await sweep()
from IPython.display import display
from ex_color.evaluation import results_to_dataframe

runs_df = results_to_dataframe(results)
# Show min, max, mean, stddev of each column
log.info(f'Summary statistics for all {len(runs_df)} runs:')
display(runs_df.describe().loc[['min', 'max', 'mean', 'std']].style.format(precision=4))

print('Correlation of reconstruction error vs. vibrancy')
viz.plot_boxplot(runs_df['score'], ylabel='', xlim=(None, None), tags=('score',))

print('Reconstruction loss')
viz.plot_boxplot(runs_df['val_recon'], ylabel='', log_scale=True, tags=('val_recon',))

print('Planar loss')
viz.plot_boxplot(runs_df['val_subspace'], ylabel='', log_scale=True, tags=('val_subspace',))
I 334.8 no.2.7.1:Summary statistics for all 60 runs:
  seed score labels/n/_any labels/n_total val_recon val_subspace labels/n/vibrant _runtime val_loss
min 0.0000 0.6163 83.0000 96064.0000 0.0000 0.0045 83.0000 45.1505 0.0000
max 59.0000 0.9119 128.0000 96064.0000 0.0001 0.0103 128.0000 153.1971 0.0001
mean 29.5000 0.8247 101.6833 96064.0000 0.0000 0.0066 101.6833 74.4959 0.0000
std 17.4642 0.0526 9.7276 0.0000 0.0000 0.0013 9.7276 29.0329 0.0000
Correlation of reconstruction error vs. vibrancy
Horizontal box plot showing the distribution of .
Reconstruction loss
Horizontal box plot showing the distribution of .
Planar loss
Horizontal box plot showing the distribution of .

Select the best runs from the Pareto front of non-dominated runs, optimizing for validation loss and score.

from ex_color.evaluation import pareto_front

non_dominated = pareto_front(runs_df, minimize=['val_recon', 'val_subspace'], maximize=['score'])
log.info(f'Best of {len(non_dominated)} non-dominated runs (Pareto front):')
display(non_dominated.sort_values(by='score', ascending=False).head(5).style.format(precision=4, hyperlinks='html'))
I 337.3 no.2.7.1:Best of 12 non-dominated runs (Pareto front):
  seed wandb url score labels/n/_any labels/n_total val_recon val_subspace labels/n/vibrant _runtime val_loss
44 44 https://wandb.ai/z0r/ex-preppy/runs/g4rkh06h 0.9119 85 96064 0.0000 0.0080 85 103.6362 0.0000
32 32 https://wandb.ai/z0r/ex-preppy/runs/z4vbty4v 0.9089 107 96064 0.0001 0.0064 107 46.6926 0.0001
27 27 https://wandb.ai/z0r/ex-preppy/runs/tjqsokzy 0.8887 109 96064 0.0000 0.0051 109 54.6538 0.0000
46 46 https://wandb.ai/z0r/ex-preppy/runs/hm2patyd 0.8859 114 96064 0.0000 0.0068 114 56.5487 0.0000
30 30 https://wandb.ai/z0r/ex-preppy/runs/8yspyoe2 0.8750 128 96064 0.0000 0.0051 128 88.9397 0.0000
from typing import cast

from mini.data import load_checkpoint_from_volume

best_run = results[cast(int, non_dominated['score'].idxmax())]
log.info(f'Loading checkpoint of best run: seed={best_run.seed}, score={best_run.score:.4f} @ {best_run.url}')
model = CNColorMLP(K, n_nonlinear=N)
model = load_checkpoint_from_volume(model, volume, best_run.checkpoint_key)
I 337.4 no.2.7.1:Loading checkpoint of best run: seed=44, score=0.9119 @ https://wandb.ai/z0r/ex-preppy/runs/g4rkh06h

Results

# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % 4 if (a + 1) % 4 not in (a, b) else (a + 2) % 4,
#     )
#     for a, b in combinations((0, 1, 2, 3), 2)
# ]
from ex_color.evaluation import TestSet

test_set = TestSet.create()
from IPython.display import clear_output

baseline_results = test_set.evaluate(model, [], tags={'baseline'})
clear_output()

viz.plot_cube(baseline_results)
# viz.plot_recon_loss(baseline_results)
# viz.plot_latent_space(
#     baseline_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )
Plot showing four slices of the HSV cube, titled "Predicted colors · baseline · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.

Ablation

from IPython.display import clear_output
from ex_color.surgery import ablate

ablated_model = ablate(model, 'bottleneck', [0, 1])
ablation_results = test_set.evaluate(ablated_model, [], tags={'ablated'})
clear_output()

viz.plot_cube(ablation_results)
# viz.plot_recon_loss(ablation_results)
# viz.plot_latent_space(
#     ablation_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )
Plot showing four slices of the HSV cube, titled "Predicted colors · ablated · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.

Pruning

from IPython.display import clear_output

from ex_color.surgery import prune

pruned_model = prune(model, 'bottleneck', [0, 1])
pruned_results = test_set.evaluate(pruned_model, [], tags={'pruned'})
clear_output()

viz.plot_cube(pruned_results)
# viz.plot_recon_loss(pruned_results)
# viz.plot_latent_space(
#     pruned_results,
#     dims=[(0, None, None), (1, None, None)],
# )
Plot showing four slices of the HSV cube, titled "Predicted colors · pruned · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.

Suppression

Using the new axis-aligned suppression intervention to zero out the hue dimensions.

from IPython.display import clear_output

from ex_color.intervention import AxisAlignedSuppression, InterventionConfig


suppression = InterventionConfig(
    apply=AxisAlignedSuppression(dims=(0, 1)),
    layer_affinities=['bottleneck'],
)
suppression_results = test_set.evaluate(model, [suppression], tags={'suppression'})
clear_output()

viz.plot_cube(suppression_results)
# viz.plot_recon_loss(suppression_results)
# viz.plot_latent_space(
#     suppression_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )
Plot showing four slices of the HSV cube, titled "Predicted colors · suppression · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.
import numpy as np

max_error = np.max(
    [
        baseline_results.loss_cube['MSE'],
        ablation_results.loss_cube['MSE'],
        pruned_results.loss_cube['MSE'],
    ]
)

dims = ((0, 1, 2), (2, 3, 0))
pruned_dims = ((None, None, 0), (0, 1, 0))

print('Baseline')
viz.plot_stacked_results(
    baseline_results,
    latent_dims=dims,
    max_error=max_error,
)

print('Ablation')
viz.plot_stacked_results(
    ablation_results,
    latent_dims=dims,
    max_error=max_error,
)

print('Pruned')
viz.plot_stacked_results(
    pruned_results,
    latent_dims=pruned_dims,
    max_error=max_error,
)

print('Suppression')
viz.plot_stacked_results(
    suppression_results,
    latent_dims=dims,
    max_error=max_error,
)
Baseline
Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).
Ablation
Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).
Pruned
Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).
Suppression
Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).
viz.tab_error_vs_color(baseline_results, ablation_results, pruned_results, suppression_results)
viz.tab_error_vs_color_latex(baseline_results, ablation_results, pruned_results, suppression_results)
Name RGB Baseline Ablated Δ Abl Pruned Δ Pru Suppression Δ Sup
red
0.001 0.333 +0.333 0.333 +0.333 0.182 +0.181
orange
0.000 0.156 +0.156 0.156 +0.156 0.124 +0.124
yellow
0.001 0.318 +0.317 0.318 +0.317 0.175 +0.174
lime
0.000 0.160 +0.160 0.160 +0.160 0.145 +0.145
green
0.001 0.313 +0.313 0.313 +0.313 0.224 +0.223
teal
0.000 0.174 +0.174 0.174 +0.174 0.190 +0.190
cyan
0.001 0.315 +0.314 0.315 +0.314 0.261 +0.260
azure
0.000 0.181 +0.181 0.181 +0.181 0.216 +0.216
blue
0.001 0.333 +0.333 0.333 +0.333 0.271 +0.271
purple
0.000 0.189 +0.189 0.189 +0.189 0.196 +0.196
magenta
0.000 0.253 +0.253 0.253 +0.253 0.241 +0.240
pink
0.000 0.187 +0.187 0.187 +0.187 0.151 +0.151
black
0.000 0.001 +0.000 0.001 +0.000 0.001 +0.000
dark gray
0.000 0.000 +0.000 0.000 +0.000 0.000 +0.000
gray
0.000 0.000 +0.000 0.000 +0.000 0.000 +0.000
light gray
0.000 0.000 +0.000 0.000 +0.000 0.000 +0.000
white
0.000 0.001 +0.001 0.001 +0.001 0.001 +0.001
\begin{table}
\centering
\label{tab:placeholder}
\caption{Reconstruction error by color and intervention method}
\sisetup{
    round-mode = places,
    round-precision = 3,
    table-auto-round = true,
    % drop-zero-decimal = true,
}
\begin{tabular}{l c g g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Ab}} & \multicolumn{1}{c}{{Prun}} & \multicolumn{1}{c}{{Suppression}} \\
\midrule
Red        & \swatch{FF0000} &  0.000792319 &  0.332541019 &  0.332541019 &  0.181122527 \\
Orange     & \swatch{FF7F00} &  0.000019311 &  0.155589819 &  0.155589819 &  0.123887539 \\
Yellow     & \swatch{FFFF00} &  0.000705138 &  0.317163587 &  0.317163587 &  0.174232632 \\
Lime       & \swatch{7FFF00} &  0.000050605 &  0.159755826 &  0.159755826 &  0.144521520 \\
Green      & \swatch{00FF00} &  0.000586422 &  0.312581360 &  0.312581360 &  0.223322719 \\
Teal       & \swatch{00FF7F} &  0.000005727 &  0.173705757 &  0.173705757 &  0.190066084 \\
Cyan       & \swatch{00FFFF} &  0.000942572 &  0.314468980 &  0.314468980 &  0.260203123 \\
Azure      & \swatch{007FFF} &  0.000170504 &  0.181120127 &  0.181120127 &  0.215764135 \\
Blue       & \swatch{0000FF} &  0.000534647 &  0.332798690 &  0.332798690 &  0.270746559 \\
Purple     & \swatch{7F00FF} &  0.000001524 &  0.189247891 &  0.189247921 &  0.195607752 \\
Magenta    & \swatch{FF00FF} &  0.000381831 &  0.252688736 &  0.252688736 &  0.240460932 \\
Pink       & \swatch{FF007F} &  0.000017568 &  0.186851054 &  0.186851054 &  0.151000232 \\
Black      & \swatch{000000} &  0.000391695 &  0.000456497 &  0.000456497 &  0.000495712 \\
Dark gray  & \swatch{3F3F3F} &  0.000069575 &  0.000227389 &  0.000227389 &  0.000234607 \\
Gray       & \swatch{7F7F7F} &  0.000054224 &  0.000235140 &  0.000235140 &  0.000237627 \\
Light gray & \swatch{BFBFBF} &  0.000010381 &  0.000228641 &  0.000228641 &  0.000230542 \\
White      & \swatch{FFFFFF} &  0.000088825 &  0.000954370 &  0.000954370 &  0.001026252 \\
\bottomrule
\end{tabular}
\end{table}
import importlib
import ex_color.vis.helpers
import utils.strings

importlib.reload(utils.strings)
importlib.reload(ex_color.vis.helpers)
viz = ex_color.vis.helpers.NbViz(nbid)

viz.plot_error_vs_vibrancy(ablation_results, power=score_fn.power)
viz.plot_error_vs_vibrancy(pruned_results, power=score_fn.power)
viz.plot_error_vs_vibrancy(suppression_results, power=2)
Scatter plot showing reconstruction error versus vibrancy. Each point represents a color, with its position on the x-axis indicating how vibrant (saturated and bright) it is, and its position on the y-axis indicating the reconstruction error (mean squared error) for that color. The points are colored according to their actual color values.
MSE,vib²·³ ablated: r = 0.96, R²: 0.93, p = 0
Scatter plot showing reconstruction error versus vibrancy. Each point represents a color, with its position on the x-axis indicating how vibrant (saturated and bright) it is, and its position on the y-axis indicating the reconstruction error (mean squared error) for that color. The points are colored according to their actual color values.
MSE,vib²·³ pruned: r = 0.96, R²: 0.93, p = 0
Scatter plot showing reconstruction error versus vibrancy. Each point represents a color, with its position on the x-axis indicating how vibrant (saturated and bright) it is, and its position on the y-axis indicating the reconstruction error (mean squared error) for that color. The points are colored according to their actual color values.
MSE,vib² suppression: r = 0.94, R²: 0.89, p = 0