参数量分析#

为了方便描述,首先定义各符号的含义:

  • v:表示embedding层有多少个词;
  • b:表示batch_size;
  • s:表示seq_length,为文本长度;
  • h:表示hidden_dim,为隐藏层的维度;
  • a:表示多头注意力中有多个头;
  • h_a:表示hidden_dim_per_head,为多头注意力中每个头的隐藏层维度;
  • n:表示总共有 n 层 transformer;

另外,在实际使用时一般都有 h_a * a = h 成立。

1、BERT 参数量#

1.1 BERT类模型#

embedding 层:参数量为 vh

MHA 层:这一层带有参数的有三部分,分别为 QKV 的权重矩阵、dense 层、Layer Norm层,分别计算:

  • QKV 的权重矩阵:每个权重矩阵的输入向量的维度为 h,有 a 个 head,每个 head 的隐层维度是 h_a,则参数量为 hah_a=h^2,三个权重矩阵的参数量就是再乘以3,即 3h^2

  • dense 层:这个比较简单,直接就是 h^2

  • Layer Norm层:这一层里面不是矩阵乘法,而是向量的相应位置的元素相乘,所以 \gamma\beta 的参数量是 2h

FFN 层:这一层包含两个 dense 层,这两个 dense 层先将维度由 h 升到 4h,再由 4h 降到 h,然后接一个 Layer Norm 层,所以参数量为 h*4h+4h*h+2h=8h^2+2h

将上述几部分加起来得到每层 transformer 的参数量为 12h^2+4h,一般会把 4h 忽略掉,则参数量为 12h^2

将 embedding 层和 n 层 transformer 加起来总的参数量为 n \cdot (12h^2+4h) + vh

1.2 BERT 模型结构#

BERT 的模型结构为:vocab 总共有 21128 个,隐藏层维度为 768,12 层 transformer,MHA 中是 12 个 head,每个 head 的隐层维度是 64。官方论文中总参数量为 110M。

将这些参数代入到上一小节的公式中得到:

\begin{split} &12 * (12 * 768 * 768 + 4 * 768) + 21128 * 768 \\ = &12*(7077888+3072)+16226304 \\ = &101197824 \\ = &101M \end{split}

也就是说大概 101M 的参数量,相比于官方所说的 110M,这其中的差距主要包括最后的分类器层,以及上述计算过程中都没有考虑偏置项(bias)。

BERT结构如下所示:

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

2、LLAMA-65B 参数量#

2.1 LLAMA-65B 参数量#

LLAMA 中会使用 SWiGLU 结构,增加一个新的符号:

  • h_{\text{swi}}:表示 SWiGLU 中隐藏层的维度;

embedding 层:参数量为 vh

MHA 层:这一层带有参数的有三部分,分别为 QKV 的权重矩阵、dense 层、Layer Norm层,分别计算

  • QKV 的权重矩阵:每个权重矩阵的输入向量的维度为 h,有 a 个 head,每个 head 的隐层维度是 h_a,则参数量为 hah_a=h^2,三个权重矩阵的参数量就是再乘以3,即 3h^2

  • dense 层:这个比较简单,直接就是 h^2

  • Layer Norm层:这一层里面不是矩阵乘法,而是向量的相应位置的元素相乘,所以 \gamma\beta 的参数量是 2h

MLP 层:LLAMA 中的 MLP 层使用的是 SWiGLU 结构,从 2.2 小节中可以看到有三个 Linear 结构,总的参数量为:3*h*h_{\text{swi}}

将上述几部分加起来得到每层 transformer 的参数量为 4h^2+3hh_{\text{swi}}+4h,一般会把 4h 忽略掉,则参数量为 4h^2+3hh_{\text{swi}}

将 embedding 层和 n 层 transformer 加起来总的参数量为 n \cdot (4h^2+3hh_{\text{swi}}+4h) + vh

2.2 LLAMA-65B 模型结构#

vocab 总共有 32000 个,隐藏层维度为 8192,80 层 transformer,MHA 中是 64 个 head,每个 head 的隐层维度是 128。

将这些参数代入到上一小节的公式中得到:

\begin{split} &80 * (4 * 8192 * 8192 + 3 * 8192 * 22016 + 4 * 8192) + 32000 * 8192 \\ = &80*(268435456+541065216+32768)+262144000 \\ = &65024819200 \\ = &65B \end{split}

也就是说大概 65B 的参数量。

结构如下所示:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 8192, padding_idx=0)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (input_layernorm): LlamaRMSNorm()
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (v_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (o_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (post_attention_layernorm): LlamaRMSNorm()
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=8192, out_features=22016, bias=False)
          (down_proj): Linear4bit(in_features=22016, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=8192, out_features=22016, bias=False)
          (act_fn): SiLUActivation()
        )
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=8192, out_features=32000, bias=False)
)

Reference#