根據 [2], 跳過各種英文或數學, 3.3 或 5.1. 裡面都引用了 C = 6NBS. 其中 C = training computing, B = batch size, S = number of parameter updates (訓練總 token 數), and N = non-embedding parameter count (參數數量).
For the 𝑖-th input token 𝑡𝑖 at the 𝑘th prediction depth:
The representation of the 𝑖th token at the (𝑘−1)th depth, denoted as h𝑘−1i∈ ℝ^𝑑, is taken. If 𝑘 = 1, h𝑘−1𝑖 is the representation provided by the main model.
The embedding of the (𝑖+𝑘)th token, Emb(𝑡𝑖+𝑘) ∈ ℝ^𝑑, is computed using the shared embedding layer.
Both h𝑘−1𝑖 and Emb(𝑡𝑖+𝑘) are normalized using RMSNorm (Root Mean Square Normalization).
The normalized representations are concatenated ([·; ·]) and linearly projected using the projection matrix𝑀𝑘 :
h′ki= Mk[RMSNorm(hk−1i); RMSNorm(Emb(ti+k))].
Here, h′𝑘𝑖 is the combined representation that serves as the input to the Transformer block at the 𝑘th depth.
Step 2: Transformer Block
The combined representation h′𝑘𝑖 is passed through the 𝑘th Transformer block (TRM𝑘(·)): h𝑘1:𝑇−𝑘 = TRM𝑘(h′𝑘1:𝑇−𝑘).
This produces the output representation h𝑘𝑖 for the 𝑖th token at the 𝑘th depth. The slicing operation 1:𝑇−𝑘 ensures that the sequence length is adjusted appropriately for each prediction depth.
Step 3: Output Head
The output representation h𝑘𝑖 is passed through the shared output head (OutHead(·)), which:
Linearly maps h𝑘𝑖 to logits.
Applies the Softmax function to compute the probability distribution over the vocabulary:𝑃𝑘𝑖+𝑘+1 = OutHead(h𝑘𝑖).
Here, 𝑃𝑘𝑖+𝑘+1 ∈ ℝ^𝑉 represents the probability distribution for the (𝑖+𝑘+1)th token, where 𝑉 is the vocabulary size.
最後一個重點來了. DeepSeek 只有在 training 的時候使用 one step MTP. 在 inference 的時候, 用的演算法又有不同. “We can also repurpose these MTP modules for speculative decoding (預言家, 投機演算法) [2] to further improve the generation latency."[1]
Training 的 loss function 計算也給出來了. 首先, 針對每個 depth (k) 都做計算, P 就是上面的 P. 最後把不同深度的 loss function 取平均值.