Deriving Muon

Muon

  • Goal: keep the RMS norms of each activation small (bounded by 1)
  • For determining ∆W, the change to the weight matrix, choose it to maximize the effect on the loss while maintaining that the output doesn’t change too much
  • how much does it change? scales with the operator norm of W **- hence we choose the update ∆W to minimize <gradient, ∆W> (make loss go down) subject to \||∆W|\|_RMS <= eta **

Two (three) ways to view the closed form of this constrained optimization

  1. If the gradient’s SVD decomposition is USigma V^T, then ∆W = eta * UV^T: “orthagonalization” because it’s just setting all the singular values to eta.
  2. ∆W = eta _ (gradient) _ (gradient^T gradient)^(-1/2). This is useful for seeing that ∆W depends only on the gradient (and eta).
  3. The third, less important interpretation is that ∆W is the closest orthogonal matrix (as measured in frobenius norm) to the gradient. This is just the formulation given in the second blog post https://kellerjordan.github.io/posts/muon/.

Discussion / Comparison to SGD, external to the paper

  • How does this compare to a variational formulation of SGD? The insight is that SGD treats weights as a vector (we can flatten a matrix into a vector if needed), whereas Muon treated weights as a matrix (connects to Why Muon?)
  • can write SGD as min <gradient, ∆v> subject to |∆v |_2^2 <= eta^2 * || gradient ||_2^2 (proof that this gives you back SGD as closed form by Cauchy schwarz)

Why use Muon? It converges much faster than adamW

  • in part because it’s basing the updates off the whole matrix gradient, and using matrix properties (i.e. SVD) of the matrix, while SGD/adamW updates weights component by component, as vectors
  • this results in the adamW updates being close to low rank, i.e. for some principle components, there’s low singular values so little learning in those directions, while Muon made all singular values equal

The rest of Muon is just this neat trick (Newton Schulz) to efficiently compute the UV^T term. Apparently if you just apply like an odd degree terms only polynomial p to U Sigma V^T, the p(U Sigma V^T) = U p(Sigma) V^T, so we just choose p such that p(p(p(p(…(p(Sigma)) = identity and iteratively apply p to U Sigma V^T