Tiny struggles

Always hacking something πŸ§‘β€πŸ”¬.

Implementing SARM on your VLA dataset in practice

1. Motivation: use big video dataset optimally for VLA training

In the first part of this series, I explained what is SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation and how it can be used with a challenge such as Stanford Behavior Challenge (1200h of demonstrations over 50 diverse long horizon tasks).

To sum up: Our main model for the robot (VLA) is trained on short windows (chunks) of data for which it predicts actions. Our SARM model is used to estimate progress within an episode and evaluates windows. We want to prioritize training on trajectory segments where the robot made meaningful progress toward task completion.

In this post I will explain how I actually implemented this in practice. The code is now open on github. The core of the implementation follows closely the original paper.

We will cover the following key areas:

  • The design of the model Sequential Multimodal Architecture that utilizes a Global Anchor Frame.
  • Data input shape and preparation
  • Using the model for scoring the episodes
  • Visual Validation of the predicted progress against our Stage-Aware Ground Truth.

2. The SARM Model Implementation

See the source here.

The model tackles a dual prediction problem: determining which stage of a task is being performed (classification) and how much progress has been made within that stage (regression).

SARM provides a principled approach to stage-aware reward modeling by:

  1. Leveraging pretrained vision models (CLIP) for robust visual understanding
  2. Fusing multimodal information through Transformers
  3. Making hierarchical predictions (stage β†’ progress within stage)
  4. Handling variable-length sequences efficiently

Architecture Overview

This architecture is particularly well-suited for tasks that have clear sequential structure and require fine-grained progress estimation within each stage.

The SARM model follows a three-part design:

  1. Encoders - Process multimodal inputs (visual and proprioceptive)
  2. Shared Backbone - A Transformer that fuses information across time and modalities
  3. Dual Heads - Separate outputs for stage classification and progress regression

Where the symbols are:

  • B - batch size
  • N - sequence size (multiple frames of images/data - more on that later)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         INPUT LAYER                                β”‚
β”‚                                                                    β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚  β”‚ Image Frames β”‚  β”‚ Joint States β”‚  β”‚  Task Index  β”‚              β”‚
β”‚  β”‚  (B,N,3,     β”‚  β”‚  (B,N,256)   β”‚  β”‚     (B,)     β”‚              β”‚
β”‚  β”‚   224,224)   β”‚  β”‚              β”‚  β”‚              β”‚              β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜              β”‚
└─────────┼─────────────────-β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
          β”‚                  β”‚                  β”‚
          β–Ό                  β–Ό                  β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      ENCODER LAYER                                  β”‚
β”‚                                                                     β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”‚
β”‚  β”‚  CLIP (ViT)  β”‚  β”‚  LayerNorm   β”‚  β”‚  Embedding   β”‚               β”‚
β”‚  β”‚   [Frozen]   β”‚  β”‚      +       β”‚  β”‚    Layer     β”‚               β”‚
β”‚  β”‚      ↓       β”‚  β”‚   Linear     β”‚  β”‚              β”‚               β”‚
β”‚  β”‚   Linear     β”‚  β”‚              β”‚  β”‚              β”‚               β”‚
β”‚  β”‚  Projection  β”‚  β”‚  Projection  β”‚  β”‚              β”‚               β”‚
β”‚  β”‚              β”‚  β”‚              β”‚  β”‚              β”‚               β”‚
β”‚  β”‚ (512β†’768)    β”‚  β”‚ (256β†’768)    β”‚  β”‚ (50β†’768)     β”‚               β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜  β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜               β”‚
β”‚         β”‚                 β”‚                  β”‚                      β”‚
β”‚         β”‚   Visual        β”‚   State          β”‚   Task               β”‚
β”‚         β”‚   Embeddings    β”‚   Embeddings     β”‚   Embedding          β”‚
β”‚         β”‚   (B,N,768)     β”‚   (B,N,768)      β”‚   (B,1,768)          β”‚
β”‚         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    β”‚
                    β–Ό
          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
          β”‚  Element-wise   β”‚
          β”‚      Sum        β”‚
          β”‚                 β”‚
          β”‚  Visual + State β”‚
          β”‚    + Task       β”‚
          β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                   β”‚
                   β–Ό
          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
          β”‚  Add Positional β”‚
          β”‚  Bias to Frame 0β”‚
          β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                   β”‚
                   β”‚  Combined Embeddings
                   β”‚  (B,N,768)
                   β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                   TRANSFORMER BACKBONE                             β”‚
β”‚                                                                    β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚  Transformer Encoder (8 layers)                               β”‚ β”‚
β”‚  β”‚                                                               β”‚ β”‚
β”‚  β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚ β”‚
β”‚  β”‚  β”‚  Multi-Head Self-Attention (12 heads)                   β”‚  β”‚ β”‚
β”‚  β”‚  β”‚  d_model = 768,                                         β”‚  β”‚ β”‚
β”‚  β”‚  β”‚  Dropout = 0.1                                          β”‚  β”‚ β”‚
β”‚  β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚ β”‚
β”‚  β”‚                           Γ—8                                  β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚                                                                    β”‚
β”‚                   (with padding mask support)                      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                               β”‚
                               β”‚  Aggregated Features
                               β”‚  (B,N,768)
                               β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         OUTPUT HEADS                                β”‚
β”‚                                                                     β”‚
β”‚         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€-────────┐            β”‚
β”‚         β”‚                                              β”‚            β”‚
β”‚         β–Ό                                              β–Ό            β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚  β”‚  Stage Head     β”‚                          β”‚  Subtask Head   β”‚   β”‚
β”‚  β”‚  (Classifier)   β”‚                          β”‚  (Regressor)    β”‚   β”‚
β”‚  β”‚                 β”‚                          β”‚                 β”‚   β”‚
β”‚  β”‚  Linear(768β†’512)β”‚         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€  Concat:        β”‚   β”‚
β”‚  β”‚      ReLU       β”‚         β”‚                β”‚  - Features(768)β”‚   β”‚
β”‚  β”‚   Dropout(0.1)  β”‚         β”‚                β”‚  - Logits(100)  β”‚   β”‚
β”‚  β”‚  Linear(512β†’100)β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                β”‚                 β”‚   β”‚
β”‚  β”‚                 β”‚                          β”‚ Linear(868β†’512) β”‚   β”‚
β”‚  β”‚  Stage Logits   β”‚                          β”‚      ReLU       β”‚   β”‚
β”‚  β”‚  (B,N,100)      β”‚                          β”‚   Dropout(0.1)  β”‚   β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                          β”‚   Linear(512β†’1) β”‚   β”‚
β”‚                                               β”‚     Sigmoid     β”‚   β”‚
β”‚                                               β”‚                 β”‚   β”‚
β”‚                                               β”‚ Scalar Progress β”‚   β”‚
β”‚                                               β”‚    (B,N)        β”‚   β”‚
β”‚                                               β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Data Flow Explained

1. Input Processing

The model accepts three types of inputs for each sequence:

  • Image Frames (B, N, 3, 224, 224): A batch of N RGB images per sequence
  • Joint States (B, N, D_state): Robot proprioceptive information (joint angles, positions, etc.)
  • Task Index (B,): An integer identifying which task is being performed

Where B is the batch size and N is the maximum sequence length, the actual data can be shorter and then we pad it.

For every prediction, our model processes a sequence of frames, deliberately structured to provide maximum temporal context:

  • Global Anchor Frame: The first frame of the episode is included in every sequence. This is a crucial engineering choice for long-horizon tasks, as it gives the Transformer a global, unchanging reference point for the task’s initial state.
  • Subsampled Context Frames: Several preceding frames are included to capture recent history.
  • Current Frame: The frame for which the progress prediction is required.

2. Encoding Stage

Each modality is processed through its own encoder:

Visual Encoding:

  • Images are flattened from (B, N, 3, 224, 224) to (B*N, 3, 224, 224)
  • Passed through a frozen CLIP ViT-B/32 model to extract visual features
  • CLIP outputs 512-dimensional features per image
  • Features are projected to the model dimension (768) via a linear layer
  • Reshaped back to (B, N, 768)

State Encoding:

  • Joint states are normalized using LayerNorm
  • Projected from dimension 256 to 768 via a linear layer
  • Output: (B, N, 768)

Task Encoding:

  • Task index is converted to a learned embedding vector
  • The embedding is replicated across the sequence: (B,) β†’ (B, 1, 768)
  • This embedding is broadcast and added to all timesteps

3. Multimodal Fusion

The three encoded representations are combined:

input_embeddings = visual_embeddings + state_embeddings + task_embedding

Additionally, a learned positional bias is added only to the first frame:

input_embeddings[:, 0, :] += positional_bias

This creates a unified representation (B, N, 768) that contains information from all modalities.

4. Transformer Backbone

The combined embeddings are processed through an 8-layer Transformer encoder:

  • Architecture: Standard Transformer encoder with 12 attention heads
  • Dimensions: 768-dimensional hidden states, 3072-dimensional feedforward layers
  • Padding Support: The model accepts an optional padding mask (B, N) where True indicates padded positions
  • Output: Aggregated features (B, N, 768) that capture temporal and multimodal dependencies

5. Dual Output Heads

The model produces two types of predictions:

Stage Head (Classification):

  • Takes the aggregated features (B, N, 768)
  • Passes through: Linear(768β†’512) β†’ ReLU β†’ Dropout β†’ Linear(512β†’100)
  • Outputs stage logits (B, N, 100) representing 100 possible task stages:
    • 100 is a maximum number of task stages supported, in practice the stages will be task dependent
  • Trained with cross-entropy loss

Subtask Head (Regression):

  • Concatenates aggregated features with stage logits: [features, stage_logits] β†’ (B, N, 868)
  • This conditioning allows progress estimation to be stage-aware
  • Passes through: Linear(868β†’512) β†’ ReLU β†’ Dropout β†’ Linear(512β†’1) β†’ Sigmoid
  • Outputs scalar progress (B, N) in the range [0, 1]
  • Trained with MSE loss

Loss Computation

The SARMWithLoss wrapper handles training:

  1. Masking: Only non-padded positions are included in loss calculation
  2. Stage Loss: Cross-entropy between predicted logits and ground truth stage labels
  3. Progress Loss: MSE between predicted progress and ground truth progress values
  4. Total Loss: Weighted sum of both losses (default weights: 1.0 each)
total_loss = (stage_loss_weight Γ— stage_loss) + (progress_loss_weight Γ— progress_loss)

Including the loss calculation within the model wrapper made the training code simpler and more standard.

Key Design Decisions

These decisions follow the original SARM paper:

  • Why freeze CLIP? CLIP is pretrained on massive image-text datasets and provides robust visual features. Freezing it prevents overfitting on smaller robotics datasets and reduces computational cost.

  • Why condition subtask head on stage predictions? We estimate progress within a stage, not the whole episode, so it’s stage dependent.

  • Why add positional bias only to the first frame? The first frame often contains important context about the initial state. The positional bias helps the model distinguish the starting point from subsequent frames. Supposedly such anchoring is very effective for video models.

  • Why use variable-length sequences with padding?

We use Rewind Augmentation when we sometimes generate longer sequences that ‘mess up’ progress on purpose by replaying older frames in the reverse order. Because of that we need to handle sequences of varied length. This augmentation is critical for the model to learn how undoing progress looks like.

3. Complex data preparation

The data for SARM has to be prepared in a very particular way.

The core of the sampling is implemented in the custom dataloaders here.

Temporal Sampling Strategy

SARM doesn’t sample frames uniformly. Instead, it uses a sophisticated sampling strategy designed to provide temporal context:

def prepare_indices(ep_first_frame_idx, idx, skip_count=30, 
                   default_length=8, rewind_prob=0.05):
    # Sample backwards in time with skip_count intervals
    indices = [idx - i * skip_count for i in range(default_length)]
    
    # Always include the first frame of the episode
    indices.append(ep_first_frame_idx)
    
    # Reverse so time flows forward
    indices = list(reversed(indices))
    
    # 5% chance: add "rewound" frames for temporal augmentation
    if random() < rewind_prob:
        num_extra = random.integers(2, 5)
        indices += [indices[-1 - i] for i in range(1, num_extra + 1)]
    
    return indices

This creates sequences with the following structure:

Sequence construction (skip_count=30, ~1 second at 30 FPS):
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚Frame 0β”‚ t-7s  β”‚ t-6s  β”‚ t-5s  β”‚ t-4s  β”‚ t-3s  β”‚ t-2s  β”‚ t-1s  β”‚  t    β”‚
β”‚(start)β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚       β”‚(curr) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

With 5% probability, add rewind frames:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”
β”‚...    β”‚  t    β”‚ t-1s (again)β”‚ t-2s  β”‚ t-3s  β”‚(curr) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜

Why this design?

  1. Always anchor to episode start: Frame 0 provides consistent context about initial conditions
  2. Uniform temporal spacing: 1-second intervals capture motion patterns without redundancy
  3. Rewind augmentation: Teaches the model temporal reversibility and robustness
  4. Future context avoided: Model only sees past and present, not future frames

Delta Timestamps Pattern

The sampling strategy is complemented by a clever timestamping scheme:

DELTA_TIMESTAMPS = [HIGH_NEGATIVE_TIMEDELTA] + [-7 + i for i in range(8)]
# Results in: [1e6, -7, -6, -5, -4, -3, -2, -1, 0]

When applied to current timestamp t:

  • Frame 0: Gets timestamp β‰ˆ -∞ (approximated as episode start)
  • Frames 1-7: Get timestamps [t-7, t-6, …, t-1]
  • Frame 8: Gets timestamp t (current)

This ensures consistent temporal windows regardless of where you are in the episode. My custom dataset SARMDataset uses a dataset provided by the BEHAVIOR codebase under the hood that allows specifying delta timestamps for more efficient sampling.

Variable-Length Sequence Handling

Real episodes have variable lengths, and sequences can have different numbers of frames. SARM handles this with:

def collate_fn(batch):
    # Find max sequence length in batch
    max_length = max(sample["sequence_length"] for sample in batch)
    
    # Pad all sequences to max_length
    batched_images = torch.zeros(batch_size, max_length, C, H, W)
    batched_padding_mask = torch.ones(batch_size, max_length, dtype=torch.bool)
    
    for i, sample in enumerate(batch):
        seq_len = sample["sequence_length"]
        batched_images[i, :seq_len] = sample["images"]
        batched_padding_mask[i, :seq_len] = False  # False = valid, True = padding

The padding mask is then passed to the Transformer to ensure padded positions don’t contribute to attention or loss:

Example batch with lengths [9, 11, 13, 9]:

Padded to max_length=13:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Seq 1 (9)   β”‚ Seq 2 (11)  β”‚ Seq 3 (13)  β”‚ Seq 4 (9)   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ [V][V]...[V]β”‚ [V][V]...[V]β”‚ [V][V]...[V]β”‚ [V][V]...[V]β”‚
β”‚ [P][P][P][P]β”‚ [P][P]      β”‚             β”‚ [P][P][P][P]β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
  V = Valid token
  P = Padding token (masked out)

Inference Sampling Strategy

To use SARM for VLA training, we need to run our original video dataset through SARM.

But that dataset was huge to begin with! But we don’t need to evaluate every frame (with its proceeding sequence).

With 5-second sampling at 30 FPS, I evaluate only 1 out of every 150 frames (5s Γ— 30 FPS), reducing computational cost by 150Γ—.

I implemented a following dataloader (source).

Why jitter?

Jitter prevents the model from overfitting to fixed timestamps and produces more robust progress estimates by sampling at slightly varied intervals rather than exact multiples of 5 seconds.

See INFERENCE_README for more details on the inference.

Translating the progress to VLA training weights

Additionally I implement the progress mapping to the weights following the SARM paper (equations 8-9): - Computes progress deltas rΜ‚α΅’ = Ο†(t+Ξ”) - Ο†(t) - Uses running statistics (ΞΌ, Οƒ) to normalize - Applies linear ramp between (ΞΌ - 2Οƒ) and (ΞΌ + 2Οƒ) - Optionally uses threshold ΞΊ for decisive weighting

See the code in weight utils.

4. Training & Results

This implementation of SARM was multi-task, however since there was so much data, I decided that it would be easier to evaluate and visualize it on a single task first.

Having 200 episodes for each task, I divided the data into the following sets:

  • “train_episodes”: 1-90,
  • “val_episodes”: 91-105,
  • “test_episodes”: 106-200

In general performance on the validation set wasn’t the best indicator of actual model performance when I analyzed it on the test dataset. Training for more steps was helpful.

For training details see the config and the training script. The 10k-step snapshot has been trained on a single RTX5090 over several hours. The model is also compatible with training on MPS. The key bottleneck was the dataset access and video processing.

Visualizations & analysis

I performed detailed analysis on how well the models were predicting the progress here.

Ground Truth

First, it’s important how the ‘ground truth’ data looks like, here is a visualization:

Based on the data annotations and the stage statistics I was able to generate our ‘Ground truth’ of progress. It was also a useful sanity check if the ground truth data looks right, e.g. having negative progress in the ground truth data would mean that there were bugs. We were only adding ’negative progress’ through the Rewind augmentation later on.

Comparing models vs ground truth and each other

Model checkpoint comparison vs ‘ground truth’ on a sample of episodes:

The first 3 episodes were in the training data and the 2 last ones weren’t present. You can see here that the yellow (10k steps) model is better fitted to the data in the training set.

Understanding the bias of the model

So the model wasn’t perfect, what type of mistakes was it making?

Overall, the model was leaning towards underestimating the progress. And the key problem was from predicting wrong stage number.

Applicability and Limitations

The caveat here that the task 8 was multimodal, the stages could be done in variable order, breaking the fundamental assumption of the fix stage order in SARM. Visualizations for a task fitting SARM assumptions would look better.

  • The SARM Assumption: Stage-Aware modeling assumes a generally linear path through semantic checkpoints (Stage 1 $\rightarrow$ Stage 2 $\rightarrow$ Stage 3…).
  • The Failure Case: Multimodal Progress: If a task allows for subtasks to be completed in an arbitrary order (e.g., “Tidy up the room”), the progress estimation becomes inherently multimodal, and our regression model, forced to average these possibilities, loses accuracy.

Handling different sequences of stages in demonstrations

Removing outliers

If the majority of demonstrations are done in a consistent way, then we can remove the outlier demonstrations that create confusion. (Annotations to generate ground truth are enough to blacklist such episodes).

Subtask splitting

The SARM model implemented here can handle multiple tasks.

Therefore, if a task can be done using different sequences of stages, we can transform it into set of related tasks with different demonstration variations.

If there are multiple different ways represented in similar proportions, e.g. ‘pick up toy 1’ then ‘pick up toy 2’, and the reverse, we can change into two tasks pick_up_toys_1_2, pick_up_toys_2_1.

Equal sampling

We can also decide not use SARM for such tasks.

SARM in Behavior challenge

For the BEHAVIOR challenge specifically, we were very time constrained, and we ended up not having enough time to apply the model for the final checkpoint training, additionally, only about 30% of tasks fulfilled the fixed stage ordering for SARM.

We performed quick fine tuning with weighted sampling earlier on a subset of data (not based on SARM), but it was difficult to see if it was actually helpful (eval in general was pretty challenging).

Despite these challenges, SARM remains a promising approach for datasets with proper stage annotations and sequential task structure. Our analysis on the 30% of tasks with fixed stage ordering showed the model could accurately track progress. For the remaining tasks, the subtask splitting approach outlined above could make SARM applicable, potentially enabling more efficient VLA training through intelligent data selection across the full dataset.

This is my mathjax support partial