訂閱
糾錯(cuò)
加入自媒體

利用生成對(duì)抗網(wǎng)絡(luò)生成海洋塑料合成圖像

問(wèn)題陳述

過(guò)去十年來(lái),海洋塑料污染一直是氣候問(wèn)題的首要問(wèn)題。海洋中的塑料不僅能夠通過(guò)勒死或饑餓殺死海洋生物,而且也是通過(guò)捕獲二氧化碳使海洋變暖的一個(gè)主要因素。

近年來(lái),非營(yíng)利組織海洋清潔組織(Ocean Cleanup)多次嘗試清潔環(huán)繞我們海洋的塑料。很多清理過(guò)程的問(wèn)題是,它需要人力,而且成本效益不高。

通過(guò)使用計(jì)算機(jī)視覺(jué)和深度學(xué)習(xí)檢測(cè)海洋碎片,利用ROV和AUV進(jìn)行清理,已經(jīng)有很多研究將這一過(guò)程自動(dòng)化。

這種方法的主要問(wèn)題是關(guān)于訓(xùn)練計(jì)算機(jī)視覺(jué)模型的數(shù)據(jù)集的可用性。JAMSTEC-JEDI數(shù)據(jù)集收集了日本沿海海底的海洋廢棄物。

但是,除了這個(gè)數(shù)據(jù)集,數(shù)據(jù)集的可用性存在巨大差異。因此,我利用了生成對(duì)抗網(wǎng)絡(luò)的幫助。

DCGAN尤其致力于合成數(shù)據(jù)集,理論上,隨著時(shí)間的推移,這些數(shù)據(jù)集可能與真實(shí)數(shù)據(jù)集非常接近。

GAN和DCGAN

2014年,伊恩·古德費(fèi)羅等人提出了GANs或生成對(duì)抗網(wǎng)絡(luò)。GANs由兩個(gè)簡(jiǎn)單的組件組成,分別稱為生成器和鑒別器。

該過(guò)程如下:生成器角色用于生成新數(shù)據(jù),而鑒別器角色用于區(qū)分生成的數(shù)據(jù)和實(shí)際數(shù)據(jù)。在理想情況下,鑒別器無(wú)法區(qū)分生成的數(shù)據(jù)和真實(shí)數(shù)據(jù),從而產(chǎn)生理想的合成數(shù)據(jù)點(diǎn)。

DCGAN是上述GAN結(jié)構(gòu)的直接擴(kuò)展,只是它在鑒別器和發(fā)生器中分別使用了深卷積層。Radford等人在論文中首次描述了深度卷積生成對(duì)抗網(wǎng)絡(luò)的無(wú)監(jiān)督表征學(xué)習(xí)。鑒別器由跨步卷積層組成,而生成器由卷積轉(zhuǎn)置層組成。

PyTorch實(shí)現(xiàn)

在這種方法中,將在DeepTrash數(shù)據(jù)集。如果你不熟悉DeepTrash數(shù)據(jù)集,請(qǐng)考慮閱讀論文。

DeepTrash是海洋表層和深海表層塑料圖像的集合,旨在利用計(jì)算機(jī)視覺(jué)進(jìn)行海洋塑料檢測(cè)。

讓我們開始編碼吧!

代碼

安裝

我們首先安裝構(gòu)建GAN模型的所有基本庫(kù),比如Matplotlib和Numpy。

我們還將利用PyTorch的所有工具(如神經(jīng)網(wǎng)絡(luò)、轉(zhuǎn)換)。

from __future__ import print_function

#%matplotlib inline

import argparse

import os

import random

import torch

import torch.nn as nn

import torch.nn.parallel

import torch.backends.cudnn as cudnn

import torch.optim as optim

import torch.utils.data

import torchvision.datasets as dset

import torchvision.transforms as transforms

import torchvision.utils as vutils

import numpy as np

import matplotlib.pyplot as plt

import matplotlib.a(chǎn)nimation as animation

from IPython.display import HTML

# Set random seem for reproducibility


manualSeed = 999

#manualSeed = random.randint(1, 10000) # use if you want new results

print("Random Seed: ", manualSeed)

random.seed(manualSeed)

torch.manual_seed(manualSeed)

初始化超參數(shù)

這一步相當(dāng)簡(jiǎn)單。我們將設(shè)置我們想要用來(lái)訓(xùn)練神經(jīng)網(wǎng)絡(luò)的超參數(shù)。這些超參數(shù)直接來(lái)自于論文和PyTorch的訓(xùn)練教程。

# Root directory for dataset

# NOTE you don't have to create this. It will be created for you in the next block!

dataroot = "/content/pgan"

# Number of workers for dataloader

workers = 4

# Batch size during training

batch_size = 128

# Spatial size of training images. All images will be resized to this

#   size using a transformer.

image_size = 64

# Number of channels in the training images. For color images this is 3

nc = 3

# Size of z latent vector (i.e. size of generator input)

nz = 100


# Size of feature maps in generator

ngf = 64

# Size of feature maps in discriminator

ndf = 64

# Number of training epochs

num_epochs = 300

# Learning rate for optimizers

lr = 0.0002

# Beta1 hyperparam for Adam optimizers

beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.

ngpu = 1

生成器和鑒別器

現(xiàn)在,我們定義生成器和鑒別器的體系結(jié)構(gòu)。

# Generator

class Generator(nn.Module):

   def __init__(self, ngpu)

       super(Generator, self).__init__()

       self.ngpu = ngpu

       self.main = nn.Sequential(

           nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),

           nn.BatchNorm2d(ngf * 8),

           nn.ReLU(True),

           nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),

           nn.BatchNorm2d(ngf * 4),

           nn.ReLU(True),

           nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),

           nn.BatchNorm2d(ngf * 2),

           nn.ReLU(True),

           nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1, bias=False),

           nn.Tanh()
       )

  def forward(self, input):

       return self.main(input)
       

       # Discriminator

class Discriminator(nn.Module):

   def __init__(self, ngpu):

       super(Discriminator, self).__init__()

       self.ngpu = ngpu

       self.main = nn.Sequential(

           nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),

           nn.LeakyReLU(0.2, inplace=True),

           nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),

           nn.BatchNorm2d(ndf * 2),

           nn.LeakyReLU(0.2, inplace=True),

           nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),

           nn.BatchNorm2d(ndf * 4),

           nn.LeakyReLU(0.2, inplace=True),

           nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),

           nn.Sigmoid()

       )

   def forward(self, input):

       return self.main(input)

定義訓(xùn)練函數(shù)

在定義生成器和鑒別器類之后,我們繼續(xù)定義訓(xùn)練函數(shù)。

訓(xùn)練函數(shù)采用生成器、鑒別器、優(yōu)化函數(shù)和epoch數(shù)作為參數(shù)。我們通過(guò)遞歸調(diào)用train函數(shù)來(lái)訓(xùn)練生成器和鑒別器,直到達(dá)到所需的epoch數(shù)。

我們通過(guò)迭代數(shù)據(jù)加載器,用生成器中的新圖像更新鑒別器,并計(jì)算和更新?lián)p失函數(shù)來(lái)實(shí)現(xiàn)這一點(diǎn)。

def train(args, gen, disc, device, dataloader, optimizerG, optimizerD, criterion, epoch, iters):

 gen.train()

 disc.train()

 img_list = []

 fixed_noise = torch.randn(64, config.nz, 1, 1, device=device)

 # Establish convention for real and fake labels during training (with label smoothing)

 real_label = 0.9

 fake_label = 0.1

 for i, data in enumerate(dataloader, 0):


     #*****

     # Update Discriminator

     #*****

     ## Train with all-real batch

     disc.zero_grad()

     # Format batch

     real_cpu = data[0].to(device)

     b_size = real_cpu.size(0)

     label = torch.full((b_size,), real_label, device=device)

     # Forward pass real batch through D

     output = disc(real_cpu).view(-1)

     # Calculate loss on all-real batch

     errD_real = criterion(output, label)

     # Calculate gradients for D in backward pass

     errD_real.backward()

     D_x = output.mean().item()

     ## Train with all-fake batch

     # Generate batch of latent vectors

     noise = torch.randn(b_size, config.nz, 1, 1, device=device)

     # Generate fake image batch with G

     fake = gen(noise)

     label.fill_(fake_label)

     # Classify all fake batch with D

     output = disc(fake.detach()).view(-1)

     # Calculate D's loss on the all-fake batch

     errD_fake = criterion(output, label)

     # Calculate the gradients for this batch

     errD_fake.backward()

     D_G_z1 = output.mean().item()

     # Add the gradients from the all-real and all-fake batches

     errD = errD_real + errD_fake

     # Update D

     optimizerD.step()

     #*****

     # Update Generator

     #*****

     gen.zero_grad()

     label.fill_(real_label)  # fake labels are real for generator cost

     # Since we just updated D, perform another forward pass of all-fake batch through D

     output = disc(fake).view(-1)

     # Calculate G's loss based on this output

     errG = criterion(output, label)

     # Calculate gradients for G

     errG.backward()

     D_G_z2 = output.mean().item()

     # Update G

     optimizerG.step()

     # Output training stats

     if i % 50 == 0:

         print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'

               % (epoch, args.epochs, i, len(dataloader),

                   errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

         wandb.log({

             "Gen Loss": errG.item(),

             "Disc Loss": errD.item()})

     # Check how the generator is doing by saving G's output on fixed_noise

     if (iters % 500 == 0) or ((epoch == args.epochs-1) and (i == len(dataloader)-1)):

         with torch.no_grad():

             fake = gen(fixed_noise).detach().cpu()

         img_list.a(chǎn)ppend(wandb.Image(vutils.make_grid(fake, padding=2, normalize=True)))

         wandb.log({

             "Generated Images": img_list})

     iters += 1

監(jiān)督和訓(xùn)練DCGAN

在我們建立了生成器、鑒別器和訓(xùn)練函數(shù)之后,最后一步就是簡(jiǎn)單地調(diào)用我們定義的eoich數(shù)的訓(xùn)練函數(shù)。我還使用了Wandb,它允許我們監(jiān)控我們的訓(xùn)練。

#hide-collapse

wandb.watch_called = False

# WandB – Config is a variable that holds and saves

hyperparameters and inputs

config = wandb.config          # Initialize config

config.batch_size = batch_size

config.epochs = num_epochs        

config.lr = lr              

config.beta1 = beta1

config.nz = nz          

config.no_cuda = False        

config.seed = manualSeed # random seed (default: 42)

config.log_interval = 10 # how many batches to wait before logging training status

def main():

   use_cuda = not config.no_cuda and torch.cuda.is_available()

   device = torch.device("cuda" if use_cuda else "cpu")

   kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
   

   # Set random seeds and deterministic pytorch for reproducibility

   random.seed(config.seed)       # python random seed

   torch.manual_seed(config.seed) # pytorch random seed

   np.random.seed(config.seed) # numpy random seed

   torch.backends.cudnn.deterministic = True

   # Load the dataset

   transform = transforms.Compose(
       [transforms.ToTensor(),

       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

   trainset = datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform)

   trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                             shuffle=True, num_workers=workers)

   # Create the generator

   netG = Generator(ngpu).to(device)

   # Handle multi-gpu if desired

   if (device.type == 'cuda') and (ngpu > 1):

       netG = nn.DataParallel(netG, list(range(ngpu)))

   # Apply the weights_init function to randomly initialize all weights

   #  to mean=0, stdev=0.2.

   netG.a(chǎn)pply(weights_init)

   # Create the Discriminator

   netD = Discriminator(ngpu).to(device)

   # Handle multi-gpu if desired

   if (device.type == 'cuda') and (ngpu > 1):

       netD = nn.DataParallel(netD, list(range(ngpu)))

   # Apply the weights_init function to randomly initialize all weights

   #  to mean=0, stdev=0.2.
   netD.a(chǎn)pply(weights_init)

   # Initialize BCELoss function

   criterion = nn.BCELoss()

   # Setup Adam optimizers for both G and D

   optimizerD = optim.Adam(netD.parameters(), lr

config.lr, betas=(config.beta1, 0.999))

   optimizerG = optim.Adam(netG.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
   

   # WandB – wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.

   # Using log="all" log histograms of parameter values in addition to gradients

   wandb.watch(netG, log="all")

   wandb.watch(netD, log="all")

   iters = 0

   for epoch in range(1, config.epochs + 1):

       train(config, netG, netD, device, trainloader, optimizerG, optimizerD, criterion, epoch, iters)
       

   # WandB – Save the model checkpoint. This automatically saves a file to the cloud and associates it with the current run.

   torch.save(netG.state_dict(), "model.h5")

   wandb.save('model.h5')

if __name__ == '__main__':

   main()

結(jié)果

我們繪制了生成器和鑒別器在訓(xùn)練期間的損失。

plt.figure(figsize=(10,5))

plt.title("Generator and Discriminator Loss During Training")

plt.plot(G_losses,label="G")

plt.plot(D_losses,label="D")

plt.xlabel("iterations")

plt.ylabel("Loss")

plt.legend()

plt.show()

我們還可以查看生成器生成的圖像,以查看真實(shí)圖像和虛假圖像之間的差異。

#%%capture

fig = plt.figure(figsize=(8,8))

plt.a(chǎn)xis("off")

ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]

ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

看起來(lái)像這樣:

結(jié)論

在本文中,我們討論了使用深度卷積生成對(duì)抗網(wǎng)絡(luò)生成海洋塑料的合成圖像,研究人員可以使用這些圖像來(lái)擴(kuò)展他們當(dāng)前的海洋塑料數(shù)據(jù)集。這有助于讓研究人員能夠通過(guò)混合真實(shí)和合成圖像來(lái)擴(kuò)展他們的數(shù)據(jù)集。

從結(jié)果中可以看出,GAN仍然需要大量的工作。海洋是一個(gè)復(fù)雜的環(huán)境,光照、渾濁度、模糊度等各不相同。

       原文標(biāo)題 : 利用生成對(duì)抗網(wǎng)絡(luò)生成海洋塑料合成圖像

聲明: 本文由入駐維科號(hào)的作者撰寫,觀點(diǎn)僅代表作者本人,不代表OFweek立場(chǎng)。如有侵權(quán)或其他問(wèn)題,請(qǐng)聯(lián)系舉報(bào)。

發(fā)表評(píng)論

0條評(píng)論,0人參與

請(qǐng)輸入評(píng)論內(nèi)容...

請(qǐng)輸入評(píng)論/評(píng)論長(zhǎng)度6~500個(gè)字

您提交的評(píng)論過(guò)于頻繁,請(qǐng)輸入驗(yàn)證碼繼續(xù)

  • 看不清,點(diǎn)擊換一張  刷新

暫無(wú)評(píng)論

暫無(wú)評(píng)論

人工智能 獵頭職位 更多
掃碼關(guān)注公眾號(hào)
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯(cuò)
x
*文字標(biāo)題:
*糾錯(cuò)內(nèi)容:
聯(lián)系郵箱:
*驗(yàn) 證 碼:

粵公網(wǎng)安備 44030502002758號(hào)