Large Language Diffusion Models

LLaDA

Masked diffusion modeling

Training:

  • input is partially masked sentence, predict masked tokens, cross entropy loss on them
  • randomly sample a proportion of tokens to mask. Why? So that the model can practice demasking different levels of masking
  • gradients: since loss is an expectation over the time and data, similar to how SGD samples a data point, we sample a data point AND a t and take gradient of that loss
  • SFT for instruction following by including both a prompt and response, but only masking the response

Inference:

  • demask whole response, then re-mask the least confident tokens, repeat
  • no KV cache

Results:

  • better at reverse predictions