Large Language Diffusion Models
LLaDA
Masked diffusion modeling
Training:
- input is partially masked sentence, predict masked tokens, cross entropy loss on them
- randomly sample a proportion of tokens to mask. Why? So that the model can practice demasking different levels of masking
- gradients: since loss is an expectation over the time and data, similar to how SGD samples a data point, we sample a data point AND a t and take gradient of that loss
- SFT for instruction following by including both a prompt and response, but only masking the response
Inference:
- demask whole response, then re-mask the least confident tokens, repeat
- no KV cache
Results:
- better at reverse predictions