Experiment 2.6: Delete warm/cool axis
So far in this milestone, we've demonstrated that the structured latent space lends itself well to test-time intervention. In Ex 2.4, we applied suppression and repulsion interventions that modified activation vectors by squashing or redirecting them. Now let's see if we can remove information from the model. We'll try to remove the concept of "warmth" from the color autoencoder. The ability to cleanly delete concepts would be extremely useful in advanced AI:
- For closed models, you could be pretty sure that a jailbreak won't be able to elicit the deleted concept
- For open models, it should be more difficult to fine-tune the model to re-learn the concept.
Hypothesis
If we delete or zero-out the parameters associated with a concept dimension, then the model should lose the ability to operate on that concept. This should result in high loss for samples that require that concept dimension.
In our case, we should see high loss at red and cyan (which lie at opposing ends of the warmth dimension), and low loss for lime green, hot pink, white and black, which are orthogonal to warmth.
from __future__ import annotations
nbid = '2.6' # ID for tagging assets
nbname = 'Ablate/prune'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'
# Basic setup: Logging, Experiment (Modal)
import logging
import modal
from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig
logging_config = (
SimpleLoggingConfig()
.info('notebook', 'utils', 'mini', 'ex_color')
.error('matplotlib.axes') # Silence warnings about set_aspect
)
logging_config.apply()
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')
run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None # prevent auto-display of this cell
Regularizers
Like Ex 2.4:
- Anchor: pins
redto $(1,0,0,0)$ - Separate: angular repulsion to reduce global clumping (applied within each batch)
- Planarity: pulls vibrant hues to the $[0, 1]$ plane
- Unitarity: pulls all embeddings to the surface of the unit hypersphere, i.e. it makes the embedding vectors have unit length.
Unlike Ex 2.5, planarity has been added back in as a regularization term.
import torch
from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import Anchor, Separate, Unitarity, RegularizerConfig, Planarity
from ex_color.training import TrainingModule
RED = (1, 0, 0, 0)
ALL_REGULARIZERS = [
RegularizerConfig(
name='reg-unit',
compute_loss_term=Unitarity(),
label_affinities=None,
layer_affinities=['encoder'],
),
RegularizerConfig(
name='reg-anchor',
compute_loss_term=Anchor(torch.tensor(RED, dtype=torch.float32)),
label_affinities={'red': 1.0},
layer_affinities=['bottleneck'],
),
RegularizerConfig(
name='reg-separate',
compute_loss_term=Separate(power=100.0, shift=True),
label_affinities=None,
layer_affinities=['bottleneck'],
),
RegularizerConfig(
name='reg-planar',
compute_loss_term=Planarity(),
label_affinities={'vibrant': 1.0},
layer_affinities=['bottleneck'],
),
]
from functools import partial
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np
from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic
from ex_color.labelling import collate_with_generated_labels
def prep_data() -> tuple[DataLoader, Tensor]:
"""
Prepare data for training.
Returns: (train, val)
"""
hsv_cube = ColorCube.from_hsv(
h=arange_cyclic(step_size=10 / 360),
s=np.linspace(0, 1, 10),
v=np.linspace(0, 1, 10),
)
hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)
labeller = partial(
collate_with_generated_labels,
soft=False, # Use binary labels (stochastic) to simulate the labelling of internet text
red=0.5,
vibrant=0.5,
)
# Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
hsv_loader = DataLoader(
hsv_dataset,
batch_size=64,
num_workers=2,
sampler=WeightedRandomSampler(
weights=hsv_cube.bias.flatten().tolist(),
num_samples=len(hsv_dataset),
replacement=True,
),
collate_fn=labeller,
)
rgb_cube = ColorCube.from_rgb(
r=np.linspace(0, 1, 8),
g=np.linspace(0, 1, 8),
b=np.linspace(0, 1, 8),
)
rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
return hsv_loader, rgb_tensor
import wandb
from ex_color.model import CNColorMLP
# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
dopesheet: Dopesheet,
regularizers: list[RegularizerConfig],
) -> CNColorMLP:
"""Train the model with the given dopesheet and variant."""
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from ex_color.seed import set_deterministic_mode
from utils.progress.lightning import LightningProgress
log.info(f'Training with: {[r.name for r in regularizers]}')
seed = 0
set_deterministic_mode(seed)
hsv_loader, _ = prep_data()
model = CNColorMLP(4)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.debug(f'Model initialized with {total_params:,} trainable parameters.')
training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)
logger = WandbLogger(experiment_name, project=project)
trainer = L.Trainer(
max_steps=len(dopesheet),
callbacks=[
LightningProgress(),
],
enable_checkpointing=False,
enable_model_summary=False,
# enable_progress_bar=True,
logger=logger,
)
print(f'max_steps: {len(dopesheet)}, hsv_loader length: {len(hsv_loader)}')
# Train the model
try:
trainer.fit(training_module, hsv_loader)
finally:
wandb.finish()
# This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
return model
async with run():
model = await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS)
I 5.1 no.2.6: Training with: ['reg-unit', 'reg-anchor', 'reg-separate', 'reg-planar']
INFO: Seed set to 0
I 5.1 li.fa.ut.se:Seed set to 0 I 5.1 ex.se: PyTorch set to deterministic mode
INFO: GPU available: False, used: False
I 5.1 li.py.ut.ra:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
I 5.1 li.py.ut.ra:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
I 5.1 li.py.ut.ra:HPU available: False, using: 0 HPUs max_steps: 3001, hsv_loader length: 57
wandb: Currently logged in as: z0r to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
./wandb/run-20250905_063238-w3dpldlr
Starting phase: Train
INFO: `Trainer.fit` stopped: `max_steps=3001` reached.
I 27.8 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=3001` reached.
Run history:
| epoch | ▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████ |
| train_loss | █▄▄▆▆▃▃▃▂▄▆▂▆▃▄▄▂▃▄▃▂▂▃▃▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁ |
| train_recon | ▆▅▅▄▃█▃▃▇▇▄▄▄▅▅▄▇▆▄▅▃▄▅▄▂▂▂▂▁▂▁▁▂▂▁▁▁▁▁▁ |
| train_reg-anchor | ▁▁▁█▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁ |
| train_reg-planar | ▁█▁▃▁▁▁▁▂▁▂▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂ |
| train_reg-separate | █▆▆██▆▄▅▃▃▄▅▁▃▂▄▄▁▂▅▄▃▃▄▄▂▂▂▁▂▂▄▁▃▂▂▃▂▃▄ |
| train_reg-unit | █▇▆▅▃▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ |
| trainer/global_step | ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
Run summary:
| epoch | 52 |
| train_loss | 3e-05 |
| train_recon | 3e-05 |
| train_reg-anchor | 0 |
| train_reg-planar | 0.05332 |
| train_reg-separate | 0.43601 |
| train_reg-unit | 0.00924 |
| trainer/global_step | 2999 |
View project at: https://wandb.ai/z0r/ex-color-transformer
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
./wandb/run-20250905_063238-w3dpldlr/logs
The charts and loss values from training look much the same as last time.
- Roughly the same shape overall
- All loss values roughly the same
from ex_color.inference import InferenceModule
async def infer(
model: CNColorMLP,
test_data: Tensor,
) -> Tensor:
"""Run inference with the given model."""
import lightning as L
inference_module = InferenceModule(model, [])
trainer = L.Trainer(
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=True,
)
reconstructed_colors_batches = trainer.predict(
inference_module,
DataLoader(
TensorDataset(test_data.reshape((-1, 3))),
batch_size=64,
collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
),
)
assert reconstructed_colors_batches is not None
# Flatten the list of batches to a single list of tensors
reconstructed_colors = [item for batch in reconstructed_colors_batches for item in batch]
# Reshape to match input
return torch.cat(reconstructed_colors).reshape(test_data.shape)
import torch
import numpy as np
from ex_color.inference import InferenceModule
async def infer_with_latent_capture(
model: CNColorMLP,
test_data: Tensor,
layer_name: str = 'bottleneck',
) -> tuple[Tensor, Tensor]:
module = InferenceModule(model, [], capture_layers=[layer_name])
import lightning as L
trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False)
batches = trainer.predict(
module,
DataLoader(
TensorDataset(test_data.reshape((-1, 3))),
batch_size=64,
collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
),
)
assert batches is not None
preds = [item for batch in batches for item in batch]
y = torch.cat(preds).reshape(test_data.shape)
# Read captured activations as a flat [N, D] tensor
latents = module.read_captured(layer_name)
return y, latents
Quick sense-check: Let's see how well the trained model reconstructs colors.
from IPython.display import clear_output
import importlib
import utils.nb
import utils.plt
importlib.reload(utils.nb)
importlib.reload(utils.plt)
from ex_color.vis import plot_colors
from utils.nb import displayer_mpl
hsv_cube = ColorCube.from_hsv(
h=arange_cyclic(step_size=1 / 24),
s=np.linspace(0, 1, 4),
v=np.linspace(0, 1, 8),
).permute('svh')
x_hsv = torch.tensor(hsv_cube.rgb_grid, dtype=torch.float32)
hd_hsv_cube = ColorCube.from_hsv(
h=arange_cyclic(step_size=1 / 240),
s=np.linspace(0, 1, 48),
v=np.linspace(0, 1, 48),
)
hd_x_hsv = torch.tensor(hd_hsv_cube.rgb_grid, dtype=torch.float32)
rgb_cube = ColorCube.from_rgb(
r=np.linspace(0, 1, 20),
g=np.linspace(0, 1, 20),
b=np.linspace(0, 1, 20),
)
x_rgb = torch.tensor(rgb_cube.rgb_grid, dtype=torch.float32)
with displayer_mpl(
f'large-assets/ex-{nbid}-true-colors.png',
alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
) as show:
show(lambda: plot_colors(hsv_cube, title='True colors', colors=x_hsv.numpy()))
from IPython.display import clear_output
from torch.nn import functional as F
from ex_color.vis import plot_colors, plot_cube_series
interventions = []
y_hsv = await infer(model, x_hsv)
hd_y_hsv = await infer(model, hd_x_hsv)
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-pred-colors-no-intervention.png',
alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but some slight differences are visible; for example, "white" is slightly gray, and many of the fully-saturated colors are less saturated than they should be.""",
) as show:
show(
lambda: plot_colors(
hsv_cube,
title='Predicted colors · no intervention',
colors=y_hsv.numpy(),
colors_compare=x_hsv.numpy(),
)
)
per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
f'large-assets/ex-{nbid}-loss-colors-no-intervention.png',
alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. The range of loss values is small, but there are two notable peaks at all primary and secondary colors (red, yellow, green, etc.).""",
) as show:
show(
lambda: plot_cube_series(
loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
title='Reconstruction error · no intervention',
var='MSE',
figsize=(12, 3),
)
)
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Max loss: 0.0011 Median MSE: 1.5e-05
from IPython.display import clear_output
from ex_color.vis import plot_latent_grid_3d
y_rgb, h_rgb = await infer_with_latent_capture(model, x_rgb, 'bottleneck')
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-latents-no-intervention.png',
alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. The first plot has the appearance of a color wheel, with the full set of vibrant colors around the rim (like a rainbow), varying to black in the center. The other plots show different views of the same sphere, with hue varying across the equator and tone varying from top to bottom, and red in the center. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. In this plot the true and reconstructor colors agree fairly well, but slight differences can be seen if you look closely.""",
) as show:
show(
lambda theme: plot_latent_grid_3d(
h_rgb,
y_rgb,
x_rgb,
title='Latents · no intervention',
dims=[(1, 0, 2), (1, 2, 0), (1, 3, 0)],
dot_radius=10,
theme=theme,
)
)
Looks fine.
Ablation
Now that we have our model, let's try ablating (zeroing) red. We have a new function for this:
def ablate[M](model: M, layer_id: str, dims: Sequence[int]) -> M:
"""Return a copy of model where the selected latent dims are effectively nulled."""
...
This zeros out producer (upstream matrix) rows and consumer (downstream) columns for the given dims. Shapes remain unchanged.
from ex_color.surgery import ablate
ablated_model = ablate(model, 'bottleneck', [0])
y_hsv = await infer(ablated_model, x_hsv)
hd_y_hsv = await infer(ablated_model, hd_x_hsv)
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-pred-colors-ablated.png',
alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and "cyan" are clearly different: red itself appears as black, and the surrounding colors up to green and blue look more like green and purple. Colors near cyan are similarly affected.""",
) as show:
show(
lambda: plot_colors(
hsv_cube,
title='Predicted colors · ablated',
colors=y_hsv.numpy(),
colors_compare=x_hsv.numpy(),
)
)
per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
f'large-assets/ex-{nbid}-loss-colors-ablated.png',
alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is a significant peak at red at either end of the X-axis and at cyan in the middle, gradually sloping down to lower loss values near yellow-green and blue-magenta. Two smaller line charts show error vs. saturation and error vs. value, with high error near high value/saturation, and low error near low value/saturation.""",
) as show:
show(
lambda: plot_cube_series(
loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
title='Reconstruction error · ablated',
var='MSE',
figsize=(12, 3),
)
)
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Max loss: 0.27 Median MSE: 0.0031
This looks like a clean ablation. The effect is far more severe than our previous interventions, because there's no falloff function involved: red and cyan have been completely removed, and all colors have been affected except those that are completely orthogonal: lime-greens, purples, black, white, and grays.
from IPython.display import clear_output
from ex_color.vis import plot_latent_grid_3d
y_rgb, h_rgb = await infer_with_latent_capture(ablated_model, x_rgb, 'bottleneck')
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-latents-ablated.png',
alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. The first plot has horizontal line across the middle, with green on the left and blue on the right, and conspicuously empty spaces at the top and bottom. The other plots show different views of the same sphere, with hue varying across the equator from green to blue/purple and tone varying from top to bottom. The centre of the spheres where you might expect to see red or cyan instead show desaturated grays, white, and black. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructed colors agree fairly well. Red and cyan and nearby colors are in fact not visible, being buried somewhere inside the sphere.""",
) as show:
show(
lambda theme: plot_latent_grid_3d(
h_rgb,
y_rgb,
x_rgb,
title='Latents · ablated',
dims=[(1, 0, 2), (1, 2, 3), (1, 3, 2)],
dot_radius=10,
theme=theme,
)
)
This looks a lot like the suppression intervention, but again — much more severe. There's nothing on the warmth dimension at all.
Pruning
Let's try something even more severe: instead of just zeroing-out the weights, let's remove them entirely. This will reduce the dimensionality of the bottleneck. Again, we have a new function:
def prune[M](model: M, layer_id: str, dims: Sequence[int]) -> M:
"""Return a copy of model with selected latent dims fully removed."""
...
This reduces the latent width $k$ by $|dims|$ by:
- Removing the rows from the producer (upstream) Linear layer's weight/bias
- Removing the columns from the consumer (downstream) Linear layer's weight
from ex_color.surgery import prune
pruned_model = prune(model, 'bottleneck', [0])
y_hsv = await infer(pruned_model, x_hsv)
hd_y_hsv = await infer(pruned_model, hd_x_hsv)
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-pred-colors-pruned.png',
alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but "red" and "cyan" are clearly different: red itself appears as black, and the surrounding colors up to green and blue look more like green and purple. Colors near cyan are similarly affected.""",
) as show:
show(
lambda: plot_colors(
hsv_cube,
title='Predicted colors · pruned',
colors=y_hsv.numpy(),
colors_compare=x_hsv.numpy(),
)
)
per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
f'large-assets/ex-{nbid}-loss-colors-pruned.png',
alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is a significant peak at red at either end of the X-axis and at cyan in the middle, gradually sloping down to lower loss values near yellow-green and blue-magenta. Two smaller line charts show error vs. saturation and error vs. value, with high error near high value/saturation, and low error near low value/saturation.""",
) as show:
show(
lambda: plot_cube_series(
loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
title='Reconstruction error · pruned',
var='MSE',
figsize=(12, 3),
)
)
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')
Max loss: 0.27 Median MSE: 0.0031
These look identical to the ablation above.
from IPython.display import clear_output
from ex_color.vis import plot_latent_grid_3d
# Capture latents with repulsion
y_rgb, h_rgb = await infer_with_latent_capture(pruned_model, x_rgb, 'bottleneck')
clear_output()
with displayer_mpl(
f'large-assets/ex-{nbid}-latents-pruned.png',
alt_text="""Two spherical plots, titled "{title}". Each plot shows different views of a collection of colored circles or balls scattered over the surface of a sphere. Hue varyies across the equator from green to blue/purple and tone varying from top to bottom. The centre of the spheres where you might expect to see red or cyan instead show desaturated grays, white, and black. Each ball shows the reconstructed color, with a dot in the center showing the true (input) color. The true and reconstructed colors agree fairly well. Red and cyan and nearby colors are in fact not visible, being buried somewhere inside the sphere.""",
) as show:
show(
lambda theme: plot_latent_grid_3d(
h_rgb,
y_rgb,
x_rgb,
title='Latents · pruned',
dims=[(0, 1, 2), (0, 2, 1)],
dot_radius=10,
theme=theme,
)
)
These look identical to the second two ablation plots.
It wasn't possible to reproduce the first plot, since it required the warmth dimension which has been removed.