Experiment 2.4.1: Soft intervention on red with color wheel
This is a re-run of Ex 2.4 with more mature tooling. See the earlier notebook for discussion.
from __future__ import annotations
nbid = '2.4.1' # ID for tagging assets
nbname = 'Soft intervention on red with color wheel'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'
# Basic setup: Logging, Experiment (Modal)
import logging
import modal
from infra.requirements import uv_freeze, project_packages
from utils.logging import SimpleLoggingConfig
from ex_color.vis import NbViz
logging_config = (
SimpleLoggingConfig()
.info('notebook', 'utils', 'mini', 'ex_color')
.error('matplotlib.axes') # Silence warnings about set_aspect
)
logging_config.apply()
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')
image = (
modal.Image.debian_slim()
.pip_install(*uv_freeze(all_groups=True, not_groups='dev'))
.add_local_python_source(*project_packages())
)
volume = modal.Volume.from_name(f'{project}-{nbid}', create_if_missing=True, version=2)
app = modal.App(name=f'{project}-{nbid}', image=image, volumes={'/data': volume})
viz = NbViz(nbid)
import torch
from ex_color.loss import AngularAnchor, AxisAlignedSubspace, Separate, RegularizerConfig
K = 4 # bottleneck dimensionality
N = 1 # number of nonlinear layers
RED = (1, 0, 0, 0)
assert len(RED) == K
BATCH_SIZE = 64
CUBE_SUBDIVISIONS = 8
NUM_RUNS = 60 # to probe seed sensitivity
RUN_SEEDS = [i for i in range(NUM_RUNS)]
reg_separate = RegularizerConfig(
name='separate',
compute_loss_term=Separate(power=100.0, shift=True),
label_affinities=None,
layer_affinities=['bottleneck'],
)
reg_anchor = RegularizerConfig(
name='anchor',
compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
label_affinities={'red': 1.0},
layer_affinities=['bottleneck'],
phase=('train', 'validate'),
)
reg_subspace = RegularizerConfig(
name='subspace',
compute_loss_term=AxisAlignedSubspace((0, 1)),
label_affinities={'vibrant': 1},
# label_affinities={'primary': 1},
layer_affinities=['bottleneck'],
phase=('train', 'validate'),
)
from mini.temporal.dopesheet import Dopesheet
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
viz.tab_dopesheet(dopesheet)
viz.plot_dopesheet(dopesheet)
from torch.utils.data import DataLoader, RandomSampler
from ex_color.data.cube_dataset import prep_color_dataset, redness, stochastic_labels, exact_labels, vibrancy
def prep_train_data(training_subs: int, *, batch_size: int) -> DataLoader:
dataset = prep_color_dataset(
training_subs,
sample_at='cell-corners',
red=lambda c: redness(c) ** 8 * 0.08,
vibrant=lambda c: vibrancy(c) ** 10 * 0.01,
)
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=4,
sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
collate_fn=stochastic_labels,
)
def prep_val_data(training_subs: int, *, batch_size: int) -> DataLoader:
dataset = prep_color_dataset(
training_subs,
sample_at='cell-centers',
red=lambda c: redness(c) == 1,
# vibrant=lambda c: vibrancy(c) == 1,
)
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=2,
collate_fn=exact_labels,
)
from typing import Callable
import torch
import wandb
from ex_color.model import CNColorMLP
from ex_color.seed import set_deterministic_mode
from ex_color.workflow import train_model
from ex_color.evaluation import Result
from utils.time import hour
@app.function(
cpu=1,
max_containers=20,
timeout=1 * hour,
env={'WANDB_API_KEY': wandb.Api().api_key or ''},
)
async def train(
dopesheet: Dopesheet,
regularizers: list[RegularizerConfig],
*,
seed: int,
score_fn: Callable[[CNColorMLP], float],
):
"""Train the model with the given dopesheet and variant."""
logging_config.apply()
if seed is not None:
set_deterministic_mode(seed)
train_loader = prep_train_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
val_loader = prep_val_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
model = CNColorMLP(K, n_nonlinear=N)
res = train_model(
model,
dopesheet,
regularizers,
train_loader,
val_loader,
experiment_name=experiment_name,
project=project,
hparams={'seed': seed},
)
score = score_fn(res.model)
key = f'model-{res.id_}.pt'
torch.save(res.model.state_dict(), f'/data/{key}')
return Result(seed, key, res.url, res.summary, score)
from math import cos, radians
from ex_color.evaluation import EvaluationPlan, ScoreByHSVSimilarity
from ex_color.intervention import InterventionConfig, Suppression, BoundedFalloff
falloff = BoundedFalloff(
cos(radians(90)),
1, # completely squash aligned vectors
0, # constant effect (no fall-off)
)
suppression = InterventionConfig(
apply=Suppression(torch.tensor(RED), falloff),
layer_affinities=['bottleneck'],
)
suppression_plan = EvaluationPlan(
{'suppression'},
lambda m: m,
[suppression],
)
score_fn = ScoreByHSVSimilarity(suppression_plan, (0.0, 1.0, 1.0), power=2.0, cube_subdivisions=CUBE_SUBDIVISIONS)
import asyncio
# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
regularizers = [reg_separate, reg_anchor, reg_subspace]
async def sweep():
logging_config.apply()
workers = [train.remote.aio(dopesheet, regularizers, seed=seed, score_fn=score_fn) for seed in RUN_SEEDS]
return await asyncio.gather(*workers)
with app.run():
results = await sweep()
from IPython.display import display
from ex_color.evaluation import results_to_dataframe
runs_df = results_to_dataframe(results)
# Show min, max, mean, stddev of each column
log.info(f'Summary statistics for all {len(runs_df)} runs:')
display(runs_df.describe().loc[['min', 'max', 'mean', 'std']].style.format(precision=4))
print('Correlation of reconstruction error vs. similarity to anchor')
viz.plot_boxplot(runs_df['score'], ylabel='', xlim=(None, 1), tags=('score',))
print('Reconstruction loss')
viz.plot_boxplot(runs_df['val_recon'], ylabel='', log_scale=True, tags=('val_recon',))
print('Anchor loss')
viz.plot_boxplot(runs_df['val_anchor'], ylabel='', log_scale=True, tags=('val_anchor',))
I 477.4 no.2.4.1:Summary statistics for all 60 runs:
| seed | score | val_recon | labels/n/_any | val_loss | labels/n_total | labels/n/red | _runtime | val_anchor | labels/n/vibrant | |
|---|---|---|---|---|---|---|---|---|---|---|
| min | 0.0000 | 0.8821 | 0.0000 | 160.0000 | 0.0000 | 96064.0000 | 64.0000 | 47.8293 | 0.0005 | 85.0000 |
| max | 59.0000 | 0.9910 | 0.0001 | 225.0000 | 0.0001 | 96064.0000 | 104.0000 | 150.0075 | 0.0072 | 132.0000 |
| mean | 29.5000 | 0.9515 | 0.0000 | 189.7167 | 0.0000 | 96064.0000 | 82.2167 | 73.2835 | 0.0029 | 107.9500 |
| std | 17.4642 | 0.0236 | 0.0000 | 13.8834 | 0.0000 | 0.0000 | 8.3951 | 24.8543 | 0.0015 | 10.6460 |
Correlation of reconstruction error vs. similarity to anchor
Reconstruction loss
Anchor loss
Select the best runs from the Pareto front of non-dominated runs, optimizing for both validation loss and score.
from ex_color.evaluation import pareto_front
non_dominated = pareto_front(runs_df, minimize=['val_recon', 'val_anchor'], maximize=['score'])
log.info(f'Best of {len(non_dominated)} non-dominated runs (Pareto front):')
display(non_dominated.sort_values(by='score', ascending=False).head(5).style.format(precision=4, hyperlinks='html'))
I 479.4 no.2.4.1:Best of 4 non-dominated runs (Pareto front):
| seed | wandb url | score | val_recon | labels/n/_any | val_loss | labels/n_total | labels/n/red | _runtime | val_anchor | labels/n/vibrant | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 7 | 7 | https://wandb.ai/z0r/ex-preppy/runs/lgikdd2e | 0.9910 | 0.0000 | 173 | 0.0000 | 96064 | 73 | 122.0421 | 0.0006 | 100 |
| 40 | 40 | https://wandb.ai/z0r/ex-preppy/runs/cq6fz0fz | 0.9848 | 0.0000 | 182 | 0.0000 | 96064 | 75 | 57.5879 | 0.0020 | 107 |
| 39 | 39 | https://wandb.ai/z0r/ex-preppy/runs/rk1yiphh | 0.9780 | 0.0000 | 182 | 0.0000 | 96064 | 81 | 61.5974 | 0.0005 | 101 |
| 30 | 30 | https://wandb.ai/z0r/ex-preppy/runs/wa4rvule | 0.9531 | 0.0000 | 188 | 0.0000 | 96064 | 89 | 68.8661 | 0.0034 | 99 |
from typing import cast
from mini.data import load_checkpoint_from_volume
best_run = results[cast(int, non_dominated['score'].idxmax())]
log.info(f'Loading checkpoint of best run: seed={best_run.seed}, score={best_run.score:.4f} @ {best_run.url}')
model = CNColorMLP(K, n_nonlinear=N)
model = load_checkpoint_from_volume(model, volume, best_run.checkpoint_key)
I 479.5 no.2.4.1:Loading checkpoint of best run: seed=7, score=0.9910 @ https://wandb.ai/z0r/ex-preppy/runs/lgikdd2e
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
# (
# b,
# a,
# (a + 1) % K if (a + 1) % K not in (a, b) else (a + 2) % K,
# )
# for a, b in combinations(tuple(range(K)), 2)
# ]
from ex_color.evaluation import TestSet
test_set = TestSet.create()
from IPython.display import clear_output
baseline_results = test_set.evaluate(model, [], tags={'baseline'})
clear_output()
viz.plot_cube(baseline_results)
# viz.plot_recon_loss(baseline_results)
# viz.plot_latent_space(
# baseline_results,
# dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
from math import cos, radians
from IPython.display import clear_output
from ex_color.intervention import Suppression, BoundedFalloff
falloff = BoundedFalloff(
cos(radians(90)), # cos(max_angle)
1, # completely squash fully-aligned vectors
# 2, # soft rim, sharp hub
0,
)
suppression = InterventionConfig(
apply=Suppression(torch.tensor(RED), falloff),
layer_affinities=['bottleneck'],
)
suppression_results = test_set.evaluate(model, [suppression], tags={'suppression'})
clear_output()
viz.plot_cube(suppression_results)
# viz.plot_recon_loss(suppression_results)
# viz.plot_latent_space(
# suppression_results,
# dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
from math import cos, radians
from IPython.display import clear_output
from ex_color.intervention import Repulsion, LinearMapper
# mapper = FastBezierMapper(
# cos(radians(90)),
# cos(radians(60)),
# )
mapper = LinearMapper(
cos(radians(90)),
cos(radians(89)),
)
repulsion = InterventionConfig(
Repulsion(torch.tensor(RED), mapper),
layer_affinities=['bottleneck'],
)
repulsion_results = test_set.evaluate(model, [repulsion], tags={'repulsion'})
clear_output()
viz.plot_cube(repulsion_results)
# viz.plot_recon_loss(repulsion_results)
# viz.plot_latent_space(
# repulsion_results,
# dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
from ex_color.surgery import ablate
ablated_model = ablate(model, 'bottleneck', [0])
ablation_results = test_set.evaluate(ablated_model, [], tags={'ablated'})
clear_output()
viz.plot_cube(ablation_results)
# viz.plot_recon_loss(ablation_results)
# viz.plot_latent_space(
# ablation_results.latent_cube,
# tags=ablation_results.tags,
# dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )
import numpy as np
from ex_color.vis.helpers import ThemedAnnotation
max_error = np.max(
[
baseline_results.loss_cube['MSE'],
suppression_results.loss_cube['MSE'],
repulsion_results.loss_cube['MSE'],
]
)
print('Baseline')
viz.plot_stacked_results(
baseline_results,
latent_dims=((1, 0, 2), (1, 3, 0)),
max_error=max_error,
)
print('Suppression')
viz.plot_stacked_results(
suppression_results,
latent_dims=((1, 0, 2), (1, 3, 0)),
max_error=max_error,
latent_annotations=[
ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - falloff.a), dashed=True),
],
)
print('Repulsion')
viz.plot_stacked_results(
repulsion_results,
latent_dims=((1, 0, 2), (1, 3, 0)),
max_error=max_error,
latent_annotations=[
ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - mapper.a), dashed=True),
ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - mapper.b), dashed=False),
],
)
print('Ablation')
viz.plot_stacked_results(
ablation_results,
latent_dims=((1, 0, 2), (1, 3, 0)),
max_error=max_error,
)
Baseline
Suppression
Repulsion
Ablation
viz.tab_error_vs_color(baseline_results, suppression_results, repulsion_results, ablation_results)
viz.tab_error_vs_color_latex(baseline_results, suppression_results, repulsion_results, ablation_results)
| Name | RGB | Baseline | Suppression | Δ Sup | Repulsion | Δ Rep | Ablated | Δ Abl |
|---|---|---|---|---|---|---|---|---|
| red | 0.001 | 0.284 | +0.284 | 0.333 | +0.333 | 0.333 | +0.333 | |
| orange | 0.000 | 0.154 | +0.154 | 0.135 | +0.135 | 0.138 | +0.137 | |
| yellow | 0.000 | 0.046 | +0.046 | 0.030 | +0.030 | 0.031 | +0.031 | |
| lime | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| green | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.035 | +0.035 | |
| teal | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.140 | +0.140 | |
| cyan | 0.001 | 0.001 | +0.000 | 0.001 | +0.000 | 0.342 | +0.341 | |
| azure | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.153 | +0.153 | |
| blue | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.034 | +0.034 | |
| purple | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| magenta | 0.000 | 0.049 | +0.048 | 0.033 | +0.033 | 0.035 | +0.034 | |
| pink | 0.000 | 0.158 | +0.158 | 0.144 | +0.144 | 0.148 | +0.148 | |
| black | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| dark gray | 0.000 | 0.000 | +0.000 | 0.000 | +0.000 | 0.000 | +0.000 | |
| gray | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 | |
| light gray | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 | |
| white | 0.000 | 0.000 | -0.000 | 0.000 | -0.000 | 0.000 | -0.000 |
\begin{table}
\centering
\label{tab:placeholder}
\caption{Reconstruction error by color and intervention method}
\sisetup{
round-mode = places,
round-precision = 3,
table-auto-round = true,
% drop-zero-decimal = true,
}
\begin{tabular}{l c g g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Suppression}} & \multicolumn{1}{c}{{Repulsion}} & \multicolumn{1}{c}{{Ab}} \\
\midrule
Red & \swatch{FF0000} & 0.000646441 & 0.283565402 & 0.332686901 & 0.332686901 \\
Orange & \swatch{FF7F00} & 0.000063273 & 0.153662667 & 0.135243282 & 0.137482330 \\
Yellow & \swatch{FFFF00} & 0.000411471 & 0.045807652 & 0.029817866 & 0.030973202 \\
Lime & \swatch{7FFF00} & 0.000108202 & 0.000000000 & 0.000000000 & 0.000005553 \\
Green & \swatch{00FF00} & 0.000220426 & 0.000000000 & 0.000000000 & 0.034675885 \\
Teal & \swatch{00FF7F} & 0.000000987 & 0.000000000 & 0.000000000 & 0.140236422 \\
Cyan & \swatch{00FFFF} & 0.001272679 & 0.000000000 & 0.000000000 & 0.341102839 \\
Azure & \swatch{007FFF} & 0.000013033 & 0.000000000 & 0.000000000 & 0.152602226 \\
Blue & \swatch{0000FF} & 0.000310250 & 0.000000000 & 0.000000000 & 0.034143779 \\
Purple & \swatch{7F00FF} & 0.000052661 & 0.000000000 & 0.000000000 & 0.000002188 \\
Magenta & \swatch{FF00FF} & 0.000248737 & 0.048393659 & 0.033117589 & 0.034419138 \\
Pink & \swatch{FF007F} & 0.000200588 & 0.157892689 & 0.143878400 & 0.148011118 \\
Black & \swatch{000000} & 0.000000000 & 0.000009886 & 0.000006749 & 0.000007795 \\
Dark gray & \swatch{3F3F3F} & 0.000121935 & 0.000156103 & 0.000148772 & 0.000153551 \\
Gray & \swatch{7F7F7F} & 0.000064246 & -0.000056821 & -0.000056655 & -0.000056814 \\
Light gray & \swatch{BFBFBF} & 0.000018871 & -0.000007298 & -0.000007729 & -0.000007257 \\
White & \swatch{FFFFFF} & 0.000108906 & -0.000092649 & -0.000093249 & -0.000093859 \\
\bottomrule
\end{tabular}
\end{table}
viz.plot_error_vs_similarity(
suppression_results,
(0, 1, 1),
anchor_name='red',
power=2,
)
viz.plot_error_vs_similarity(
repulsion_results,
(0, 1, 1),
anchor_name='red',
power=2,
)
viz.plot_error_vs_similarity(
ablation_results,
(0, 1, 1),
anchor_name='red',
power=2,
)
MSE,sim² suppression: r = 1.00, R²: 0.99, p = 0
MSE,sim² repulsion: r = 0.99, R²: 0.98, p = 0
MSE,sim² ablated: r = 0.61, R²: 0.37, p = 0