Train ESRGAN 小整理

我想挑戰一面挖礦, 一面跑 AI training, 一面開艾爾登法環,… 不過看起來有點困難, 只好不跑老環了, 反正一直 “You Died". 好, 在此把 training 的步驟整理如下.

  1. 如先前所說, 用 Anaconda3 做出一個環境, 然後開 terminal.
  2. 到網路上下載訓練用的圖片.  根據原作者所述, 他用的 database 是 DF2K.DF2K is a merged training dataset consisting of 800 DIV2K  training images and 2650 Flickr2K training images.
  3. 雖然不知道有什麼特殊之處, 總之這兩個 databse 的名稱長度不一樣, 我就直接把它們塞到同一個目錄.

新目錄的位置在主目錄下自己建一個 datasets. 再下面一層開個 DF2K 子目錄, 一層 DF2K_HR 子子目錄. 2K 照片全部放這裡. 然後在與 DF2K_HR 同層開個 DF2K_multiscale 目錄. 這樣下一階段就可以開始訓練了.

train-data-150x150

4.  訓練步驟參考 [1], 首先把每張圖做縮圖, 分成 0.75, 0.5, 0.33, 0.25 四種大小. 看起來有 rounding 到特定的數字, 而且有 floor = 400.

python scripts/generate_multiscale_DF2K.py --input datasets/DF2K/DF2K_HR --output datasets/DF2K/DF2K_multiscale

例如 000001.png 就生成 000001T0.png、000001T1.png、000001T2.png 、000001T2.png 四張不同尺寸的小圖.

2048×1356 1530×1017 1020×678 680×452 601×400

5. 然後把圖形全部切齊成同樣大小的方塊. 注意 DF2K_multiscale_sub 這個目錄要留給程式建, 自己先建好會報錯.

python scripts/extract_subimages.py --input datasets/DF2K/DF2K_multiscale --output datasets/DF2K/DF2K_multiscale_sub --crop_size 400 --step 200

這裡的第一個參數顯然就是要把圖形裁切成 400×400 的小圖. 第二個參數看程式註解: “step (int): Step for overlapped sliding window."  這個意思是每個小圖有 200 pixel 和另外一張裁切圖有重疊. However, 我必須說此處就丟掉一些重要的資訊, 我想想怎模申請專利好了~~~

[原始的小圖 xxxT2: 680×452]

000001-300x199

T0 切成  35 個 400×400, T1 切成 15 個, T2 切成 6 個, T3 切成 2 個. T0 的前四張 400×400 長這樣. 也就是由左至右掃過原圖的上方.

000001T0_s001-150x150000001T0_s002-150x150000001T0_s003-150x150000001T0_s004-150x150

6. 在前一個步驟, 很多 400×400 小塊的原稿在哪裡? 這個資訊 (metadata) 需要保留起來, 因此下面這個步驟的重點就是回溯紀錄 metadata.

 python scripts/generate_meta_info.py --input datasets/DF2K/DF2K_HR, datasets/DF2K/DF2K_multiscale --root datasets/DF2K, datasets/DF2K --meta_info datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt

7. 以上算是滿快就跑完了, 跟我打字的速度差不多. 接下來是重點, 作者的說明在不同版本有些許差異, 總之找到前面檔案總管節圖的那個目錄底下的 options, 裡面會有很多個 yml 檔. 以放大四倍來說, 需要改這兩個檔案, 確定目錄都對. 第一個訓練 generator, 第二個訓練 discrimator, 因此 pre-train 的檔案不同.

train_realesrnet_x4plus.yml

train_realesrgan_x4plus.yml

8. 接下來用 debug mode 跑一次,

python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --debug

馬上發現 Anaconda 的環境沒有安裝 cuda “Torch not compiled with CUDA enabled", 即使補安裝 cudatool, 也只是把 torch 版本由 1.10.2+CPU 變成 1.10.2, cuda 還是 unavailable.

(RealESRgan) C:\Users\ufoca\Real-ESRGAN>python
Python 3.8.12 (default, Oct 12 2021, 03:01:40) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print (torch.__version__)
1.10.2
>>> print (torch.cuda.is_available())
False
>>>

看起來要先下載正確版本. 參考 [2],[3],[4]. 網站 [4] 是 pytorch 官網, 似乎只要把自己環境設定對, 就可以用 UI 生成正確的安裝指令.

conda install pytorch==1.10.2 torchvision cudatoolkit=11.3 -c pytorch

不過這招其實沒用, [5] 說的才是對的. 要用 pip install, anaconda 本身的 server 只有 cpu 的 pytorch 可以選.

pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio===0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

另外, 我本來 pytorch 是 1.10.2, torchvision 是 0.11.3, 但上面那行貼了就會動. 安裝完跑出一些紅字說沒有檢查所有版本的相容性, 所以我又重新安裝成我的版號, 不過我不知道我的 torchaudio 應該要搭哪一版? 反而安裝不成功.

於是我還是用網路版本, 並且重跑一次

pip install -r requirements.txt

讓作者的檢查系統幫我微調一下版本. 總之, AI 的技術進步很快, 網路文章上的版號參考就好. 到時候要隨機應變.

9. 接者就不用跑 debug 版了. 改跑這條.

python realesrgan/train.py -opt options/train_realesrnet_x4plus.yml --auto_resume

我估計應該會做很久很久. 記錄一下起跑時間.

run

UI 顯示要跑 5 天半.

2022-04-10 21:48:02,327 INFO: [train..][epoch: 5, iter: 6,500, lr:(2.000e-04,)] [eta: 5 days, 10:47:47, time (data): 0.476 (0.002)] l_pix: 6.8241e-02

10. 然後是訓練 discriminator

python realesrgan/train.py -opt options/train_realesrgan_x4plus.yml --auto_resume

證明環境沒問題之後, 就可以自己動手腳去改 code 了.

[Note]

1. https://github.com/xinntao/Real-ESRGAN/blob/master/Training.md

2. https://towardsdatascience.com/setting-up-tensorflow-gpu-with-cuda-and-anaconda-onwindows-2ee9c39b5c44

3. https://varhowto.com/install-pytorch-cuda-10-0/

4. https://pytorch.org/

5. https://ithelp.ithome.com.tw/articles/10282119

ESRGAN 小註解

我參考 ESRGAN [11] 的論文導讀[1], 把其中相關的名詞做了個整理.

基本上兩大重點是 RRDB 和 GAN. 這個網路模型前者的修改是拔掉 BN, 後續又出了一篇 [9] 加上 noise 處理. 後者是修改 GAN 的對抗方式, [9] 又做了一些調整就略過. 以下把相關的名詞稍微整理如下.

首先看到 BN 這個詞, 文章中指的是第二種, 但是也順便把第一種 BN 列進來.

Bayesian Network (BN)

在 AI 中的主要應用是把人類的知識加入類神經網路 [2]. 如果單純使用 data training 叫做 machine learning. 既然它的名字裡面有 Bayesian, 顯然它考慮了條件機率, 所以在 training 的過程中, 我們要引入條件機率表 (conditional probability table – CPT).

條件式的因果關係就是 P(B|A) = C 的形式, 因此 BN 網路需要有一群輸入層的 feature 值, 然後以因果關係建立網路結構, 接著用條件機率把每個 node 到下一層 node 的機率以 CPT 表示, 最後會有一個輸出層, 告訴我們有幾種可能的 state, 以及各自有多少機率.

Batch Normalization (BN)

這是另外一個 BN, 在 RealSRGAN 裡面指的是這個, 事實上它也比較常用. 原理就是把輸入正規化為 N(0,1). 即 mean = 0, variance = 1, 這方法在 pattern recognition 和 classification 中也常用.  不過對 NN 來說, 它是指在 internal layer 去做 normalization, 不只是在輸入層.

RRDBNet Network [3]

在 RRDB 的實作中, 裡面有許多的 basic block 串接. Basic block  包括 Dense Block (DB) [4] 和 Residial Block (RB) [5]. Residual in Resisual Dense Block 就是把 dense block 以 residual network 的方式連起來 [3], 把這些 DB 都當作 RB 裡面的一層 layer 看待.

RRDBNet-768x255

RRDB 在 [1] 講得比較多. BTW, 關於噪音的改進. 可以參考 [9] 這篇的開頭.

Dense Block [4]

主要是每層都和其他幾層相連. 我把它當作 full connection 的分批簡化形式. Dense Block 的串接連成 Dense Network.

denseblock-768x653

Residual Block (RB) [5]

Residual Block 的概念, 像是 Dense block 的變形. 也就是可以跳過一層網路不連.

resisual-network-768x668

GAN

GAN 用國文來解釋很簡單, 就是正邪兩派各自去認親, 能夠訓練到邪不勝正, 那麼正的網路就大功告成了.

GAN-block-diagram

本圖取材自 [10].

從 coding 的角度來說, 當然還是要看到程式比較有感覺. 我覺得 [6] 的寫法滿好懂的. 若對於只熟悉C語言, 不熟 Pythone 的人需要克服這個底線障礙. 這邊的底線是 [7] 五種底線的第二種, ‘_’ 表示傳回的值有些是 don’t care. 例如:

validity, _, _ = discriminator(gen_imgs)

GAN 的基本流程當然是把生成對抗網路都各 train 一輪.  此時每輸入每一張圖, 都要生成一堆有噪音的圖 (靠隨機變數) – train generator. 然後把這一大把圖, 拿去 train discrimiator. 在生成 generator 的時候, 可以看到有製造噪音的 code 和最後的 back propogation.

# Train Generator
optimizer_G.zero_grad()
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(1, 1, (batch_size, opt.code_dim))))
# Generate a batch of images
gen_imgs = generator(z, label_input, code_input)
# Loss measures generator’s ability to fool the discriminator
validity, _, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()

訓練 discrimator 網路時. 要得到 real image 在這個網路的 loss 評分 (終極目標是 loss = 0), 以及 fake image 的 loss 評分 (終極目標是 loss 很大). 兩組 loss 輸出, 取其平均值做 backward propogation 的分數. adversarial_loss 的實作細節看起來是在 Torch 的 loss.py 裡面 [8].

# Train Discriminator
optimizer_D.zero_grad()
# Loss for real images
real_pred, _, _ = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, valid)
# Loss for fake images
fake_pred, _, _ = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, fake)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()

以上是普通的 GAN 的處理方式. ESRGAN 的 GAN 不是比較 real image 和 fake image 誰的分數高, 因為我們早已經知道 real image 和 fake image 各自要扮演的角色. 花太多時間讓參數長對有點浪費時間.

ESRGAN 用的方式是讓真的和假的直接去 PK, 而不是對標準答案. 真的贏過假的參數要加分, 假的贏過真的的參數要扣分. 以下直接貼 [1] 的內容, 但用顏色標註重點.

我们可以看SRGAN的loss:

         l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
         l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)

而ESRGAN的loss:

l_d_real = self.cri_gan( real_d_pred - torch.mean(fake_d_pred), True, is_disc=True)         
l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred), False, is_disc=True)

以上就是這篇論文會用到的重要模組. 沒有什麼創見, 就是幫入門者的障礙減少一點.

[Note]

  1. https://zhuanlan.zhihu.com/p/258532044
  2. https://towardsdatascience.com/how-to-train-a-bayesian-network-bn-using-expert-knowledge-583135d872d7
  3. https://blog.csdn.net/gwplovekimi/article/details/90032735
  4. DenseNet 學習心得
  5. https://ithelp.ithome.com.tw/articles/10204727
  6. https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/infogan/infogan.py
  7. https://towardsdatascience.com/5-different-meanings-of-underscore-in-python-3fafa6cd0379
  8. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
  9. https://zhuanlan.zhihu.com/p/393350811
  10. https://github.com/jonbruner/generative-adversarial-networks/blob/master/gan-notebook.ipynb
  11. https://aiqianji.com/blog/article/1