GRASP: Parallel Stochastic Gradient-Based Planning for World Models

Michael Psenka1,2,○, Michael Rabbat2, Aditi Krishnapriyan1, Yann LeCun2,3,*, Amir Bar2,*

1University of California, Berkeley    2Meta FAIR    3New York University

*Equally Advised    Work done at Meta

tl;dr: A new gradient-based planner for world models, incorporating a lifted-states approach, stochastic optimization, and a reshaped gradient structure for a smoother optimization landscape.

Paper Code (soon!)
GRASP examples

Virtual states learned through planning. All examples are instantiations of our planner at horizon 50 in the Point-Maze, Wall-Single, and Push-T environments. All world models are over visual input. Regardless of the dynamics constraint relaxation and state noising, directly optimized states find realistic, non-greedy paths towards the goal.

Ball navigation demo

Exploration through state noising and a relaxed dynamics loss in a synthetic navigation environment (state-based). State-based exploration allows for more stable exploration around multiple greedy minima.

Abstract

World models simulate environment dynamics from raw sensory inputs like video. However, using them for planning can be challenging due to the vast and unstructured search space. We propose a robust and highly parallelizable planner that leverages the differentiability of the learned world model for efficient optimization, solving long-horizon control tasks from visual input. Our method treats states as optimization variables ("virtual states") with soft dynamics constraints, enabling parallel computation and easier optimization. To facilitate exploration and avoid local optima, we introduce stochasticity into the states. To mitigate sensitive gradients through high-dimensional vision-based world models, we modify the gradient structure to descend towards valid plans while only requiring action-input gradients. Our planner, which we call GRASP (Gradient RelAxed Stochastic Planner), can be viewed as a stochastic version of a non-condensed or collocation-based optimal controller. We provide theoretical justification and experiments on video-based world models, where our resulting planner outperforms existing planning algorithms like the cross-entropy method (CEM) and vanilla gradient-based optimization (GD) on long-horizon experiments, both in success rate and time to convergence.

Planning with world models is deceptively hard

We assume a learned world model $F_\theta$ predicts the next latent state:

$$s_{t+1} = F_\theta(s_t, a_t), \qquad a_t \in \mathbb{R}^k.$$

Given $s_0$ and a goal $g$, the standard optimization-based planner chooses actions $\mathbf{a}=(a_0,\dots,a_{T-1})$ by rolling out the model and minimizing terminal error:

$$\min_{\mathbf{a}} \;\;\big|s_T(\mathbf{a}) - g\big|_2^2, \quad \text{where } s_T(\mathbf{a}) = F_\theta^T(s_0,\mathbf{a}).$$

This looks simple, but it creates a deep computation graph (a $T$-step composition of $F_\theta$), which is both expensive and often poorly conditioned. The loss landscape can be highly non-greedy: successful plans may need to move away from the goal before moving toward it.

Loss landscape visualization

Planning is non-greedy and the loss can be jagged. (a) Successful trajectories can temporarily increase distance to the goal. (b) Standard rollout losses can be sharp and brittle. (c) Our objective smooths the landscape during exploration.

Lifting planning into virtual states to parallelize and stabilize

Instead of forcing a serial rollout during optimization, we introduce intermediate virtual states $\mathbf{s}=(s_1,\dots,s_{T-1})$ and enforce dynamics only through pairwise consistency:

$$\min_{\mathbf{s},\mathbf{a}} \;\; \sum_{t=0}^{T-1} \big|F_\theta(s_t,a_t)-s_{t+1}\big|_2^2, \quad \text{with } s_0\text{ fixed and } s_T=g.$$

This "lifted" view has a big practical benefit:

  • All model evaluations $F_\theta(s_t,a_t)$ are parallel across time. No $T$-step serial rollout is needed just to compute a training signal.

But naïvely optimizing states introduces two common failure modes:

  1. Local minima in state space (e.g., "phantom" trajectories that cut through walls).
  2. Brittle state gradients in high-dimensional latent spaces (the optimizer can exploit tiny off-manifold perturbations in $s_t$).
Expected behavior

(a) Expected behavior

Adversarial example

(b) Adversarial example

Sensitivity of state gradient structure. Examples of three states far away from the goal on the right (either in-distribution or out-of-distribution), such that taking a small step along the gradient $s' = s - \epsilon \nabla_s \mathcal{L}(s), \mathcal{L}(s) = \| F_\theta(s, a=\mathbf{0}) - g\|_2^2$, leads to a nearby state $s'$ that solves the planning problem in a single step: $F_\theta(s', \mathbf{0}) = g$. Thus, optimizing states directly through the world model $F_\theta$ can be quite challenging.

That leads to the two ingredients that make the method work.

Ingredient 1: Exploration by injecting noise into the virtual states

Even with a smoother objective, planning remains nonconvex. We add lightweight exploration by noising the state iterates during optimization:

  • Update $\mathbf{a}$ by gradient descent on $\mathcal{L}$,
  • Update $\mathbf{s}$ by a descent-like step plus Gaussian noise.

A simple version is:

$$s_t \leftarrow s_t - \eta_s \,\partial_{s_t}\mathcal{L} \;+\; \sigma_{\text{state}}\xi,\qquad a_t \leftarrow a_t - \eta_a\nabla_{a_t}\mathcal{L}, \quad \xi\sim\mathcal{N}(0,I).$$

In practice, this helps the planner "hop" out of bad basins in the lifted space, while actions remain guided by stable gradients.

Ingredient 2: Cut brittle state-input gradients, keep action gradients

Directly backpropagating through the state input $s_t$ can be adversarial: the optimizer can "cheat" by nudging $s_t$ off-manifold so that the model outputs whatever is convenient. To prevent this, we stop gradients through the state inputs to the world model.

Let $\bar{s}_t$ denote a stop-gradient copy of $s_t$ (same value, no gradient). We use:

$$\mathcal{L}_{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a}) = \sum_{t=0}^{T-1}\big|F_\theta(\bar{s}_t,a_t)-s_{t+1}\big|_2^2.$$

Now:

  • $a_t$ still gets a clean gradient signal through $F_\theta(\bar{s}_t,a_t)$,
  • but we avoid optimizing $s_t$ via fragile $\nabla_{s}F_\theta$ directions.

Dense goal shaping

If we only enforce pairwise consistency, the optimization can drift toward "matching itself" rather than progressing toward $g$. So we add a simple one-step-to-goal shaping term:

$$\mathcal{L}_{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a}) = \sum_{t=0}^{T-1}\big|F_\theta(\bar{s}_t,a_t)-g\big|_2^2.$$

Final objective (two terms, one knob):

$$\mathcal{L}(\mathbf{s},\mathbf{a}) = \mathcal{L}_{\text{dyn}}^{\text{sg}}(\mathbf{s},\mathbf{a}) + \gamma \,\mathcal{L}_{\text{goal}}^{\text{sg}}(\mathbf{s},\mathbf{a}).$$

Intuition:

  • Cutting state gradients leads to smoother state optimization and removal of state adversarial examples.
  • Goal term makes each action step push predicted next-state toward the goal.

Periodic "sync": briefly switch back to true rollout gradients for refinement

The lifted, grad-cut objective is great for fast parallel exploration, but it's still an approximation of the original serial rollout objective. So every $K_{\text{sync}}$ iterations we "sync":

  1. Roll out from $s_0$ using the current actions $\mathbf{a}$ (a standard serial rollout).
  2. Take a few small gradient steps on the original terminal loss:

    $$\mathbf{a} \leftarrow \mathbf{a} - \eta_{\text{sync}}\nabla_{\mathbf{a}}\big|s_T(\mathbf{a})-g\big|_2^2.$$

Intuition:

  • The parallel lifted phase explores a smoother landscape, while a few steps of full GD allows for the benefits of both the smooth landspace of the lifted dynamics optimization and the sharpness of the full GD landscape.

Putting it together

Each optimization loop alternates between:

  • Parallel lifted updates on $(\mathbf{s},\mathbf{a})$ using $\mathcal{L}$ (grad-cut on state inputs + dense goal shaping + state noise)
  • Occasional rollout sync updating only $\mathbf{a}$ with full backprop through time

This gives a planner that is:

  • parallelizable during exploration,
  • robust to brittle state gradients,
  • and still grounded by periodic true rollout refinement.
Method diagram

Serial rollout vs. our lifted planning. (a) Standard planners require a $T$-step rollout before getting a learning signal. (b) We optimize virtual states directly with pairwise consistency, enabling parallel model evaluations; we periodically sync with a short serial rollout for refinement.

Minimal pseudocode

  1. Initialize actions $\mathbf{a}$ and virtual states $\mathbf{s}$ (with $s_0$ fixed, $s_T=g$).
  2. Repeat for $k=1..K$:
    • Compute $\mathcal{L}(\mathbf{s},\mathbf{a})$.
    • Gradient step on actions: $a_t \leftarrow a_t - \eta_a \nabla_{a_t}\mathcal{L}$.
    • State step + noise: $s_t \leftarrow s_t - \eta_s \nabla_{s_t}\mathcal{L} + \sigma_{\text{state}}\xi$.
    • Every $K_{\text{sync}}$ steps: roll out from $s_0$ and do a few small steps on $|s_T(\mathbf{a})-g|^2$.

Citation

@article{psenka2026grasp,
    title={GRASP: Parallel Stochastic Gradient-Based Planning for World Models},
    author={Psenka, Michael and Rabbat, Michael and Krishnapriyan, Aditi and LeCun, Yann and Bar, Amir},
    journal={arXiv preprint arXiv:2602.00475},
    year={2026}
}