📚 Auto-publish: Add/update 1 blog posts
Some checks failed
Hugo Publish CI / build-and-deploy (push) Failing after 12m36s
Some checks failed
Hugo Publish CI / build-and-deploy (push) Failing after 12m36s
Generated on: Sun Aug 3 03:49:57 UTC 2025 Source: md-personal repository
This commit is contained in:
@@ -9,19 +9,19 @@ Mixture-of-Experts (MoEs) are neural network architectures that allow different
|
||||
### 1. Challenge: Non-Differentiability of Routing Functions
|
||||
|
||||
**The Problem:**
|
||||
Many routing mechanisms, especially "Top-K routing," involve a discrete, hard selection process. A common function is `KeepTopK(v, k)`, which selects the top `k` scoring elements from a vector `v` and sets others to $-\infty$ or $0$.
|
||||
Many routing mechanisms, especially "Top-K routing," involve a discrete, hard selection process. A common function is `KeepTopK(v, k)`, which selects the top `k` scoring elements from a vector `v` and sets others to \(-\infty\) or \(0\).
|
||||
|
||||
$$
|
||||
\[
|
||||
KeepTopK(v, k)_i = \begin{cases} v_i & \text{if } v_i \text{ is in the top } k \text{ elements of } v \\ -\infty & \text{otherwise.} \end{cases}
|
||||
$$
|
||||
\]
|
||||
|
||||
This function is **not differentiable**. Its gradient is zero almost everywhere and undefined at the threshold points, making it impossible to directly train the gating network's parameters (e.g., $W_g$) using standard gradient descent.
|
||||
This function is **not differentiable**. Its gradient is zero almost everywhere and undefined at the threshold points, making it impossible to directly train the gating network's parameters (e.g., \(W_g\)) using standard gradient descent.
|
||||
|
||||
**Solutions (Stochastic Approximations):**
|
||||
To enable end-to-end training, non-differentiable routing decisions must be approximated with differentiable or stochastic methods.
|
||||
|
||||
* **Stochastic Scoring (e.g., Shazeer et al. 2017):**
|
||||
The expert score $H(x)_i = (x \cdot W_g)_i + \text{StandardNormal}() \cdot \text{Softplus}((x \cdot W_{noise})_i)$ introduces Gaussian noise. This makes the scores themselves stochastic, which can be leveraged with other methods.
|
||||
The expert score \(H(x)_i = (x \cdot W_g)_i + \text{StandardNormal}() \cdot \text{Softplus}((x \cdot W_{noise})_i)\) introduces Gaussian noise. This makes the scores themselves stochastic, which can be leveraged with other methods.
|
||||
|
||||
* **Gumbel-Softmax Trick (or Concrete Distribution):**
|
||||
This method allows for differentiable sampling from categorical distributions. Instead of directly picking the top-k, Gumbel noise is added to the scores, and a Softmax (with a temperature parameter) is applied. This provides a continuous, differentiable approximation of a discrete choice, allowing gradients to flow back.
|
||||
@@ -33,7 +33,7 @@ To enable end-to-end training, non-differentiable routing decisions must be appr
|
||||
A simpler approximation where, during the backward pass, gradients are treated as if the non-differentiable operation was an identity function or a simple smooth function.
|
||||
|
||||
* **Softmax after TopK (e.g., Mixtral, DBRX, DeepSeek v3):**
|
||||
Instead of `Softmax(KeepTopK(...))`, some models apply a Softmax *only to the scores of the selected TopK experts*, and then assign $0$ to the rest. This provides differentiable weights for the selected experts while still enforcing sparsity.
|
||||
Instead of `Softmax(KeepTopK(...))`, some models apply a Softmax *only to the scores of the selected TopK experts*, and then assign \(0\) to the rest. This provides differentiable weights for the selected experts while still enforcing sparsity.
|
||||
|
||||
### 2. Challenge: Uneven Expert Utilization (Balancing Loss)
|
||||
|
||||
@@ -45,27 +45,27 @@ Left unchecked, the gating network might learn to heavily favor a few experts, l
|
||||
**Solution: Heuristic Balancing Losses (e.g., from Switch Transformer, Fedus et al. 2022)**
|
||||
An auxiliary loss is added to the total model loss during training to encourage more even expert usage.
|
||||
|
||||
$$ \text{loss}_{\text{auxiliary}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i $$
|
||||
\(\( \text{loss}_{\text{auxiliary}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i \)\)
|
||||
|
||||
Where:
|
||||
* $\alpha$: A hyperparameter controlling the strength of the auxiliary loss.
|
||||
* $N$: Total number of experts.
|
||||
* $f_i$: The **fraction of tokens *actually dispatched* to expert $i$** in the current batch $B$.
|
||||
$$ f_i = \frac{1}{T} \sum_{x \in B} \mathbf{1}\{\text{argmax } p(x) = i\} $$
|
||||
($p(x)$ here refers to the output of the gating network, which could be $s_{i,t}$ in the DeepSeek/classic router. The $\text{argmax}$ means it counts hard assignments to expert $i$.)
|
||||
* $P_i$: The **fraction of the router *probability mass* allocated to expert $i$** in the current batch $B$.
|
||||
$$ P_i = \frac{1}{T} \sum_{x \in B} p_i(x) $$
|
||||
($p_i(x)$ is the learned probability (or soft score) from the gating network for token $x$ and expert $i$.)
|
||||
* \(\alpha\): A hyperparameter controlling the strength of the auxiliary loss.
|
||||
* \(N\): Total number of experts.
|
||||
* \(f_i\): The **fraction of tokens *actually dispatched* to expert \(i\)** in the current batch \(B\).
|
||||
\(\( f_i = \frac{1}{T} \sum_{x \in B} \mathbf{1}\{\text{argmax } p(x) = i\} \)\)
|
||||
(\(p(x)\) here refers to the output of the gating network, which could be \(s_{i,t}\) in the DeepSeek/classic router. The \(\text{argmax}\) means it counts hard assignments to expert \(i\).)
|
||||
* \(P_i\): The **fraction of the router *probability mass* allocated to expert \(i\)** in the current batch \(B\).
|
||||
\(\( P_i = \frac{1}{T} \sum_{x \in B} p_i(x) \)\)
|
||||
(\(p_i(x)\) is the learned probability (or soft score) from the gating network for token \(x\) and expert \(i\).)
|
||||
|
||||
**How it works:**
|
||||
The loss aims to minimize the product $f_i \cdot P_i$ when $f_i$ and $P_i$ are small, effectively pushing them to be larger (closer to $1/N$). If an expert $i$ is overused (high $f_i$ and $P_i$), its term in the sum contributes significantly to the loss. The derivative with respect to $p_i(x)$ reveals that "more frequent use = stronger downweighting," meaning the gating network is penalized for sending too much traffic to an already busy expert.
|
||||
The loss aims to minimize the product \(f_i \cdot P_i\) when \(f_i\) and \(P_i\) are small, effectively pushing them to be larger (closer to \(1/N\)). If an expert \(i\) is overused (high \(f_i\) and \(P_i\)), its term in the sum contributes significantly to the loss. The derivative with respect to \(p_i(x)\) reveals that "more frequent use = stronger downweighting," meaning the gating network is penalized for sending too much traffic to an already busy expert.
|
||||
|
||||
**Relationship to Gating Network:**
|
||||
* **$p_i(x)$ (or $s_{i,t}$):** This is the output of the **learned gating network** (e.g., from a linear layer followed by Softmax). The gating network's parameters are updated via gradient descent, influenced by this auxiliary loss.
|
||||
* **$P_i$:** This is *calculated* from the outputs of the learned gating network for the current batch. It's not a pre-defined value.
|
||||
* **\(p_i(x)\) (or \(s_{i,t}\)):** This is the output of the **learned gating network** (e.g., from a linear layer followed by Softmax). The gating network's parameters are updated via gradient descent, influenced by this auxiliary loss.
|
||||
* **\(P_i\):** This is *calculated* from the outputs of the learned gating network for the current batch. It's not a pre-defined value.
|
||||
|
||||
**Limitation ("Second Best" Scenario):**
|
||||
Even with this loss, an expert can remain imbalanced if it's consistently the "second best" option (high $P_i$) but never the *absolute top choice* that gets counted in $f_i$ (especially if $K=1$). This is because $f_i$ strictly counts hard assignments based on `argmax`. This limitation highlights why "soft" routing or "softmax after TopK" approaches can be more effective for truly even distribution.
|
||||
Even with this loss, an expert can remain imbalanced if it's consistently the "second best" option (high \(P_i\)) but never the *absolute top choice* that gets counted in \(f_i\) (especially if \(K=1\)). This is because \(f_i\) strictly counts hard assignments based on `argmax`. This limitation highlights why "soft" routing or "softmax after TopK" approaches can be more effective for truly even distribution.
|
||||
|
||||
### 3. Challenge: Overfitting during Fine-tuning
|
||||
|
||||
|
Reference in New Issue
Block a user