我參考 ESRGAN [11] 的論文導讀[1], 把其中相關的名詞做了個整理.
基本上兩大重點是 RRDB 和 GAN. 這個網路模型前者的修改是拔掉 BN, 後續又出了一篇 [9] 加上 noise 處理. 後者是修改 GAN 的對抗方式, [9] 又做了一些調整就略過. 以下把相關的名詞稍微整理如下.
首先看到 BN 這個詞, 文章中指的是第二種, 但是也順便把第一種 BN 列進來.
Bayesian Network (BN)
在 AI 中的主要應用是把人類的知識加入類神經網路 [2]. 如果單純使用 data training 叫做 machine learning. 既然它的名字裡面有 Bayesian, 顯然它考慮了條件機率, 所以在 training 的過程中, 我們要引入條件機率表 (conditional probability table – CPT).
條件式的因果關係就是 P(B|A) = C 的形式, 因此 BN 網路需要有一群輸入層的 feature 值, 然後以因果關係建立網路結構, 接著用條件機率把每個 node 到下一層 node 的機率以 CPT 表示, 最後會有一個輸出層, 告訴我們有幾種可能的 state, 以及各自有多少機率.
Batch Normalization (BN)
這是另外一個 BN, 在 RealSRGAN 裡面指的是這個, 事實上它也比較常用. 原理就是把輸入正規化為 N(0,1). 即 mean = 0, variance = 1, 這方法在 pattern recognition 和 classification 中也常用. 不過對 NN 來說, 它是指在 internal layer 去做 normalization, 不只是在輸入層.
RRDBNet Network [3]
在 RRDB 的實作中, 裡面有許多的 basic block 串接. Basic block 包括 Dense Block (DB) [4] 和 Residial Block (RB) [5]. Residual in Resisual Dense Block 就是把 dense block 以 residual network 的方式連起來 [3], 把這些 DB 都當作 RB 裡面的一層 layer 看待.
RRDB 在 [1] 講得比較多. BTW, 關於噪音的改進. 可以參考 [9] 這篇的開頭.
Dense Block [4]
主要是每層都和其他幾層相連. 我把它當作 full connection 的分批簡化形式. Dense Block 的串接連成 Dense Network.
Residual Block (RB) [5]
Residual Block 的概念, 像是 Dense block 的變形. 也就是可以跳過一層網路不連.
GAN
GAN 用國文來解釋很簡單, 就是正邪兩派各自去認親, 能夠訓練到邪不勝正, 那麼正的網路就大功告成了.
本圖取材自 [10].
從 coding 的角度來說, 當然還是要看到程式比較有感覺. 我覺得 [6] 的寫法滿好懂的. 若對於只熟悉C語言, 不熟 Pythone 的人需要克服這個底線障礙. 這邊的底線是 [7] 五種底線的第二種, ‘_’ 表示傳回的值有些是 don’t care. 例如:
validity, _, _ = discriminator(gen_imgs)
GAN 的基本流程當然是把生成對抗網路都各 train 一輪. 此時每輸入每一張圖, 都要生成一堆有噪音的圖 (靠隨機變數) – train generator. 然後把這一大把圖, 拿去 train discrimiator. 在生成 generator 的時候, 可以看到有製造噪音的 code 和最後的 back propogation.
# Train Generator |
optimizer_G.zero_grad() |
# Sample noise and labels as generator input |
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim)))) |
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes) |
code_input = Variable(FloatTensor(np.random.uniform(–1, 1, (batch_size, opt.code_dim)))) |
# Generate a batch of images |
gen_imgs = generator(z, label_input, code_input) |
# Loss measures generator’s ability to fool the discriminator |
validity, _, _ = discriminator(gen_imgs) |
g_loss = adversarial_loss(validity, valid) |
g_loss.backward() |
optimizer_G.step() |
訓練 discrimator 網路時. 要得到 real image 在這個網路的 loss 評分 (終極目標是 loss = 0), 以及 fake image 的 loss 評分 (終極目標是 loss 很大). 兩組 loss 輸出, 取其平均值做 backward propogation 的分數. adversarial_loss 的實作細節看起來是在 Torch 的 loss.py 裡面 [8].
# Train Discriminator |
optimizer_D.zero_grad() |
# Loss for real images |
real_pred, _, _ = discriminator(real_imgs) |
d_real_loss = adversarial_loss(real_pred, valid) |
# Loss for fake images |
fake_pred, _, _ = discriminator(gen_imgs.detach()) |
d_fake_loss = adversarial_loss(fake_pred, fake) |
# Total discriminator loss |
d_loss = (d_real_loss + d_fake_loss) / 2 |
d_loss.backward() |
optimizer_D.step() |
以上是普通的 GAN 的處理方式. ESRGAN 的 GAN 不是比較 real image 和 fake image 誰的分數高, 因為我們早已經知道 real image 和 fake image 各自要扮演的角色. 花太多時間讓參數長對有點浪費時間.
ESRGAN 用的方式是讓真的和假的直接去 PK, 而不是對標準答案. 真的贏過假的參數要加分, 假的贏過真的的參數要扣分. 以下直接貼 [1] 的內容, 但用顏色標註重點.
我们可以看SRGAN的loss:
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
而ESRGAN的loss:
l_d_real = self.cri_gan( real_d_pred - torch.mean(fake_d_pred), True, is_disc=True)
l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred), False, is_disc=True)
以上就是這篇論文會用到的重要模組. 沒有什麼創見, 就是幫入門者的障礙減少一點.
[Note]
- https://zhuanlan.zhihu.com/p/258532044
- https://towardsdatascience.com/how-to-train-a-bayesian-network-bn-using-expert-knowledge-583135d872d7
- https://blog.csdn.net/gwplovekimi/article/details/90032735
- DenseNet 學習心得
- https://ithelp.ithome.com.tw/articles/10204727
- https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/infogan/infogan.py
- https://towardsdatascience.com/5-different-meanings-of-underscore-in-python-3fafa6cd0379
- https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
- https://zhuanlan.zhihu.com/p/393350811
- https://github.com/jonbruner/generative-adversarial-networks/blob/master/gan-notebook.ipynb
- https://aiqianji.com/blog/article/1