分散式機器學習小註解

先前學過的 Tensorflow model training, 久而久之也還給老師了. 今天趁颱風假學習分散式的機器學習, 順便把基本功也補一些回來!

需要分散式學習的原因就是 AI (ML) model 太大了, 因此有各式各樣的方法把工作分散給不同的 CPU, GPU, TPU (後續用 NPU 涵蓋之). 分散的方法包括把 training data 分散給大量 NPU, 或是把 model 的不同層拆給不同的 NPU 做. 後者比較沒有參數及時互通的問題就略過不談.

Data 分散出去給不同的 NPU 學習, 最明顯的問題就是大家拿到的 data 不同, 算出來的 gradient 理論上也不同, 那要麼收斂到同一版呢? 解決這個問題的架構 (architecture) 有兩個大方向, 分別為 synchronous 和 Asynchronous 兩種.

Synchronous 架構下, 每個 NPU 都要等其它 NPU 做完一個段落. 然後大家同步一下參數. 同步的方式可能只是將大家得到的數取個平均. 假設原本一個 update grdient 的段落是一個 batch. 現在有 N 個NPU 均分 data, 所以每個 NPU 只要做一個 mini-batch.

mini-batch = batch * strategy.num_replicas_in_sync (e.g. 1/N)

在一台機器有多個 NPU 的時候, 我們稱這個策略為 mirrorstrategy. 因為每個 NPU 做的事情都一樣, 像是照鏡子. 假如我有多台機器, 每一台機器 (worker) 都執行 mirrorstrategy, 此時稱為 multi-worker mirror strategy. 實際上的精神都是一樣的. 就是三個動作:

  1. 初始化:所有 worker 從相同的初始模型參數開始。
  2. 同步:取一個段落, 每個 worker 把自己的 grafient 傳出去. 每個 worker 看到其他所有 worker 的資料後, 做個計算 (像是取平均). 因為只要有一個 worker 沒更新, 大家都缺一筆資料, 所以不需要中控也可以自動進行.
  3. 參數更新:每個 worker 使用同步的梯度, 以 optimizer 更新其模型參數, 確保所有工作者在每一步都有相同的模型參數。

至於 asynchronous architecture 就沒有互等的機制, 參數 (weight, bias) 會放在一個以上的 parameter server (PS). 大家都去跟它要參數就對了. 等到 worker 算出自己的 gradient, 就把它傳給 PS, PS 負責用大家的 gradient 更新出新的參數給下一個 worker 抓取. 講義原文如下:

Each worker independently fetches the latest parameters from the parameter servers and computes gradients based on a subset of training samples. It then sends the gradients back to the parameter server, which then updates its copy of the parameters with those gradients.

由 Asynchronous 的架構會有大量的資料要傳遞 (weight, bias), 所以適合用在參數大多是 0 的 sparse model. 而參數大多不是 0 的 dense model (如 BERT, GPT) 就更適合 synchronous architecture. 言下之意是傳的時候多少會做壓縮吧!

MirroredStrategy 的 sample code 如下, 藍色字是和普通 training 不一樣的地方. 特別留意是 model 這行退縮到 with strategy.scope(): 的下一層.

import tensorflow as tf
import tensorflow_datasets as tfds

# Load the dataset

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

# Define the distribution strategy
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

# Set up the input pipeline
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

# Build and compile the model within the strategy scope
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])

# Train the model
model.fit(train_dataset, epochs=10, validation_data=test_dataset)