Experiment 2.4: Post-norm regularization

In Ex 2.3, we applied interventions to an autoencoder that included regularization and an explicit normalization step. That performed significantly better than earlier model versions that relied solely on regularization. That experiment (2.3) had all regularizers immediately before the normalization to try to get the latent space into a fairly good shape before normalizing. In this experiment, we'll see what happens if we apply some of the regularizers after normalization. Only the unitarity regularizer should need to run beforehand.

The model will be the same as the linear output variant from 2.3, i.e. without sigmoid applied to the output.

Hypothesis

The separate, planarity, and anchor regularizers are (approximately) angular. If we move them to run after the explicit normalization, they will be more effective at shaping latent space. We should see lower reconstruction loss overall and a more regular structure in latent space.

from __future__ import annotations

nbid = '2.4'  # ID for tagging assets
nbname = 'Post-norm regularization'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

Regularizers

Like Ex 2.3:

  • Anchor: pins red to $(1,0,0,0)$
  • Separate: angular repulsion to reduce global clumping (applied within each batch)
  • Planarity: pulls vibrant hues to the $[0, 1]$ plane
  • Unitarity: pulls all embeddings to the surface of the unit hypersphere, i.e. it makes the embedding vectors have unit length.

Like 2.3, unitarity is applied to the output of the encoder. But unlike 2.3:

  • All other regularizers are applied after the activations are explitly normalized.
import torch

from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import Anchor, Separate, Planarity, Unitarity, RegularizerConfig

from ex_color.training import TrainingModule

RED = (1, 0, 0, 0)

ALL_REGULARIZERS = [
    RegularizerConfig(
        name='reg-unit',
        compute_loss_term=Unitarity(),
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-anchor',
        compute_loss_term=Anchor(torch.tensor(RED, dtype=torch.float32)),
        label_affinities={'red': 1.0},
        layer_affinities=['bottleneck'],
    ),
    RegularizerConfig(
        name='reg-separate',
        compute_loss_term=Separate(power=100.0, shift=True),
        label_affinities=None,
        layer_affinities=['bottleneck'],
    ),
    RegularizerConfig(
        name='reg-planar',
        compute_loss_term=Planarity(),
        label_affinities={'vibrant': 1.0},
        layer_affinities=['bottleneck'],
    ),
]

Data

Data is the same as last time:

  • Train: an HSV cube (of RGB values)
  • Test: an RGB cube
from functools import partial
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic
from ex_color.labelling import collate_with_generated_labels


def prep_data() -> tuple[DataLoader, Tensor]:
    """
    Prepare data for training.

    Returns: (train, val)
    """
    hsv_cube = ColorCube.from_hsv(
        h=arange_cyclic(step_size=10 / 360),
        s=np.linspace(0, 1, 10),
        v=np.linspace(0, 1, 10),
    )
    hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
    hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)

    labeller = partial(
        collate_with_generated_labels,
        soft=False,  # Use binary labels (stochastic) to simulate the labelling of internet text
        red=0.5,
        vibrant=0.5,
    )
    # Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
    hsv_loader = DataLoader(
        hsv_dataset,
        batch_size=64,
        num_workers=2,
        sampler=WeightedRandomSampler(
            weights=hsv_cube.bias.flatten().tolist(),
            num_samples=len(hsv_dataset),
            replacement=True,
        ),
        collate_fn=labeller,
    )

    rgb_cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 8),
        g=np.linspace(0, 1, 8),
        b=np.linspace(0, 1, 8),
    )
    rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    return hsv_loader, rgb_tensor

Training

Like in Ex 2.2, the model is trained with PyTorch Lightning, with regularizers applied as custom hooks.

import wandb
from ex_color.inference import InferenceModule
from ex_color.intervention.intervention import InterventionConfig
from ex_color.model import CNColorMLP


# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
) -> CNColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    seed = 0
    set_deterministic_mode(seed)

    hsv_loader, _ = prep_data()

    model = CNColorMLP(4)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    logger = WandbLogger(experiment_name, project=project)

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        logger=logger,
    )

    print(f'max_steps: {len(dopesheet)}, hsv_loader length: {len(hsv_loader)}')

    # Train the model
    try:
        trainer.fit(training_module, hsv_loader)
    finally:
        wandb.finish()
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model


async with run():
    model = await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS)
I 5.6 no.2.4:  Training with: ['reg-unit', 'reg-anchor', 'reg-separate', 'reg-planar']
INFO: Seed set to 0
I 5.6 li.fa.ut.se:Seed set to 0
I 5.6 ex.se:   PyTorch set to deterministic mode
INFO: GPU available: False, used: False
I 5.7 li.py.ut.ra:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
I 5.7 li.py.ut.ra:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
I 5.7 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 3001, hsv_loader length: 57
wandb: Currently logged in as: z0r to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
creating run (0.5s)
Tracking run with wandb version 0.21.0
Run data is saved locally in ./wandb/run-20250905_064851-8qu5myaj
Syncing run Ex 2.4: Post-norm regularization to Weights & Biases (docs)
View project at https://wandb.ai/z0r/ex-color-transformer
View run at https://wandb.ai/z0r/ex-color-transformer/runs/8qu5myaj
0.0245
0.0062
0.0032
0.0070
0.0060
0.0018
0.0020
0.0002
0.0000
Training: 100.0% [3001/3001] [00:19/<00:00, 156.44 it/s]
v_num
train_loss
myaj
3.036e-05
Starting phase: Train
INFO: `Trainer.fit` stopped: `max_steps=3001` reached.
I 27.0 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=3001` reached.


Run history:


epoch▁▁▁▁▂▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
train_loss█▄▄▆▃▄▃▃▂▄▆▂▆▄▃▃▄▂▃▄▂▂▂▃▃▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁
train_recon█▂▂▂▂▁▁▁▁▂▂▃▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-anchor▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▂▁▁▁
train_reg-planar▁▅▁▂█▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
train_reg-separate█▆▆██▆▆▄▅▃▃▄▅▁▃▂▄▁▂▅▄▁▄▃▂▂▁▂▃▃▂▄▁▃▂▃▂▃▄▄
train_reg-unit█▆▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


Run summary:


epoch52
train_loss3e-05
train_recon3e-05
train_reg-anchor0
train_reg-planar0.05332
train_reg-separate0.43601
train_reg-unit0.00924
trainer/global_step2999


View run Ex 2.4: Post-norm regularization at: https://wandb.ai/z0r/ex-color-transformer/runs/8qu5myaj
View project at: https://wandb.ai/z0r/ex-color-transformer
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20250905_064851-8qu5myaj/logs

The charts and loss values from training look much the same as last time.

  • Roughly the same shape overall
  • All loss values slightly higher

Inference ('test-time')

We wrap the model that we trained above in an InferenceModule, which knows how to apply our interventions.

# @run.thither
async def infer(
    model: CNColorMLP,
    interventions: list[InterventionConfig],
    test_data: Tensor,
) -> Tensor:
    """Run inference with the given model and interventions."""
    import lightning as L

    inference_module = InferenceModule(model, interventions)
    trainer = L.Trainer(
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=True,
    )
    reconstructed_colors_batches = trainer.predict(
        inference_module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert reconstructed_colors_batches is not None
    # Flatten the list of batches to a single list of tensors
    reconstructed_colors = [item for batch in reconstructed_colors_batches for item in batch]
    # Reshape to match input
    return torch.cat(reconstructed_colors).reshape(test_data.shape)

Let's see how well the model reconstructs colors without any interventions.

from IPython.display import clear_output

from utils.nb import displayer_mpl
from ex_color.vis import plot_colors


hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 24),
    s=np.linspace(0, 1, 4),
    v=np.linspace(0, 1, 8),
).permute('svh')
x_hsv = torch.tensor(hsv_cube.rgb_grid, dtype=torch.float32)

hd_hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 240),
    s=np.linspace(0, 1, 48),
    v=np.linspace(0, 1, 48),
)
hd_x_hsv = torch.tensor(hd_hsv_cube.rgb_grid, dtype=torch.float32)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 20),
    g=np.linspace(0, 1, 20),
    b=np.linspace(0, 1, 20),
)
x_rgb = torch.tensor(rgb_cube.rgb_grid, dtype=torch.float32)

with displayer_mpl(
    f'large-assets/ex-{nbid}-true-colors.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
) as show:
    show(lambda: plot_colors(hsv_cube, title='True colors', colors=x_hsv.numpy()))
Plot showing four slices of the HSV cube, titled "True colors · V vs H by S". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.
from IPython.display import clear_output
from torch.nn import functional as F

from ex_color.vis import plot_colors, plot_cube_series

interventions = []
y_hsv = await infer(model, interventions, x_hsv)
hd_y_hsv = await infer(model, interventions, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-no-intervention.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but some slight differences are visible; for example, "white" is slightly gray, and many of the fully-saturated colors are less saturated than they should be.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors without intervention',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-no-intervention.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. The range of loss values is small, but there are two notable peaks at all primary and secondary colors (red, yellow, green, etc.).""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            title='Reconstruction error · no intervention',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Plot showing four slices of the HSV cube, titled "Predicted colors without intervention · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but some slight differences are visible; for example, "white" is slightly gray, and many of the fully-saturated colors are less saturated than they should be.
Line chart showing loss per color, titled "Reconstruction error · no intervention". Y-axis: mean square error, ranging from zero to 0.0011. X-axis: hue. The range of loss values is small, but there are two notable peaks at all primary and secondary colors (red, yellow, green, etc.).
Max loss: 0.0011
Median MSE: 1.5e-05

That's a good baseline: visually similar to 2.3 (linear variant), with a slightly lower maximum reconstruction loss. One notable feature is that the primary and secondary colors are more regular in their loss spikes: they all have roughly the same peak loss, whereas last time green and blue were worse than the others. I expect this means latent space will be slightly more globally regular.

There's an odd and oddly straight line for a set of very dark colors — maybe black — at around 4e-4 (compared to the median of 1.5e-5). I'm not sure what is happening there, but these numbers are all pretty small so I'll ignore it for now.

# Capture encoder activations (latents) without interventions
import torch
import numpy as np

from ex_color.inference import InferenceModule


# Build a tiny helper that runs predict while capturing latents from 'encoder'
async def infer_with_latent_capture(
    model: CNColorMLP, interventions: list[InterventionConfig], test_data: Tensor, layer_name: str = 'bottleneck'
) -> tuple[Tensor, Tensor]:
    module = InferenceModule(model, interventions, capture_layers=[layer_name])
    import lightning as L

    trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False)
    batches = trainer.predict(
        module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert batches is not None
    preds = [item for batch in batches for item in batch]
    y = torch.cat(preds).reshape(test_data.shape)
    # Read captured activations as a flat [N, D] tensor
    latents = module.read_captured(layer_name)
    return y, latents
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(model, [], x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-no-intervention.png',
    alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a color wheel, with the full set of vibrant colors around the rim (like a rainbow), varying to black in the center. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom, and red in the center. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. In this plot the true and reconstructor colors agree fairly well, but slight differences can be seen if you look closely.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · no intervention',
            dims=[(1, 0, 2), (1, 2, 0), (1, 3, 0)],
            dot_radius=10,
            theme=theme,
        )
    )
Three spherical plots, titled "Latents · no intervention". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a color wheel, with the full set of vibrant colors around the rim (like a rainbow), varying to black in the center. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom, and red in the center. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. In this plot the true and reconstructor colors agree fairly well, but slight differences can be seen if you look closely.

Ah, here we see why the peak losses for primary and secondary colors were more uniform: the middle plot shows that they have been positioned more uniformly around the hub of the dome, i.e. they are more planar. That suggests that the planarity regularizer was more effective, having been applied after the normalization. Or rather, given that the overall planarity loss was not lower than last time, perhaps it's just more uniform (less variance).

Also: the middle plot in 2.3 had a "tail" of dark colors extending down from the middle. That's absent in these plots. On dimensions $(1,2)$, the latent embeddings appear cleanly hemispherical.

Suppression

Now that we have our model, let's try suppressing red. We'll use the Suppression function developed in Ex 2.1.

from math import cos, pi

import torch

from ex_color.intervention import BoundedFalloff, InterventionConfig, Suppression
from ex_color.vis import plot_colors, plot_cube_series


suppression = Suppression(
    torch.tensor(RED, dtype=torch.float32),  # Constant!
    BoundedFalloff(
        0,  # within 60°
        1,  # completely squash fully-aligned vectors
        2,  # soft rim, sharp hub
    ),
)

interventions = [InterventionConfig(suppression, ['bottleneck'])]
y_hsv = await infer(model, interventions, x_hsv)
hd_y_hsv = await infer(model, interventions, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-suppression.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and nearby colors are clearly different: red itself appears as middle-gray, and the surrounding colors up to orange and pink look washed out. "Red-orange" actually appears to be green, moreso even than yellow (which is geometrically closer to green).""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors with suppression',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-suppression.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is a significant peak at red at either end of the X-axis, gradually sloping down to lower loss values near yellow and pink. Two very small peaks are at green and blue (apparently around 1% of the height of the peaks at red).""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            title='Reconstruction error · suppression',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Plot showing four slices of the HSV cube, titled "Predicted colors with suppression · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and nearby colors are clearly different: red itself appears as middle-gray, and the surrounding colors up to orange and pink look washed out. "Red-orange" actually appears to be green, moreso even than yellow (which is geometrically closer to green).
Line chart showing loss per color, titled "Reconstruction error · suppression". Y-axis: mean square error, ranging from zero to 0.16. X-axis: hue. There is a significant peak at red at either end of the X-axis, gradually sloping down to lower loss values near yellow and pink. Two very small peaks are at green and blue (apparently around 1% of the height of the peaks at red).
Max loss: 0.16
Median MSE: 2e-05

This looks great: the reconstructed colors look similar to the intervened ones in in 2.3, but the loss curves are even smoother: they show less crossing-over-each-other near red — which I think should mean that the intervention has introduced less high-frequency perturbation, and should have fewer unwanted downstream effects. They also rise steadily to a peak at red, whereas the ones in 2.3 tended to flatten out near the top. The peak is what we should expect, given that the falloff function is quadratic. I think it may be that these results are sharper because red is more precicely aligned with $1,0,0,0$.

from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d, ConicalAnnotation

# Capture latents with suppression
y_rgb, h_rgb = await infer_with_latent_capture(model, interventions, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-suppression.png',
    alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a partial color wheel, with  vibrant colors around the rim (like a rainbow), with with a conspicuously absent space at the top where "red" should be. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom, and warm-but-not-red colors in the center (where "red" should be). Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructor colors agree fairly well, even for the warmer colors. "Red" and nearby colors are in fact not visible, being buried somewhere inside the sphere.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · suppression',
            dims=[(1, 0, 2), (1, 2, 0), (1, 3, 0)],
            dot_radius=10,
            theme=theme,
            annotations=[
                ConicalAnnotation(
                    RED,
                    2 * (np.pi / 2 - suppression.falloff.a),  # type: ignore
                    color=theme.val('black', dark='#fffa'),
                    linewidth=theme.val(0.5, dark=1),
                    dashes=theme.val((8, 8), dark=(4, 4)),
                    gapcolor=theme.val('#fff4', dark='#0004'),
                ),
            ],
        )
    )
Three spherical plots, titled "Latents · suppression". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a partial color wheel, with  vibrant colors around the rim (like a rainbow), with with a conspicuously absent space at the top where "red" should be. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom, and warm-but-not-red colors in the center (where "red" should be). Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructor colors agree fairly well, even for the warmer colors. "Red" and nearby colors are in fact not visible, being buried somewhere inside the sphere.

Again, the effect of suppression on the latent space looks really good — visually, around as good as 2.3.

Repulsion

from math import cos, pi

import torch

from ex_color.intervention import FastBezierMapper, InterventionConfig, Repulsion
from ex_color.vis import plot_colors, plot_cube_series

repulsion = Repulsion(
    torch.tensor([1, 0, 0, 0], dtype=torch.float32),  # Constant!
    FastBezierMapper(
        0,  # Constrain effect to within 60° cone
        cos(pi / 3),  # Create 30° hole (negative cone) around concept vector
    ),
)


interventions = [InterventionConfig(repulsion, ['bottleneck'])]
y_hsv = await infer(model, interventions, x_hsv)
hd_y_hsv = await infer(model, interventions, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-repulsion.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and nearby colors are clearly different, and different again from how the suppression intervention looked: "red" itself appears as pink or hot pink, and the surrounding colors up to orange and pink look shifted. "Red-orange" actually appears to be fully-saturated yellow. Overall the effect is as though the nearby colors have bled into the neighborhood of red.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors with repulsion',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )


per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-repulsion.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is a significant peak at red at either end of the X-axis, gradually sloping down to lower loss values near yellow and pink. Two very small peaks are at green and blue (apparently less than 1% of the height of the peaks at red).""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            title='Reconstruction error · suppression',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Plot showing four slices of the HSV cube, titled "Predicted colors with repulsion · V vs H by S". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and nearby colors are clearly different, and different again from how the suppression intervention looked: "red" itself appears as pink or hot pink, and the surrounding colors up to orange and pink look shifted. "Red-orange" actually appears to be fully-saturated yellow. Overall the effect is as though the nearby colors have bled into the neighborhood of red.
Line chart showing loss per color, titled "Reconstruction error · suppression". Y-axis: mean square error, ranging from zero to 0.094. X-axis: hue. There is a significant peak at red at either end of the X-axis, gradually sloping down to lower loss values near yellow and pink. Two very small peaks are at green and blue (apparently less than 1% of the height of the peaks at red).
Max loss: 0.094
Median MSE: 2.5e-05

Again, visually, the colors look similar to 2.3. And again, the loss curves look similar but more regular (less crossing over). They also rise to more of a peak than 2.3, although they are a bit flat on the purple side of red.

from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d, ConicalAnnotation

# Capture latents with repulsion
y_rgb, h_rgb = await infer_with_latent_capture(model, interventions, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-repulsion.png',
    alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a partial color wheel, with  vibrant colors around the rim (like a rainbow), with with a conspicuously absent space at the top where "red" should be. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom. The central region of the second and third plots show something interesting: "Red" and nearby colors have been arranged into a wide ring or disc, rather than being clustered in the center. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructor colors agree fairly well, except for colors close to "red", which roughly agree in saturation but differ significantly in tone and hue.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · repulsion',
            dims=[(1, 0, 2), (1, 2, 0), (1, 3, 0)],
            dot_radius=10,
            theme=theme,
            annotations=[
                ConicalAnnotation(
                    RED,
                    2 * (np.pi / 2 - repulsion.mapper.a),  # type: ignore
                    color=theme.val('black', dark='#fffa'),
                    linewidth=theme.val(0.5, dark=1),
                    dashes=theme.val((8, 8), dark=(4, 4)),
                    gapcolor=theme.val('#fff4', dark='#0004'),
                ),
                ConicalAnnotation(
                    RED,
                    2 * (np.pi / 2 - repulsion.mapper.b),  # type: ignore
                    color=theme.val('black', dark='#fffa'),
                    linewidth=theme.val(0.5, dark=1),
                ),
            ],
        )
    )
Three spherical plots, titled "Latents · repulsion". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a black sphere. The first plot has the appearance of a partial color wheel, with  vibrant colors around the rim (like a rainbow), with with a conspicuously absent space at the top where "red" should be. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom. The central region of the second and third plots show something interesting: "Red" and nearby colors have been arranged into a wide ring or disc, rather than being clustered in the center. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructor colors agree fairly well, except for colors close to "red", which roughly agree in saturation but differ significantly in tone and hue.

This looks really clean: reds have been pushed to a nearly smooth arc towards to the top of the middle plot, and the distribution of reds in the right plot are is more regular than in 2.3.

Conclusion

Visually, these results are mostly better than Ex 2.3:

  • Loss vs. hue curves are smoother and more regular
  • Latent space looks very regular, with primary and secondary colors mostly co-planar
  • Repuslion pushed reds into a clean arc.

Numerically, the global training loss values are slightly worse, so it's difficult to say if post-norm regularization would be better in more complicated networks.

Overall, I feel like it's probably a small improvement. Let's adopt this configuration for future experiments, but be open to trying the previous configuration again.