訂閱
糾錯
加入自媒體

使用JojoGAN創(chuàng)建風(fēng)格化的面部圖

介紹

風(fēng)格遷移是神經(jīng)網(wǎng)絡(luò)的一個發(fā)展領(lǐng)域,它是一個非常有用的功能,可以集成到社交媒體和人工智能應(yīng)用程序中。幾個神經(jīng)網(wǎng)絡(luò)可以根據(jù)訓(xùn)練數(shù)據(jù)將圖像樣式映射和傳輸?shù)捷斎雸D像。在本文中,我們將研究 JojoGAN,以及僅使用一種參考樣式來訓(xùn)練和生成具有該樣式的任何圖像的過程。

JoJoGAN:One Shot Face Stylization

One Shot Face Stylization(一次性面部風(fēng)格化)可用于 AI 應(yīng)用程序、社交媒體過濾器、有趣的應(yīng)用程序和業(yè)務(wù)用例。隨著 AI 生成的圖像和視頻濾鏡的日益普及,以及它們在社交媒體和短視頻、圖像中的使用,一次性面部風(fēng)格化是一個有用的功能,應(yīng)用程序和社交媒體公司可以將其集成到最終產(chǎn)品中。

因此,讓我們來看看用于一次性生成人臉樣式的流行 GAN 架構(gòu)——JojoGAN。

JojoGAN 架構(gòu)

JojoGAN 是一種風(fēng)格遷移程序,可讓將人臉圖像的風(fēng)格遷移為另一種風(fēng)格。它通過GAN將參考風(fēng)格圖像反轉(zhuǎn)為近似的配對訓(xùn)練數(shù)據(jù),根據(jù)風(fēng)格化代碼生成真實的人臉圖像,并與參考風(fēng)格圖像相匹配。然后將該數(shù)據(jù)集用于微調(diào) StyleGAN,并且可以使用新的輸入圖像,JojoGAN 將根據(jù) GAN 反轉(zhuǎn)(inversion)將其轉(zhuǎn)換為該特定樣式。

JojoGAN 架構(gòu)和工作流程

JojoGAN 只需一種參考風(fēng)格即可在很短的時間內(nèi)(不到 1 分鐘)進行訓(xùn)練,并生成高質(zhì)量的風(fēng)格化圖像。

JojoGan 的一些例子

JojoGAN 生成的風(fēng)格化圖像的一些示例:

風(fēng)格化的圖像可以在各種不同的輸入風(fēng)格上生成并且可以修改。

JojoGan 代碼深潛

讓我們看看 JojoGAN 生成風(fēng)格化人像的實現(xiàn)。有幾個預(yù)訓(xùn)練模型可用,它們可以在我們的風(fēng)格圖像上進行訓(xùn)練,或者可以修改模型以在幾分鐘內(nèi)更改風(fēng)格。

JojoGAN 的設(shè)置和導(dǎo)入

克隆 JojoGAN 存儲庫并導(dǎo)入必要的庫。在 Google Colab 存儲中創(chuàng)建一些文件夾,用于存儲反轉(zhuǎn)代碼、樣式圖像和模型。

!git clone https://github.com/mchong6/JoJoGAN.git

%cd JoJoGAN

!pip install tqdm gdown scikit-learn==0.22 scipy lpips dlib opencv-python wandb

!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip

!sudo unzip ninja-linux.zip -d /usr/local/bin/

import torch

torch.backends.cudnn.benchmark = True

from torchvision import transforms, utils

from util import *

from PIL import Image

import math

import random

import os

import numpy

from torch import nn, autograd, optim

from torch.nn import functional

from tqdm import tqdm

import wandb

from model import *

from e4e_projection import projection

from google.colab import files

from copy import deepcopy

from pydrive.a(chǎn)uth import GoogleAuth

from pydrive.drive import GoogleDrive

from google.colab import auth

from oauth2client.client import GoogleCredentials

模型文件

使用 Pydrive 下載模型文件。一組驅(qū)動器 ID 可用于預(yù)訓(xùn)練模型。這些預(yù)訓(xùn)練模型可用于隨時隨地生成風(fēng)格化圖像,并具有不同的準(zhǔn)確度。之后,可以訓(xùn)練用戶創(chuàng)建的模型。

#Download models

#optionally enable downloads with pydrive in order to authenticate and avoid drive download limits.

download_with_pydrive = True  

device = 'cuda' #['cuda', 'cpu']

!wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2

!bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2

!mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat

%matplotlib inline

drive_ids = {

   "stylegan2-ffhq-config-f.pt": "1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",

   "e4e_ffhq_encode.pt": "1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",

   "restyle_psp_ffhq_encode.pt": "1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",

   "arcane_caitlyn.pt": "1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",

   "arcane_caitlyn_preserve_color.pt": "1cUTyjU-q98P75a8THCaO545RTwpVV-aH",

   "arcane_jinx_preserve_color.pt": "1jElwHxaYPod5Itdy18izJk49K1nl4ney",

   "arcane_jinx.pt": "1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",

   "arcane_multi_preserve_color.pt": "1enJgrC08NpWpx2XGBmLt1laimjpGCyfl",

   "arcane_multi.pt": "15V9s09sgaw-zhKp116VHigf5FowAy43f",

   "sketch_multi.pt": "1GdaeHGBGjBAFsWipTL0y-ssUiAqk8AxD",

   "disney.pt": "1zbE2upakFUAx8ximYnLofFwfT8MilqJA",

   "disney_preserve_color.pt": "1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",

   "jojo.pt": "13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",

   "jojo_preserve_color.pt": "1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",

   "jojo_yasuho.pt": "1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",

   "jojo_yasuho_preserve_color.pt": "1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",

   "art.pt": "1a0QDEHwXQ6hE_FcYEyNMuv5r5UnRQLKT",



# from StyelGAN-NADA

class Downloader(object):

   def __init__(self, use_pydrive):

       self.use_pydrive = use_pydrive

       if self.use_pydrive:

           self.a(chǎn)uthenticate()

   def authenticate(self):

       auth.a(chǎn)uthenticate_user()

       gauth = GoogleAuth()

       gauth.credentials = GoogleCredentials.get_application_default()

       self.drive = GoogleDrive(gauth)

   def download_file(self, file_name):

       file_dst = os.path.join('models', file_name)

       file_id = drive_ids[file_name]

       if not os.path.exists(file_dst):

           print(f'Downloading {file_name}')

           if self.use_pydrive:

               downloaded = self.drive.CreateFile({'id':file_id})

               downloaded.FetchMetadata(fetch_all=True)

               downloaded.GetContentFile(file_dst)

           else:

               !gdown --id $file_id -O $file_dst

downloader = Downloader(download_with_pydrive)


downloader.download_file('stylegan2-ffhq-config-f.pt')

downloader.download_file('e4e_ffhq_encode.pt')


加載生成器

加載原始和微調(diào)生成器。設(shè)置用于調(diào)整圖像大小和規(guī)范化圖像的 transforms。

latent_dim = 512

# Load original generator

original_generator = Generator(1024, latent_dim, 8, 2).to(device)

ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)

original_generator.load_state_dict(ckpt["g_ema"], strict=False)

mean_latent = original_generator.mean_latent(10000)

# to be finetuned generator

generator = deepcopy(original_generator)

transform = transforms.Compose(

   [

       transforms.Resize((1024, 1024)),

       transforms.ToTensor(),

       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

   ]

輸入圖像

設(shè)置輸入圖像位置。對齊和裁剪面并重新設(shè)置映射的樣式。

#image to the test_input directory and put the name here

filename = 'face.jpeg' #@param {type:"string"}

filepath = f'test_input/{filename}'

name = strip_path_extension(filepath)+'.pt'

# aligns and crops face

aligned_face = align_face(filepath)

# my_w = restyle_projection(aligned_face, name, device, n_iters=1).unsqueeze(0)

my_w = projection(aligned_face, name, device).unsqueeze(0)

預(yù)訓(xùn)練圖

選擇預(yù)訓(xùn)練好的圖類型,選擇不保留顏色的檢查點,效果更好。

plt.rcParams['figure.dpi'] = 150

pretrained = 'sketch_multi' #['art', 'arcane_multi', 'sketch_multi', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']

#Preserve color tries to preserve color of original image by limiting family of allowable transformations.

if preserve_color:

   ckpt = f'{pretrained}_preserve_color.pt'

else:

   ckpt = f'{pretrained}.pt'

生成結(jié)果

加載檢查點和生成器并設(shè)置種子值,然后開始生成風(fēng)格化圖像。用于 Elon Musk 的輸入圖像將根據(jù)圖類型進行風(fēng)格化。

#Generate results

n_sample =  5#{type:"number"}

seed = 3000 #{type:"number"}

torch.manual_seed(seed)

with torch.no_grad():

   generator.eval()

   z = torch.randn(n_sample, latent_dim, device=device)

   original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)

   sample = generator([z], truncation=0.7, truncation_latent=mean_latent)

   original_my_sample = original_generator(my_w, input_is_latent=True)

   my_sample = generator(my_w, input_is_latent=True)

# display reference images

if pretrained == 'arcane_multi':

   style_path = f'style_images_aligned/arcane_jinx.png'

elif pretrained == 'sketch_multi':

   style_path = f'style_images_aligned/sketch.png'

else:   

   style_path = f'style_images_aligned/{pretrained}.png'

style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)

face = transform(aligned_face).unsqueeze(0).to(device)


my_output = torch.cat([style_image, face, my_sample], 0)

生成的結(jié)果

結(jié)果生成為預(yù)先訓(xùn)練的類型“Jojo”,看起來相當(dāng)準(zhǔn)確。

現(xiàn)在讓我們看一下在自創(chuàng)樣式上訓(xùn)練 GAN。

使用你的風(fēng)格圖像進行訓(xùn)練

選擇一些面部圖,甚至創(chuàng)建一些自己的面部圖并加載這些圖像以訓(xùn)練 GAN,并設(shè)置路徑。裁剪和對齊人臉并執(zhí)行 GAN 反轉(zhuǎn)。

names = ['1.jpg', '2.jpg', '3.jpg']

targets = []

latents = []

for name in names:

   style_path = os.path.join('style_images', name)

   assert os.path.exists(style_path), f"{style_path} does not exist。

   name = strip_path_extension(name)

   # crop and align the face

   style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')

   if not os.path.exists(style_aligned_path):

       style_aligned = align_face(style_path)

       style_aligned.save(style_aligned_path)

   else:

       style_aligned = Image.open(style_aligned_path).convert('RGB')

   # GAN invert

   style_code_path = os.path.join('inversion_codes', f'{name}.pt')

   if not os.path.exists(style_code_path):

       latent = projection(style_aligned, style_code_path, device)

   else:

       latent = torch.load(style_code_path)['latent']

   latents.a(chǎn)ppend(latent.to(device))

targets = torch.stack(targets, 0)

latents = torch.stack(latents, 0)

微調(diào) StyleGAN

通過調(diào)整 alpha、顏色保留和設(shè)置迭代次數(shù)來微調(diào) StyleGAN。加載感知損失的鑒別器并重置生成器。

#Finetune StyleGAN

#alpha controls the strength of the style

alpha =  1.0 # min:0, max:1, step:0.1

alpha = 1-alpha

#preserve color of original image by limiting family of allowable transformations

preserve_color = False 

#Number of finetuning steps.

num_iter = 300

#Log training on wandb and interval for image logging

use_wandb = False 

log_interval = 50

if use_wandb:

   wandb.init(project="JoJoGAN")

   config = wandb.config

   config.num_iter = num_iter

   config.preserve_color = preserve_color

   wandb.log(

   {"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},

   step=0)

# load discriminator for perceptual loss

discriminator = Discriminator(1024, 2).eval().to(device)

ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)

discriminator.load_state_dict(ckpt["d"], strict=False)

# reset generator

del generator

generator = deepcopy(original_generator)

g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

訓(xùn)練生成器從潛在空間生成圖像,并優(yōu)化損失。

if preserve_color:

   id_swap = [9,11,15,16,17]

z = range(numiter)

for idx in tqdm( z):

   mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)

   

in_latent = latents.clone()

   in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha*mean_w[:, id_swap]


   img = generator(in_latent, input_is_latent=True)

   with torch.no_grad():

       real_feat = discriminator(targets)
   

   fake_feat = discriminator(img)

   loss = sum([functional.l1_loss(a, b) for a, b in zip(fake_feat, real_feat)])/len(fake_feat)  
   

   if use_wandb:

       wandb.log({"loss": loss}, step=idx)

       if idx % log_interval == 0:

           generator.eval()

           my_sample = generator(my_w, input_is_latent=True)

           generator.train()

           wandb.log(

           {"Current stylization": [wandb.Image(my_sample)]},

           step=idx)

   g_optim.zero_grad()

   loss.backward()

   g_optim.step()

使用 JojoGAN 生成結(jié)果

現(xiàn)在生成結(jié)果。下面已經(jīng)為原始圖像和示例圖像生成了結(jié)果以進行比較。

#Generate resultsn_sample =  5

seed = 3000

torch.manual_seed(seed)

with torch.no_grad():

   generator.eval()

   z = torch.randn(n_sample, latent_dim, device=device)

   original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)

   sample = generator([z], truncation=0.7, truncation_latent=mean_latent)

   original_my_sample = original_generator(my_w, input_is_latent=True)

   my_sample = generator(my_w, input_is_latent=True)


# display reference images

style_images = []

for name in names:

   style_path = f'style_images_aligned/{strip_path_extension(name)}.png'

   style_image = transform(Image.open(style_path))

   style_images.a(chǎn)ppend(style_image)

face = transform(aligned_face).to(device).unsqueeze(0)

style_images = torch.stack(style_images, 0).to(device)

my_output = torch.cat([face, my_sample], 0)

output = torch.cat([original_sample, sample], 0)

生成的結(jié)果

現(xiàn)在,你可以使用 JojoGAN 生成你自己風(fēng)格的圖像。結(jié)果令人印象深刻,但可以通過調(diào)整訓(xùn)練方法和訓(xùn)練圖像中的更多特征來進一步改進。

結(jié)論

JojoGAN 能夠以快速有效的方式準(zhǔn)確地映射和遷移用戶定義的樣式。關(guān)鍵要點是:

· JojoGAN 可以只用一種風(fēng)格進行訓(xùn)練,以輕松映射并創(chuàng)建任何面部的風(fēng)格化圖

· JojoGAN 非?焖儆行В梢栽诓坏揭环昼姷臅r間內(nèi)完成訓(xùn)練

· 結(jié)果非常準(zhǔn)確,類似于逼真的肖像

· JojoGAN 可以輕松微調(diào)和修改,使其適用于 AI 應(yīng)用程序

因此,無論風(fēng)格類型、形狀和顏色如何,JojoGAN 都是用于風(fēng)格轉(zhuǎn)移的理想神經(jīng)網(wǎng)絡(luò),因此可以成為各種社交媒體應(yīng)用程序和 AI 應(yīng)用程序中非常有用的功能。

       原文標(biāo)題 : 使用JojoGAN創(chuàng)建風(fēng)格化的面部圖

聲明: 本文由入駐維科號的作者撰寫,觀點僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個字

您提交的評論過于頻繁,請輸入驗證碼繼續(xù)

暫無評論

暫無評論

人工智能 獵頭職位 更多
掃碼關(guān)注公眾號
OFweek人工智能網(wǎng)
獲取更多精彩內(nèi)容
文章糾錯
x
*文字標(biāo)題:
*糾錯內(nèi)容:
聯(lián)系郵箱:
*驗 證 碼:

粵公網(wǎng)安備 44030502002758號