Guided Speculative Inference for Efficient Test-Time Alignment of LLMs

Harvard Guided speculative decoding

Goal: do decoding more than one step at a time quickly, while also picking high reward outputs. Generally do this using importance sampling with a failsafe, and prove theoretical closeness to desired distribution.

  • tilted policy: solves maximizing expected reward regularized by KL divergence to base model
    • corresponds to like importance sampling based on the softmax of the rewards

for each time step t during autoregressive decoding

  • soft best of n (sample from softmax of rewards) is an approximation to the tilted policy, so the ideal thing to do would be to just autoregressively sample soft best of n from the base model each turn (soft BoN = importance sampling)
    • we’re actually only approximating the reweighting factor using the discrete n samples
    • can’t use ground truth reweighting because we don’t have the rewards for all y, only the sampled ones
  • but ideally would like to use a smaller model to guess the next output using soft best of n and look at the rewards of the n
    • just edit the importance sampling factor, which corresponds to changing the rewards to the tilted rewards (including likelihood ratio of base to small model)
  • have theoretical guarantees on the KL of this dist from the tilted policy, and the expectation of the reward

  • if the tilted reward from the best sample is too low, then revert to sampling autoregressively from the base model

Results:

  • does well on the math and code benchmarks, beats other speculative reward decoding methods, doesn’t usually beat best of n with base model but faster
  • can see examples where the reversion is helpful because the small model is wrong