Experiment 1.5: Smooth curriculum and anchoring

This experiment combines the insights and tools from our previous work:

  • 3D Bottleneck & Curriculum: Like Experiment 1.3, we use a low-dimensional latent space and a curriculum to encourage the model to learn hue first.
  • Smooth Transitions: We replace the abrupt phase changes of Ex 1.3 with the smooth parameter transitions developed in Experiment 1.4, using the SmoothProp mechanism driven by a dopesheet.
  • Anchoring: We introduce an "anchor" regularization term, also controlled via the dopesheet, to fix the positions of key colors (primaries/secondaries) after the initial phase, preventing later phases from disrupting the learned hue structure.

This time we use four dimensions: two for hue, and one each for value and saturation. This is actually more than is strictly needed: consider that both HSV and RGB only use three dimensions! But they use a dense cube, whereas our latent space will be the surface of a hypersphere.

We hope to achieve a stable, well-structured latent space where hue forms a planar color wheel, while value and saturation extend into the dimension remaining dimensions. We expect this approach to be less sensitive to the initial conditions and exact timing and weighting of curriculum phases compared to the discrete steps in Ex 1.3.

from __future__ import annotations
import logging
from utils.logging import SimpleLoggingConfig

logging_config = SimpleLoggingConfig().info('notebook', 'utils', 'mini', 'ex_color')
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger('notebook')

Model architecture

We use the same simple 2-layer MLP autoencoder with a bottleneck as in previous experiments. The key difference lies not in the architecture, but in the training process governed by the smooth curriculum.

import torch
import torch.nn as nn

E = 4


class ColorMLP(nn.Module):
    def __init__(self, normalize_bottleneck=False):
        super().__init__()
        # RGB input (3D) → hidden layer → bottleneck → hidden layer → RGB output
        self.encoder = nn.Sequential(
            nn.Linear(3, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, E),  # Our critical bottleneck!
        )

        self.decoder = nn.Sequential(
            nn.Linear(E, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, 3),
            nn.Sigmoid(),  # Keep RGB values in [0,1]
        )

        self.normalize = normalize_bottleneck

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Get our bottleneck representation
        bottleneck = self.encoder(x)

        # Optionally normalize to unit vectors (like nGPT)
        if self.normalize:
            norm = torch.norm(bottleneck, dim=1, keepdim=True)
            bottleneck = bottleneck / (norm + 1e-8)  # Avoid division by zero

        # Decode back to RGB
        output = self.decoder(bottleneck)
        return output, bottleneck

Training machinery with timeline and events

The train_color_model function orchestrates the training process based on a Timeline derived from the dopesheet. It handles:

  • Iterating through training steps.
  • Fetching the correct data loader for the current phase.
  • Updating hyperparameters (like learning rate and loss weights) smoothly based on the timeline state.
  • Calculating the combined loss from reconstruction and various regularizers.
  • Executing the optimizer step.
  • Emitting events at different points (phase start/end, pre-step, actions like 'anchor', step metrics) to trigger callbacks like plotting, recording, or updating loss terms.
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from torch import Tensor
import torch.optim as optim

from mini.temporal.timeline import State


@dataclass
class InferenceResult:
    outputs: Tensor
    latents: Tensor

    def detach(self):
        return InferenceResult(self.outputs.detach(), self.latents.detach())

    def clone(self):
        return InferenceResult(self.outputs.clone(), self.latents.clone())

    def cpu(self):
        return InferenceResult(self.outputs.cpu(), self.latents.cpu())


@runtime_checkable
class LossCriterion(Protocol):
    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor: ...


@runtime_checkable
class SpecialLossCriterion(LossCriterion, Protocol):
    def forward(self, model: ColorMLP, data: Tensor) -> InferenceResult | None: ...


@dataclass(eq=False, frozen=True)
class Event:
    name: str
    step: int
    model: ColorMLP
    timeline_state: State
    optimizer: optim.Optimizer


@dataclass(eq=False, frozen=True)
class PhaseEndEvent(Event):
    validation_data: Tensor
    inference_result: InferenceResult


@dataclass(eq=False, frozen=True)
class StepMetricsEvent(Event):
    """Event carrying metrics calculated during a training step."""

    total_loss: float
    losses: dict[str, float]


class EventHandler[T](Protocol):
    def __call__(self, event: T) -> None: ...


class EventBinding[T]:
    """A class to bind events to handlers."""

    def __init__(self, event_name: str):
        self.event_name = event_name
        self.handlers: list[tuple[str, EventHandler[T]]] = []

    def add_handler(self, event_name: str, handler: EventHandler[T]) -> None:
        self.handlers.append((event_name, handler))

    def emit(self, event_name: str, event: T) -> None:
        for name, handler in self.handlers:
            if name == event_name:
                handler(event)


class EventHandlers:
    """A simple event system to allow for custom callbacks."""

    phase_start: EventBinding[Event]
    pre_step: EventBinding[Event]
    action: EventBinding[Event]
    phase_end: EventBinding[PhaseEndEvent]
    step_metrics: EventBinding[StepMetricsEvent]

    def __init__(self):
        self.phase_start = EventBinding[Event]('phase-start')
        self.pre_step = EventBinding[Event]('pre-step')
        self.action = EventBinding[Event]('action')
        self.phase_end = EventBinding[PhaseEndEvent]('phase-end')
        self.step_metrics = EventBinding[StepMetricsEvent]('step-metrics')
from typing import Iterable, Iterator
from torch.utils.data import DataLoader
import torch.optim as optim

from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from utils.progress import RichProgress


def reiterate[T](it: Iterable[T]) -> Iterator[T]:
    """
    Iterates over an iterable indefinitely.

    When the iterable is exhausted, it starts over from the beginning. Unlike
    `itertools.cycle`, yielded values are not cached — so each iteration may be
    different.
    """
    while True:
        yield from it


def train_color_model(  # noqa: C901
    model: ColorMLP,
    datasets: dict[str, tuple[DataLoader, Tensor]],
    dopesheet: Dopesheet,
    loss_criteria: dict[str, LossCriterion | SpecialLossCriterion],
    event_handlers: EventHandlers | None = None,
):
    if event_handlers is None:
        event_handlers = EventHandlers()

    # --- Validate inputs ---
    # Check if all phases in dopesheet have corresponding data
    dopesheet_phases = dopesheet.phases
    missing_data = dopesheet_phases - set(datasets.keys())
    if missing_data:
        raise ValueError(f'Missing data for dopesheet phases: {missing_data}')

    # Check if 'lr' is defined in the dopesheet properties
    if 'lr' not in dopesheet.props:
        raise ValueError("Dopesheet must define the 'lr' property column.")
    # --- End Validation ---

    timeline = Timeline(dopesheet)
    optimizer = optim.Adam(model.parameters(), lr=0)
    device = next(model.parameters()).device

    data_iterators = {
        phase_name: iter(reiterate(dataloader))  #
        for phase_name, (dataloader, _) in datasets.items()
    }

    total_steps = len(timeline)

    with RichProgress(total=total_steps, description='Training Steps') as pbar:
        for step in range(total_steps):
            # Get state *before* advancing timeline for this step's processing
            current_state = timeline.state
            current_phase_name = current_state.phase

            # Assuming TensorDataset yields a tuple with one element
            (batch,) = next(data_iterators[current_phase_name])

            # --- Event Handling ---
            event_template = {
                'step': step,
                'model': model,
                'timeline_state': current_state,
                'optimizer': optimizer,
            }

            if current_state.is_phase_start:
                event = Event(name=f'phase-start:{current_phase_name}', **event_template)
                event_handlers.phase_start.emit(event.name, event)
                event_handlers.phase_start.emit('phase-start', event)

            for action in current_state.actions:
                event = Event(name=f'action:{action}', **event_template)
                event_handlers.action.emit(event.name, event)
                event_handlers.action.emit('action', event)

            event = Event(name='pre-step', **event_template)
            event_handlers.pre_step.emit('pre-step', event)

            # --- Training Step ---
            # ... (get data, update LR, zero grad, forward pass, calculate loss, backward, step) ...

            current_lr = current_state.props['lr']
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            optimizer.zero_grad()

            outputs, latents = model(batch.to(device))
            current_results = InferenceResult(outputs, latents)

            total_loss = torch.tensor(0.0, device=device)
            losses_dict: dict[str, float] = {}
            for name, criterion in loss_criteria.items():
                weight = current_state.props.get(name, 0.0)
                if weight == 0:
                    continue

                if isinstance(criterion, SpecialLossCriterion):
                    # Special criteria might run on their own data (like Anchor)
                    # or potentially use the current batch (depends on implementation).
                    # The forward method gets the model and the *current batch*
                    special_results = criterion.forward(model, batch)
                    if special_results is None:
                        continue
                    term_loss = criterion(batch, special_results)
                else:
                    term_loss = criterion(batch, current_results)

                total_loss += term_loss * weight
                losses_dict[name] = term_loss.item()

            if total_loss > 0:
                total_loss.backward()
                optimizer.step()
            # --- End Training Step ---

            # Emit step metrics event
            step_metrics_event = StepMetricsEvent(
                name='step-metrics',
                **event_template,
                total_loss=total_loss.item(),
                losses=losses_dict,
            )
            event_handlers.step_metrics.emit('step-metrics', step_metrics_event)

            # --- Post-Step Event Handling ---
            if current_state.is_phase_end:
                # Trigger phase-end for the *current* phase
                _, validation_data = datasets[current_phase_name]
                with torch.no_grad():
                    val_outputs, val_latents = model(validation_data.to(device))
                event = PhaseEndEvent(
                    name=f'phase-end:{current_phase_name}',
                    **event_template,
                    validation_data=validation_data,
                    inference_result=InferenceResult(val_outputs, val_latents),
                )
                event_handlers.phase_end.emit(event.name, event)
                event_handlers.phase_end.emit('phase-end', event)
            # --- End Event Handling ---

            # Update progress bar
            pbar.update(
                metrics={
                    'PHASE': current_phase_name,
                    'lr': f'{current_lr:.6f}',
                    'loss': f'{total_loss.item():.4f}',
                    **{name: f'{lt:.4f}' for name, lt in losses_dict.items()},
                },
            )

            # Advance timeline *after* processing the current step
            if step < total_steps:  # Avoid stepping past the end
                timeline.step()

    log.info('Training finished!')

Phase plotting callback

This PhasePlotter class acts as an event handler. It listens for phase-end events emitted by the training loop. When a phase ends, it captures the model's latent representations for the validation data of that phase and generates plots showing the state of the latent space (projected onto different 2D planes). This allows us to visualize how the structure evolves across the curriculum.

import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from torch import Tensor
from IPython.display import HTML

from utils.nb import save_fig


class PhasePlotter:
    """Event handler to plot latent space at the end of each phase."""

    def __init__(self, dim_pairs: list[tuple[int, int]] | None = None):
        from utils.nb import displayer

        # Store (phase_name, end_step, data, result) - data comes from event now
        self.history: list[tuple[str, int, Tensor, InferenceResult]] = []
        self.display = displayer()
        self.dim_pairs = dim_pairs or [(0, 1), (0, 2)]

    # Expect PhaseEndEvent specifically
    def __call__(self, event: PhaseEndEvent):
        """Handle phase-end events."""
        if not isinstance(event, PhaseEndEvent):
            raise TypeError(f'Expected PhaseEndEvent, got {type(event)}')

        # TODO: Don't assume device = CPU
        # TODO: Split this class so that the event handler is separate from the plotting, and so the plotting can happen locally with @run.hither
        phase_name = event.timeline_state.phase
        end_step = event.step
        phase_dataset = event.validation_data
        inference_result = event.inference_result

        log.info(f'Plotting end of phase: {phase_name} at step {end_step} using provided results.')

        # Append to history
        self.history.append((phase_name, end_step, phase_dataset.cpu(), inference_result.cpu()))

        # Plotting logic remains the same as it already expected CPU tensors
        fig = self._plot_phase_history()
        self.display(
            HTML(
                save_fig(
                    fig,
                    'large-assets/ex-1.5-color-phase-history.png',
                    alt_text='Visualizations of latent space at the end of each curriculum phase.',
                )
            )
        )

    def _plot_phase_history(self):
        num_phases = len(self.history)
        plt.style.use('dark_background')
        if num_phases == 0:
            fig, ax = plt.subplots()
            fig.set_facecolor('#333')
            ax.set_facecolor('#222')
            ax.text(0.5, 0.5, 'Waiting...', ha='center', va='center')
            return fig

        fig, axes = plt.subplots(
            num_phases, len(self.dim_pairs), figsize=(5 * len(self.dim_pairs), 5 * num_phases), squeeze=False
        )
        fig.set_facecolor('#333')

        for row_idx, (phase_name, end_step, data, res) in enumerate(self.history):
            _latents = res.latents.numpy()
            _colors = data.numpy()

            for col_idx, (dim1, dim2) in enumerate(self.dim_pairs):
                ax = axes[row_idx, col_idx]
                ax.set_facecolor('#222')
                ax.scatter(_latents[:, dim1], _latents[:, dim2], c=_colors, s=50, alpha=0.7)

                # Set y-label differently for the first column
                if col_idx == 0:
                    ax.set_ylabel(
                        f'Phase: {phase_name}\n(End Step: {end_step})',
                        fontsize='medium',
                        rotation=90,  # Rotate vertically
                        labelpad=15,  # Adjust padding
                        verticalalignment='center',
                        horizontalalignment='center',
                    )
                else:
                    # Standard y-label for other columns
                    ax.set_ylabel(f'Dim {dim2}')

                # Set title only for the top row
                if row_idx == 0:
                    ax.set_title(f'Dims {dim1} vs {dim2}')

                # Standard x-label for all columns
                ax.set_xlabel(f'Dim {dim1}')

                # Keep other plot settings
                ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
                ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
                ax.add_patch(Circle((0, 0), 1, fill=False, linestyle='--', color='gray', alpha=0.3))
                ax.set_aspect('equal')

        fig.tight_layout()
        return fig

Smooth curriculum via dopesheet

Instead of defining discrete phases with fixed parameters, we now use a dopesheet (as CSV) to define keyframes for our hyperparameters. The Timeline class interpolates these values smoothly between keyframes using the minimum jerk approach from Experiment 1.4.

The dopesheet controls:

  • Learning Rate (lr): Gradually decreased over training.
  • Reconstruction Loss Weight (loss-recon): Kept constant.
  • Regularization Weights (reg-separate, reg-planar, reg-norm, reg-anchor): Faded in and out to guide the model. For example, reg-separate and reg-planar are strong early on to establish the color wheel, while reg-anchor activates later to lock it in place.
  • Data Fraction (data-fraction): Controls the DynamicWeightedRandomBatchSampler to smoothly transition the training data distribution from vibrant colors towards the full color space (details below).
  • Actions (ACTION): Triggers specific events, like the anchor action which tells the Anchor regularizer to capture the current latent positions of the primary/secondary colors.

This allows for a more continuous and potentially more stable learning process.

from IPython.display import display, HTML, Markdown
from matplotlib import pyplot as plt
from matplotlib.figure import Figure

from mini.temporal.vis import group_properties_by_scale, plot_timeline, realize_timeline
from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from utils.nb import save_fig

dopesheet = Dopesheet.from_csv('ex-1.5-dopesheet.csv')
display(
    Markdown(f"""
## Parameter schedule
{dopesheet.to_markdown()}
""")
)

timeline = Timeline(dopesheet)
history_df = realize_timeline(timeline)
keyframes_df = dopesheet.as_df()

groups = group_properties_by_scale(keyframes_df[dopesheet.props])
fig, ax = plot_timeline(history_df, keyframes_df, groups)
# Add assertion to satisfy type checker
assert isinstance(fig, Figure), 'plot_timeline should return a Figure'
display(
    HTML(
        save_fig(
            fig,  # Now type checker is happy
            'large-assets/ex-1.5-color-timeline.png',
            alt_text='Line chart showing the hyperparameter schedule over time.',
        )
    )
)

Parameter schedule

STEP PHASE ACTION lr loss-recon reg-separate reg-planar reg-norm reg-anchor data-fraction
0 Primary & secondary 1 0 0.2
1200 0.8 0.3
1800 0.4 0.1
3000 All hues anchor 0 0.25 0 0
3350 0.01
6500 0.8 0
8600 0.3
10000 Full color space 0.25
10500
13000 1 0.1 1
20000 0.001 0.75
I 386.3 ut.nb: Figure saved: 'large-assets/ex-1.5-color-timeline.png'
Line chart showing the hyperparameter schedule over time.

Loss functions and regularizers

We use mean squared error for the main reconstruction loss (loss-recon). The following regularizers, weighted according to the dopesheet schedule, guide the latent space structure:

  • unitarity (reg-norm): Encourages latent vectors to lie on a unit hypersphere by penalizing deviations from a norm of 1.
  • planarity (reg-planar): Pushes dimensions beyond the first two towards zero, encouraging the primary hue structure to form in the first two dimensions.
  • Separate (reg-separate): Pushes latent points away from each other, primarily used in the early phase to spread out the primary/secondary colors.
  • Anchor (reg-anchor): This is a SpecialLossCriterion. When the anchor action is triggered by the timeline, its on_anchor method captures the current latent positions of a reference dataset (the primary/secondary colors). Subsequently, its __call__ method calculates a loss based on how far the current model places those reference colors from their captured anchor positions. This penalizes drift in the established structure.
from torch import linalg as LA

from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic


def objective(fn):
    """Adapt loss function to look like a regularizer"""

    def wrapper(data: Tensor, res: InferenceResult) -> Tensor:
        return fn(data, res.outputs)

    return wrapper


def unitarity(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to have unit norm (vectors of length 1)"""
    norms = LA.vector_norm(res.latents, dim=-1)
    return torch.mean((norms - 1.0) ** 2)


def planarity(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to be planar in the first two channels (so zero in other channels)"""
    return torch.mean(res.latents[:, 2:] ** 2)


class Separate(LossCriterion):
    def __init__(self, channels: tuple[int, ...] = (0, 1)):
        self.channels = channels

    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor:
        """Regularize latents to be separated from each other in first two channels"""
        # Get pairwise differences in the first two dimensions
        points = res.latents[:, self.channels]  # [B, C]
        diffs = points.unsqueeze(1) - points.unsqueeze(0)  # [B, B, C]

        # Calculate squared distances
        sq_dists = torch.sum(diffs**2, dim=-1)  # [B, B]

        # Remove self-distances (diagonal)
        mask = 1.0 - torch.eye(sq_dists.shape[0], device=sq_dists.device)
        masked_sq_dists = sq_dists * mask

        # Encourage separation by minimizing inverse distances (stronger repulsion between close points)
        epsilon = 1e-6  # Prevent division by zero
        return torch.mean(1.0 / (masked_sq_dists + epsilon))


class Anchor(SpecialLossCriterion):
    """Regularize latents to be close to their position in the reference phase"""

    ref_data: Tensor
    _ref_latents: Tensor | None = None

    def __init__(self, ref_data: Tensor):
        self.ref_data = ref_data
        self._ref_latents = None
        log.info(f'Anchor initialized with reference data shape: {ref_data.shape}')

    def forward(self, model: ColorMLP, data: Tensor) -> InferenceResult | None:
        """Run the *stored reference data* through the *current* model."""
        # Note: The 'data' argument passed by the training loop for SpecialLossCriterion
        # is the *current training batch*, which we IGNORE here.
        # We only care about running our stored _ref_data through the model.
        device = next(model.parameters()).device
        ref_data = self.ref_data.to(device)

        outputs, latents = model(ref_data)
        return InferenceResult(outputs, latents)

    def __call__(self, data: Tensor, special: InferenceResult) -> Tensor:
        """Calculates loss between current model's latents (for ref_data) and the stored reference latents."""
        if self._ref_latents is None:
            # This means on_anchor hasn't been called yet, so the anchor loss is zero.
            # This prevents errors during the very first phase before the anchor point is set.
            log.debug('Anchor.__call__ invoked before reference latents captured. Returning zero loss.')
            return torch.tensor(0.0, device=special.latents.device)
        ref_latents = self._ref_latents.to(special.latents.device)
        return torch.mean((special.latents - ref_latents) ** 2)

    def on_anchor(self, event: Event):
        # Called when the 'anchor' event is triggered
        log.info(f'Capturing anchor latents via Anchor.on_anchor at step {event.step}')

        device = next(event.model.parameters()).device
        ref_data = self.ref_data.to(device)

        with torch.no_grad():
            _, latents = event.model(ref_data)
        self._ref_latents = latents.detach().cpu()
        log.info(f'Anchor state captured internally. Ref data: {ref_data.shape}, Ref latents: {latents.shape}')

Data loading, sampling, and event handling

Here we set up:

  • Datasets: Define the datasets used (primary/secondary colors, full color grid).
  • Sampler: Use DynamicWeightedRandomBatchSampler for the full dataset. Its weights are updated by the update_sampler_weights callback, which responds to the data-fraction parameter from the dopesheet. This smoothly shifts the sampling focus from highly vibrant colors early on to the full range of colors later.
  • Recorders: ModelRecorder and MetricsRecorder are event handlers that save the model state and loss values at each step.
  • Event bindings: Connect event handlers to specific events (e.g., plotter to phase-end, reg_anchor.on_anchor to action:anchor, recorders to pre-step and step-metrics).
  • Training execution: Finally, call train_color_model with the model, datasets, dopesheet, loss criteria, and configured event handlers.
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from ex_color.data.cube_sampler import DynamicWeightedRandomBatchSampler, vibrancy
from ex_color.data.filters import levels


class ModelRecorder(EventHandler):
    """Event handler to record model parameters."""

    history: list[tuple[int, dict[str, Tensor]]]
    """A list of tuples (step, state_dict) where state_dict is a copy of the model's state dict."""

    def __init__(self):
        self.history = []

    def __call__(self, event: Event):
        # It's crucial to get a *copy* of the state dict and move it to the CPU
        # so we don't hold onto GPU memory or track gradients unnecessarily.
        model_state = {k: v.cpu().clone() for k, v in event.model.state_dict().items()}
        self.history.append((event.step, model_state))
        log.debug(f'Recorded model state at step {event.step}')


class MetricsRecorder(EventHandler):
    """Event handler to record training metrics."""

    history: list[tuple[int, float, dict[str, float]]]
    """A list of tuples (step, total_loss, losses_dict)."""

    def __init__(self):
        self.history = []

    def __call__(self, event: StepMetricsEvent):
        # Ensure we are handling the correct event type
        if not isinstance(event, StepMetricsEvent):
            log.warning(f'MetricsRecorder received unexpected event type: {type(event)}')
            return

        self.history.append((event.step, event.total_loss, event.losses.copy()))
        log.debug(f'Recorded metrics at step {event.step}: loss={event.total_loss:.4f}')


primary_cube = ColorCube.from_hsv(h=arange_cyclic(step_size=1 / 6), s=np.ones(1), v=np.ones(1))
primary_tensor = torch.tensor(primary_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
primary_dataset = TensorDataset(primary_tensor)
primary_loader = DataLoader(primary_dataset, batch_size=len(primary_tensor))

full_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=10 / 360),
    s=np.linspace(0, 1, 10),
    v=np.linspace(0, 1, 10),
)
full_tensor = torch.tensor(full_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
full_dataset = TensorDataset(full_tensor)
full_sampler = DynamicWeightedRandomBatchSampler(
    bias=full_cube.bias.flatten(),
    batch_size=256,
    steps_per_epoch=100,
)
vibrancy_weights = vibrancy(full_cube).flatten()
full_loader = DataLoader(full_dataset, batch_sampler=full_sampler)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 10),
    g=np.linspace(0, 1, 10),
    b=np.linspace(0, 1, 10),
)
rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)


def update_sampler_weights(event: Event):
    frac = event.timeline_state.props['data-fraction']
    # When the fraction is near zero, in_low is almost 1 — which means "scale everything down to 0 except for 1"
    # When the fraction is 0.5, in_low and out_low are both 0, so the weights are unchanged
    # When the fraction is 1, in_low is 0 and out_low is 1, so the weights are all scaled to 1
    in_low = np.interp(frac, [0, 0.5], [0.99, 0])
    out_low = np.interp(frac, [0.5, 1], [0, 1])
    full_sampler.weights = levels(vibrancy_weights, in_low=in_low, out_low=out_low)


recorder = ModelRecorder()
metrics_recorder = MetricsRecorder()

# Phase -> (train loader, validation tensor)
datasets: dict[str, tuple[DataLoader, Tensor]] = {
    'Primary & secondary': (primary_loader, primary_tensor),
    'All hues': (full_loader, rgb_tensor),
    'Full color space': (full_loader, rgb_tensor),
}

model = ColorMLP(normalize_bottleneck=False)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.info(f'Model initialized with {total_params:,} trainable parameters.')

event_handlers = EventHandlers()
event_handlers.pre_step.add_handler('pre-step', recorder)
event_handlers.pre_step.add_handler('pre-step', update_sampler_weights)
event_handlers.step_metrics.add_handler('step-metrics', metrics_recorder)

plotter = PhasePlotter(dim_pairs=[(0, 1), (0, 2), (0, 3)])
event_handlers.phase_end.add_handler('phase-end', plotter)

reg_anchor = Anchor(ref_data=primary_tensor)
event_handlers.action.add_handler('action:anchor', reg_anchor.on_anchor)

history = train_color_model(
    model,
    datasets,
    dopesheet,
    loss_criteria={
        'loss-recon': objective(nn.MSELoss()),
        'reg-separate': Separate((0, 1)),
        'reg-planar': planarity,
        'reg-norm': unitarity,
        'reg-anchor': reg_anchor,
    },
    event_handlers=event_handlers,
)
I 420.9 no:    Model initialized with 263 trainable parameters.
I 420.9 no:    Anchor initialized with reference data shape: torch.Size([6, 3])
Training Steps: 100.0% [20001/20001] [00:49/<00:00, 405.11 it/s]
PHASE
lr
loss
loss-recon
reg-norm
reg-anchor
Full color space
0.001000
0.0002
0.0002
0.0001
0.0000
I 425.3 no:    Plotting end of phase: Primary & secondary at step 2999 using provided results.
I 425.8 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
I 425.8 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
Visualizations of latent space at the end of each curriculum phase.
I 425.8 no:    Capturing anchor latents via Anchor.on_anchor at step 3000
I 425.8 no:    Anchor state captured internally. Ref data: torch.Size([6, 3]), Ref latents: torch.Size([6, 4])
I 425.8 no:    Anchor state captured internally. Ref data: torch.Size([6, 3]), Ref latents: torch.Size([6, 4])
I 444.1 no:    Plotting end of phase: All hues at step 9999 using provided results.
I 444.1 no:    Plotting end of phase: All hues at step 9999 using provided results.
I 444.5 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
I 444.5 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
I 469.6 no:    Plotting end of phase: Full color space at step 20000 using provided results.
I 469.6 no:    Plotting end of phase: Full color space at step 20000 using provided results.
I 470.3 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
I 470.3 no:    Training finished!
I 470.3 ut.nb: Figure saved: 'large-assets/ex-1.5-color-phase-history.png'
I 470.3 no:    Training finished!

Latent space evolution analysis

Let's visualize how the latent space evolved over time. We use the ModelRecorder's history to load the model state at each recorded step and evaluate the latent positions for a fixed set of input colors (the full RGB grid). This gives us a sequence of latent space snapshots.

import numpy as np


def eval_latent_history(
    recorder: ModelRecorder,
    rgb_tensor: Tensor,
):
    """Evaluate the latent space for each step in the recorder's history."""
    # Create a new model instance
    from utils.progress import RichProgress

    model = ColorMLP(normalize_bottleneck=False)

    latent_history: list[tuple[int, np.ndarray]] = []
    # Iterate over the recorded history
    for step, state_dict in RichProgress(recorder.history, description='Evaluating latents'):
        # Load the model state dict
        model.load_state_dict(state_dict)
        model.eval()
        with torch.no_grad():
            # Get the latents for the RGB tensor
            _, latents = model(rgb_tensor.to(next(model.parameters()).device))
            latents = latents.cpu().numpy()
            latent_history.append((step, latents))
    return latent_history


latent_history = eval_latent_history(recorder, rgb_tensor)
Evaluating latents: 0.0% [0/20001] [00:00/<00:00, 0.00 it/s]

Animation of latent space

This final visualization combines multiple views into a single animation:

  • Latent space: Shows the 2D projection (Dims 0 vs 1) of the latent embeddings for the full RGB color grid, colored by their true RGB values. We can see the color wheel forming and potentially expanding/contracting.
  • Hyperparameters: Replots the parameter schedule from the dopesheet, with a vertical line indicating the current step in the animation.
  • Training metrics: Plots the total loss and the contribution of each individual loss/regularization term (on a log scale), again with a vertical line for the current step.

(Note: A variable stride is used for sampling frames to focus on periods of rapid change.)

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio_ffmpeg
from matplotlib import rcParams
import pandas as pd

from mini.temporal.dopesheet import RESERVED_COLS
from utils.progress import RichProgress
from mini.temporal.vis import group_properties_by_scale, plot_timeline

rcParams['animation.ffmpeg_path'] = imageio_ffmpeg.get_ffmpeg_exe()


def animate_latent_evolution_with_metrics(
    latent_history: list[tuple[int, np.ndarray]],
    metrics_history: list[tuple[int, float, dict[str, float]]],
    param_history_df: pd.DataFrame,
    param_keyframes_df: pd.DataFrame,
    colors: np.ndarray,
    dim_pair: tuple[int, int] = (0, 1),
    interval=1 / 30,
):
    """Create an animation of the latent space evolution alongside hyperparameter and metric plots."""
    plt.style.use('dark_background')
    # Create a figure with 3 subplots: 1 for latent, 2 for lines
    fig = plt.figure(figsize=(12, 6))
    gs = fig.add_gridspec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1])

    ax_latent = fig.add_subplot(gs[:, 0])  # Latent space on the left, spanning rows
    ax_params = fig.add_subplot(gs[0, 1])  # Params top right
    ax_metrics = fig.add_subplot(gs[1, 1])  # Metrics bottom right

    fig.patch.set_facecolor('#333')
    ax_latent.patch.set_facecolor('#222')
    ax_params.patch.set_facecolor('#222')
    ax_metrics.patch.set_facecolor('#222')

    # --- Setup Latent Plot ---
    ax_latent.set_xlim(-1.5, 1.5)
    ax_latent.set_ylim(-1.5, 1.5)
    ax_latent.set_aspect('equal')
    ax_latent.set_xlabel(f'Dim {dim_pair[0]}')
    ax_latent.set_ylabel(f'Dim {dim_pair[1]}')
    step, current_latents = latent_history[0]
    scatter = ax_latent.scatter(
        current_latents[:, dim_pair[0]], current_latents[:, dim_pair[1]], c=colors, s=30, alpha=0.7
    )
    title_latent = ax_latent.set_title(f'Latent Space (Step {step})')

    # --- Setup Parameter Plot ---
    # Filter out 'STEP' and other reserved columns before grouping
    param_props = param_keyframes_df.columns.difference(list(RESERVED_COLS)).tolist()
    param_groups = group_properties_by_scale(param_keyframes_df[param_props])
    # Pass only the first group if plotting on a specific axis
    plot_timeline(param_history_df, param_keyframes_df, [param_groups[0]], ax=ax_params, show_legend=True)
    param_vline = ax_params.axvline(step, color='white', linestyle='--', lw=1)

    ax_params.set_title('Hyperparameters')
    ax_params.set_xlabel('')  # Remove x-label as it shares with metrics
    ax_params.tick_params(axis='x', labelbottom=False)

    # --- Setup Metrics Plot ---
    metrics_steps = [h[0] for h in metrics_history]
    total_losses = [h[1] for h in metrics_history]
    loss_components = {k: [h[2].get(k, np.nan) for h in metrics_history] for k in metrics_history[0][2].keys()}

    ax_metrics.plot(metrics_steps, total_losses, label='Total Loss', lw=2)
    for name, values in loss_components.items():
        ax_metrics.plot(metrics_steps, values, label=name, lw=1, alpha=0.8)

    ax_metrics.set_xlabel('Step')
    ax_metrics.set_ylabel('Loss (log scale)')  # Update label
    ax_metrics.set_title('Training Metrics')
    ax_metrics.legend(fontsize='small')
    ax_metrics.set_yscale('log')  # Set log scale
    ax_metrics.set_ylim(bottom=1e-6)  # Set bottom slightly above zero for log scale
    metrics_vline = ax_metrics.axvline(step, color='white', linestyle='--', lw=1)
    # Use uppercase 'STEP' for accessing the history_df column
    max_step = param_history_df['STEP'].max()
    ax_metrics.set_xlim(left=0, right=max_step)
    ax_params.set_xlim(left=0, right=max_step)

    fig.tight_layout()

    def update(frame: int):
        # frame is the index in the *sampled* history
        latent_step, current_latents = latent_history[frame]

        # Update latent space
        scatter.set_offsets(current_latents[:, dim_pair])
        title_latent.set_text(f'Latent Space (Step {latent_step})')

        # Update vertical lines
        param_vline.set_xdata([latent_step])
        metrics_vline.set_xdata([latent_step])

        return scatter, title_latent, param_vline, metrics_vline

    # Use the length of the (potentially strided) latent_history for frames
    num_frames = len(latent_history)
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=interval * 1000, blit=True)
    return fig, ani


# --- Variable Stride Logic ---
def get_stride(step: int):
    import math

    a = 7.9236
    b = 0.0005
    return a * math.log(b * step + 1) + 1


sampled_indices = [0]
last_sampled_index = 0
while True:
    # Get the step number corresponding to the last sampled frame
    current_step = latent_history[round(last_sampled_index)][0]
    # Determine the stride based on that step number
    stride = get_stride(current_step)
    # Calculate the index of the next potential frame
    next_index = last_sampled_index + stride
    # Stop if we've gone past the end of the history
    if round(next_index) >= len(latent_history):
        break
    # Add the calculated index to our list
    sampled_indices.append(round(next_index))
    # Update the last sampled index for the next iteration
    last_sampled_index = next_index

# Use the sampled indices to select frames from the full history
sampled_latent_history = [latent_history[i] for i in sampled_indices]
# --- End Variable Stride Logic ---

# Filter metrics history to align with the *new* sampled latent history steps
sampled_steps_set = {step for step, _ in sampled_latent_history}
filtered_metrics_history = [h for h in metrics_recorder.history if h[0] in sampled_steps_set]

# Make sure we have the parameter history dataframes (they were created earlier)
# history_df, keyframes_df

fig, ani = animate_latent_evolution_with_metrics(
    latent_history=sampled_latent_history,  # Use variable stride history
    metrics_history=filtered_metrics_history,  # Use filtered metrics
    param_history_df=history_df,  # Full parameter history
    param_keyframes_df=keyframes_df,  # Keyframes for plotting
    colors=rgb_tensor.cpu().numpy(),
    dim_pair=(0, 1),
)

video_file = 'large-assets/ex-1.5-latent-evolution-with-metrics.mp4'  # New filename
num_frames_to_render = len(sampled_latent_history)  # Update frame count
with RichProgress(total=num_frames_to_render, description='Rendering video') as pbar:
    ani.save(
        video_file,
        fps=30,
        extra_args=['-vcodec', 'libx264'],
        progress_callback=lambda i, n: pbar.update(1),
    )
plt.close(fig)

from random import randint
from IPython.display import display, HTML

cache_buster = randint(1, 1_000_000)

display(
    HTML(
        f"""
        <video width="960" height="480" controls loop>
            <source src="{video_file}?v={cache_buster:d}" type="video/mp4">
            Your browser does not support the video tag.
        </video>
        """
    )
)
Rendering video: 100.0% [1905/1905] [04:34/<00:00, 6.95 it/s]
Your browser does not support the video tag.