Experiment 2.4.1: Soft intervention on red with color wheel
This is a re-run of Ex 2.4 with more mature tooling. See the earlier notebook for discussion.
from __future__ import annotations
nbid = '2.4.1'  # ID for tagging assets
nbname = 'Soft intervention on red with color wheel'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'
# Basic setup: Logging, Experiment (Modal)
import logging
import modal
from infra.requirements import uv_freeze, project_packages
from utils.logging import SimpleLoggingConfig
from ex_color.vis import NbViz
logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')
image = (
    modal.Image.debian_slim()
    .pip_install(*uv_freeze(all_groups=True, not_groups='dev'))
    .add_local_python_source(*project_packages())
)
volume = modal.Volume.from_name(f'{project}-{nbid}', create_if_missing=True, version=2)
app = modal.App(name=f'{project}-{nbid}', image=image, volumes={'/data': volume})
viz = NbViz(nbid)
import torch
from ex_color.loss import AngularAnchor, AxisAlignedSubspace, Separate, RegularizerConfig
K = 4  # bottleneck dimensionality
N = 1  # number of nonlinear layers
RED = (1, 0, 0, 0)
assert len(RED) == K
BATCH_SIZE = 64
CUBE_SUBDIVISIONS = 8
NUM_RUNS = 60  # to probe seed sensitivity
RUN_SEEDS = [i for i in range(NUM_RUNS)]
reg_separate = RegularizerConfig(
    name='separate',
    compute_loss_term=Separate(power=100.0, shift=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)
reg_anchor = RegularizerConfig(
    name='anchor',
    compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
    label_affinities={'red': 1.0},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
reg_subspace = RegularizerConfig(
    name='subspace',
    compute_loss_term=AxisAlignedSubspace((0, 1)),
    label_affinities={'vibrant': 1},
    # label_affinities={'primary': 1},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
from mini.temporal.dopesheet import Dopesheet
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
viz.tab_dopesheet(dopesheet)
viz.plot_dopesheet(dopesheet)
from torch.utils.data import DataLoader, RandomSampler
from ex_color.data.cube_dataset import prep_color_dataset, redness, stochastic_labels, exact_labels, vibrancy
def prep_train_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-corners',
        red=lambda c: redness(c) ** 8 * 0.08,
        vibrant=lambda c: vibrancy(c) ** 10 * 0.01,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )
def prep_val_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-centers',
        red=lambda c: redness(c) == 1,
        # vibrant=lambda c: vibrancy(c) == 1,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=2,
        collate_fn=exact_labels,
    )
from typing import Callable
import torch
import wandb
from ex_color.model import CNColorMLP
from ex_color.seed import set_deterministic_mode
from ex_color.workflow import train_model
from ex_color.evaluation import Result
from utils.time import hour
@app.function(
    cpu=1,
    max_containers=20,
    timeout=1 * hour,
    env={'WANDB_API_KEY': wandb.Api().api_key or ''},
)
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    *,
    seed: int,
    score_fn: Callable[[CNColorMLP], float],
):
    """Train the model with the given dopesheet and variant."""
    logging_config.apply()
    if seed is not None:
        set_deterministic_mode(seed)
    train_loader = prep_train_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    val_loader = prep_val_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    model = CNColorMLP(K, n_nonlinear=N)
    res = train_model(
        model,
        dopesheet,
        regularizers,
        train_loader,
        val_loader,
        experiment_name=experiment_name,
        project=project,
        hparams={'seed': seed},
    )
    score = score_fn(res.model)
    key = f'model-{res.id_}.pt'
    torch.save(res.model.state_dict(), f'/data/{key}')
    return Result(seed, key, res.url, res.summary, score)
from math import cos, radians
from ex_color.evaluation import EvaluationPlan, ScoreByHSVSimilarity
from ex_color.intervention import InterventionConfig, Suppression, BoundedFalloff
falloff = BoundedFalloff(
    cos(radians(90)),  # cos(max_angle)
    1,  # completely squash fully-aligned vectors
    # 2,  # soft rim, sharp hub
    0,
)
suppression = InterventionConfig(
    apply=Suppression(torch.tensor(RED), falloff),
    layer_affinities=['bottleneck'],
)
suppression_plan = EvaluationPlan(
    {'suppression'},
    lambda m: m,
    [suppression],
)
score_fn = ScoreByHSVSimilarity(suppression_plan, (0.0, 1.0, 1.0), power=2.0, cube_subdivisions=CUBE_SUBDIVISIONS)
import asyncio
# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
regularizers = [reg_separate, reg_anchor, reg_subspace]
async def sweep():
    logging_config.apply()
    workers = [train.remote.aio(dopesheet, regularizers, seed=seed, score_fn=score_fn) for seed in RUN_SEEDS]
    return await asyncio.gather(*workers)
with app.run():
    results = await sweep()
from IPython.display import display
from ex_color.evaluation import results_to_dataframe
runs_df = results_to_dataframe(results)
# Show min, max, mean, stddev of each column
log.info(f'Summary statistics for all {len(runs_df)} runs:')
display(runs_df.describe().loc[['min', 'max', 'mean', 'std']].style.format(precision=4))
print('Correlation of reconstruction error vs. similarity to anchor')
viz.plot_boxplot(runs_df['score'], ylabel='', xlim=(None, 1), tags=('score',))
print('Reconstruction loss')
viz.plot_boxplot(runs_df['val_recon'], ylabel='', log_scale=True, tags=('val_recon',))
print('Anchor loss')
viz.plot_boxplot(runs_df['val_anchor'], ylabel='', log_scale=True, tags=('val_anchor',))
I 949.5 no.2.4.1:Summary statistics for all 60 runs:
| seed | score | val_loss | labels/n/vibrant | val_recon | labels/n/_any | labels/n/red | labels/n_total | _runtime | val_anchor | |
|---|---|---|---|---|---|---|---|---|---|---|
| min | 0.0000 | 0.8821 | 0.0000 | 85.0000 | 0.0000 | 160.0000 | 64.0000 | 96064.0000 | 45.4041 | 0.0005 | 
| max | 59.0000 | 0.9910 | 0.0001 | 132.0000 | 0.0001 | 225.0000 | 104.0000 | 96064.0000 | 153.1740 | 0.0072 | 
| mean | 29.5000 | 0.9515 | 0.0000 | 107.9500 | 0.0000 | 189.7167 | 82.2167 | 96064.0000 | 68.9505 | 0.0029 | 
| std | 17.4642 | 0.0236 | 0.0000 | 10.6460 | 0.0000 | 13.8834 | 8.3951 | 0.0000 | 28.0112 | 0.0015 | 
Correlation of reconstruction error vs. similarity to anchor
 
Reconstruction loss
 
Anchor loss
 
Select the best runs from the Pareto front of non-dominated runs, optimizing for both validation loss and score.
from ex_color.evaluation import pareto_front
non_dominated = pareto_front(runs_df, minimize=['val_recon', 'val_anchor'], maximize=['score'])
log.info(f'Best of {len(non_dominated)} non-dominated runs (Pareto front):')
display(non_dominated.sort_values(by='score', ascending=False).head(5).style.format(precision=4, hyperlinks='html'))
I 951.5 no.2.4.1:Best of 4 non-dominated runs (Pareto front):
| seed | wandb url | score | val_loss | labels/n/vibrant | val_recon | labels/n/_any | labels/n/red | labels/n_total | _runtime | val_anchor | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 7 | 7 | https://wandb.ai/z0r/ex-preppy/runs/0kj9jztg | 0.9910 | 0.0000 | 100 | 0.0000 | 173 | 73 | 96064 | 52.0971 | 0.0006 | 
| 40 | 40 | https://wandb.ai/z0r/ex-preppy/runs/11g2ie4j | 0.9848 | 0.0000 | 107 | 0.0000 | 182 | 75 | 96064 | 118.5703 | 0.0020 | 
| 39 | 39 | https://wandb.ai/z0r/ex-preppy/runs/ce5iw8rf | 0.9780 | 0.0000 | 101 | 0.0000 | 182 | 81 | 96064 | 103.4187 | 0.0005 | 
| 30 | 30 | https://wandb.ai/z0r/ex-preppy/runs/aalvzihq | 0.9531 | 0.0000 | 99 | 0.0000 | 188 | 89 | 96064 | 71.2192 | 0.0034 | 
from typing import cast
from mini.data import load_checkpoint_from_volume
best_run = results[cast(int, non_dominated['score'].idxmax())]
log.info(f'Loading checkpoint of best run: seed={best_run.seed}, score={best_run.score:.4f} @ {best_run.url}')
model = CNColorMLP(K, n_nonlinear=N)
model = load_checkpoint_from_volume(model, volume, best_run.checkpoint_key)
I 951.5 no.2.4.1:Loading checkpoint of best run: seed=7, score=0.9910 @ https://wandb.ai/z0r/ex-preppy/runs/0kj9jztg
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % K if (a + 1) % K not in (a, b) else (a + 2) % K,
#     )
#     for a, b in combinations(tuple(range(K)), 2)
# ]
from ex_color.evaluation import TestSet
test_set = TestSet.create()
from IPython.display import clear_output
baseline_results = test_set.evaluate(model, [], tags={'baseline'})
clear_output()
viz.plot_cube(baseline_results)
# viz.plot_recon_loss(baseline_results)
# viz.plot_latent_space(
#     baseline_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
 
from math import cos, radians
from IPython.display import clear_output
from ex_color.intervention import Suppression, BoundedFalloff
falloff = BoundedFalloff(
    cos(radians(90)),  # cos(max_angle)
    1,  # completely squash fully-aligned vectors
    # 2,  # soft rim, sharp hub
    0,
)
suppression = InterventionConfig(
    apply=Suppression(torch.tensor(RED), falloff),
    layer_affinities=['bottleneck'],
)
suppression_results = test_set.evaluate(model, [suppression], tags={'suppression'})
clear_output()
viz.plot_cube(suppression_results)
# viz.plot_recon_loss(suppression_results)
# viz.plot_latent_space(
#     suppression_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
 
from math import cos, radians
from IPython.display import clear_output
from ex_color.intervention import Repulsion, LinearMapper
# mapper = FastBezierMapper(
#     cos(radians(90)),
#     cos(radians(60)),
# )
mapper = LinearMapper(
    cos(radians(90)),
    cos(radians(89)),
)
repulsion = InterventionConfig(
    Repulsion(torch.tensor(RED), mapper),
    layer_affinities=['bottleneck'],
)
repulsion_results = test_set.evaluate(model, [repulsion], tags={'repulsion'})
clear_output()
viz.plot_cube(repulsion_results)
# viz.plot_recon_loss(repulsion_results)
# viz.plot_latent_space(
#     repulsion_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
 
from ex_color.surgery import ablate
ablated_model = ablate(model, 'bottleneck', [0])
ablation_results = test_set.evaluate(ablated_model, [], tags={'ablated'})
clear_output()
viz.plot_cube(ablation_results)
# viz.plot_recon_loss(ablation_results)
# viz.plot_latent_space(
#     ablation_results.latent_cube,
#     tags=ablation_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
 
import numpy as np
from ex_color.vis.helpers import ThemedAnnotation
max_error = np.max(
    [
        baseline_results.loss_cube['MSE'],
        suppression_results.loss_cube['MSE'],
        repulsion_results.loss_cube['MSE'],
    ]
)
print('Baseline')
viz.plot_stacked_results(
    baseline_results,
    latent_dims=((1, 0, 2), (1, 3, 0)),
    max_error=max_error,
)
print('Suppression')
viz.plot_stacked_results(
    suppression_results,
    latent_dims=((1, 0, 2), (1, 3, 0)),
    max_error=max_error,
    latent_annotations=[
        ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - falloff.a), dashed=True),
    ],
)
print('Repulsion')
viz.plot_stacked_results(
    repulsion_results,
    latent_dims=((1, 0, 2), (1, 3, 0)),
    max_error=max_error,
    latent_annotations=[
        ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - mapper.a), dashed=True),
        ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - mapper.b), dashed=False),
    ],
)
print('Ablation')
viz.plot_stacked_results(
    ablation_results,
    latent_dims=((1, 0, 2), (1, 3, 0)),
    max_error=max_error,
)
Baseline
 
Suppression
 
Repulsion
 
Ablation
 
viz.tab_error_vs_color(baseline_results, suppression_results, repulsion_results, ablation_results)
viz.tab_error_vs_color_latex(baseline_results, suppression_results, repulsion_results, ablation_results)
| Name | RGB | Baseline | Suppression | Δ Sup | Repulsion | Δ Rep | Ablated | Δ Abl | 
|---|---|---|---|---|---|---|---|---|
| red | 0.001 | 0.284 | +0.284 | 0.333 | +0.333 | 0.333 | +0.333 | |
| orange | 0.000 | 0.154 | +0.154 | 0.135 | +0.135 | 0.138 | +0.137 | |
| yellow | 0.000 | 0.046 | +0.046 | 0.030 | +0.030 | 0.031 | +0.031 | |
| lime | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| green | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.035 | +0.035 | |
| teal | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.140 | +0.140 | |
| cyan | 0.001 | 0.001 | +0.000 | 0.001 | +0.000 | 0.342 | +0.341 | |
| azure | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.153 | +0.153 | |
| blue | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.034 | +0.034 | |
| purple | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| magenta | 0.000 | 0.049 | +0.048 | 0.033 | +0.033 | 0.035 | +0.034 | |
| pink | 0.000 | 0.158 | +0.158 | 0.144 | +0.144 | 0.148 | +0.148 | |
| black | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| dark gray | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| gray | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 | |
| light gray | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 | |
| white | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 
\begin{table}
\centering
\label{tab:placeholder}
\caption{Reconstruction error by color and intervention method}
\sisetup{
    round-mode = places,
    round-precision = 3,
    table-auto-round = true,
    % drop-zero-decimal = true,
}
\begin{tabular}{l c g g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Suppression}} & \multicolumn{1}{c}{{Repulsion}} & \multicolumn{1}{c}{{Ab}} \\
\midrule
Red        & \swatch{FF0000} &  0.000646436 &  0.283565700 &  0.332686901 &  0.332686901 \\
Orange     & \swatch{FF7F00} &  0.000063273 &  0.153662786 &  0.135243252 &  0.137482300 \\
Yellow     & \swatch{FFFF00} &  0.000411471 &  0.045807716 &  0.029817864 &  0.030973246 \\
Lime       & \swatch{7FFF00} &  0.000108202 &  0.000000000 &  0.000000000 &  0.000005554 \\
Green      & \swatch{00FF00} &  0.000220425 &  0.000000000 &  0.000000000 &  0.034675915 \\
Teal       & \swatch{00FF7F} &  0.000000988 &  0.000000000 &  0.000000000 &  0.140236348 \\
Cyan       & \swatch{00FFFF} &  0.001272683 &  0.000000000 &  0.000000000 &  0.341102749 \\
Azure      & \swatch{007FFF} &  0.000013033 &  0.000000000 &  0.000000000 &  0.152602151 \\
Blue       & \swatch{0000FF} &  0.000310250 &  0.000000000 &  0.000000000 &  0.034143761 \\
Purple     & \swatch{7F00FF} &  0.000052661 &  0.000000000 &  0.000000000 &  0.000002185 \\
Magenta    & \swatch{FF00FF} &  0.000248736 &  0.048393689 &  0.033117618 &  0.034419145 \\
Pink       & \swatch{FF007F} &  0.000200588 &  0.157892883 &  0.143878624 &  0.148011312 \\
Black      & \swatch{000000} &  0.000000000 &  0.000009886 &  0.000006748 &  0.000007794 \\
Dark gray  & \swatch{3F3F3F} &  0.000121933 &  0.000156107 &  0.000148776 &  0.000153554 \\
Gray       & \swatch{7F7F7F} &  0.000064245 & -0.000056821 & -0.000056655 & -0.000056814 \\
Light gray & \swatch{BFBFBF} &  0.000018872 & -0.000007299 & -0.000007730 & -0.000007258 \\
White      & \swatch{FFFFFF} &  0.000108910 & -0.000092652 & -0.000093251 & -0.000093862 \\
\bottomrule
\end{tabular}
\end{table}
viz.plot_error_vs_similarity(
    suppression_results,
    (0, 1, 1),
    anchor_name='red',
    power=2,
)
viz.plot_error_vs_similarity(
    repulsion_results,
    (0, 1, 1),
    anchor_name='red',
    power=2,
)
viz.plot_error_vs_similarity(
    ablation_results,
    (0, 1, 1),
    anchor_name='red',
    power=2,
)
 
MSE,sim² suppression: r = 1.00, R²: 0.99, p = 0
 
MSE,sim² repulsion: r = 0.99, R²: 0.98, p = 0
 
MSE,sim² ablated: r = 0.61, R²: 0.37, p = 0
