banner
andrewji8

Being towards death

Heed not to the tree-rustling and leaf-lashing rain, Why not stroll along, whistle and sing under its rein. Lighter and better suited than horses are straw sandals and a bamboo staff, Who's afraid? A palm-leaf plaited cape provides enough to misty weather in life sustain. A thorny spring breeze sobers up the spirit, I feel a slight chill, The setting sun over the mountain offers greetings still. Looking back over the bleak passage survived, The return in time Shall not be affected by windswept rain or shine.
telegram
twitter
github

Pythonを使用して写真の人物の背景を置き換え、髪の毛の一本一本まで精密に(コード付き)

前言
本文の github リポジトリのアドレスは:

替换照片中人物背景

モデルファイルが大きいため、リポジトリには置いていません。以下にモデルのダウンロードアドレスがあります。

プロジェクト説明
プロジェクト構造
まず、プロジェクトの構造を見てみましょう。

640
ここで、model フォルダにはモデルファイルがあり、モデルファイルのダウンロードアドレスは:https://drive.google.com/drive/folders/1NmyTItr2jRac0nLoZMeixlcU1myMiYTs

640 (1)
このモデルをダウンロードして model フォルダに置いてください。

依存ファイル - requirements.txt について説明します。pytorch のインストールには公式サイトからのものを使用する必要があります。これにより、グラフィックカードのドライバーが一致しないことを避けることができます。私の別の記事を参考にして pytorch のインストールについて確認できます:

https://huyi-aliang.blog.csdn.net/article/details/120556923

依存ファイルは以下の通りです:

kornia==0.4.1  
tensorboard==2.3.0  
torch==1.7.0  
torchvision==0.8.1  
tqdm==4.51.0  
opencv-python==4.4.0.44  
onnxruntime==1.6.0  

データ準備
写真とその背景画像、置き換えたい画像を準備する必要があります。私が選んだのは BackgroundMattingV2 が提供するいくつかの参考画像で、元の画像と背景画像は以下の通りです:

640 (2)

640 (3)
新しい背景画像(私が適当に見つけたもの)は以下の通りです

640 (4)
背景画像置き換えコード
無駄話はやめて、核心コードに入ります。


#!/usr/bin/env python  
# -*- coding: utf-8 -*-  
# @Time    : 2021/11/14 21:24  
# @Author  : 剣客阿良_ALiang  
# @Site    :   
# @File    : inferance_hy.py  
import argparse  
import torch  
import os  

from torch.nn import functional as F  
from torch.utils.data import DataLoader  
from torchvision import transforms as T  
from torchvision.transforms.functional import to_pil_image  
from threading import Thread  
from tqdm import tqdm  
from torch.utils.data import Dataset  
from PIL import Image  
from typing import Callable, Optional, List, Tuple  
import glob  
from torch import nn  
from torchvision.models.resnet import ResNet, Bottleneck  
from torch import Tensor  
import torchvision  
import numpy as np  
import cv2  
import uuid  

# --------------- hy ---------------  
class HomographicAlignment:  
    """  
    背景にホモグラフィーアライメントを適用して、ソース画像と一致させます。  
    """  

    def __init__(self):  
        self.detector = cv2.ORB_create()  
        self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)  

    def __call__(self, src, bgr):  
        src = np.asarray(src)  
        bgr = np.asarray(bgr)  

        keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)  
        keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)  

        matches = self.matcher.match(descriptors_bgr, descriptors_src, None)  
        matches.sort(key=lambda x: x.distance, reverse=False)  
        num_good_matches = int(len(matches) * 0.15)  
        matches = matches[:num_good_matches]  

        points_src = np.zeros((len(matches), 2), dtype=np.float32)  
        points_bgr = np.zeros((len(matches), 2), dtype=np.float32)  
        for i, match in enumerate(matches):  
            points_src[i, :] = keypoints_src[match.trainIdx].pt  
            points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt  

        H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)  

        h, w = src.shape[:2]  
        bgr = cv2.warpPerspective(bgr, H, (w, h))  
        msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))  

        # 背景の外側の領域については、  
        # ソースからピクセルをコピーします。  
        bgr[msk != 1] = src[msk != 1]  

        src = Image.fromarray(src)  
        bgr = Image.fromarray(bgr)  

        return src, bgr  

class Refiner(nn.Module):  
    # TorchScriptエクスポート最適化のため。  
    __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']  

    def __init__(self,  
                 mode: str,  
                 sample_pixels: int,  
                 threshold: float,  
                 kernel_size: int = 3,  
                 prevent_oversampling: bool = True,  
                 patch_crop_method: str = 'unfold',  
                 patch_replace_method: str = 'scatter_nd'):  
        super().__init__()  
        assert mode in ['full', 'sampling', 'thresholding']  
        assert kernel_size in [1, 3]  
        assert patch_crop_method in ['unfold', 'roi_align', 'gather']  
        assert patch_replace_method in ['scatter_nd', 'scatter_element']  

        self.mode = mode  
        self.sample_pixels = sample_pixels  
        self.threshold = threshold  
        self.kernel_size = kernel_size  
        self.prevent_oversampling = prevent_oversampling  
        self.patch_crop_method = patch_crop_method  
        self.patch_replace_method = patch_replace_method  

        channels = [32, 24, 16, 12, 4]  
        self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)  
        self.bn1 = nn.BatchNorm2d(channels[1])  
        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)  
        self.bn2 = nn.BatchNorm2d(channels[2])  
        self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)  
        self.bn3 = nn.BatchNorm2d(channels[3])  
        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)  
        self.relu = nn.ReLU(True)  

    def forward(self,  
                src: torch.Tensor,  
                bgr: torch.Tensor,  
                pha: torch.Tensor,  
                fgr: torch.Tensor,  
                err: torch.Tensor,  
                hid: torch.Tensor):  
        H_full, W_full = src.shape[2:]  
        H_half, W_half = H_full // 2, W_full // 2  
        H_quat, W_quat = H_full // 4, W_full // 4  

        src_bgr = torch.cat([src, bgr], dim=1)  

        if self.mode != 'full':  
            err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)  
            ref = self.select_refinement_regions(err)  
            idx = torch.nonzero(ref.squeeze(1))  
            idx = idx[:, 0], idx[:, 1], idx[:, 2]  

            if idx[0].size(0) > 0:  
                x = torch.cat([hid, pha, fgr], dim=1)  
                x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)  
                x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)  

                y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)  
                y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)  

                x = self.conv1(torch.cat([x, y], dim=1))  
                x = self.bn1(x)  
                x = self.relu(x)  
                x = self.conv2(x)  
                x = self.bn2(x)  
                x = self.relu(x)  

                x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')  
                y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)  

                x = self.conv3(torch.cat([x, y], dim=1))  
                x = self.bn3(x)  
                x = self.relu(x)  
                x = self.conv4(x)  

                out = torch.cat([pha, fgr], dim=1)  
                out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)  
                out = self.replace_patch(out, x, idx)  
                pha = out[:, :1]  
                fgr = out[:, 1:]  
            else:  
                pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)  
                fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)  
        else:  
            x = torch.cat([hid, pha, fgr], dim=1)  
            x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)  
            y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)  
            if self.kernel_size == 3:  
                x = F.pad(x, (3, 3, 3, 3))  
                y = F.pad(y, (3, 3, 3, 3))  

            x = self.conv1(torch.cat([x, y], dim=1))  
            x = self.bn1(x)  
            x = self.relu(x)  
            x = self.conv2(x)  
            x = self.bn2(x)  
            x = self.relu(x)  

            if self.kernel_size == 3:  
                x = F.interpolate(x, (H_full + 4, W_full + 4))  
                y = F.pad(src_bgr, (2, 2, 2, 2))  
            else:  
                x = F.interpolate(x, (H_full, W_full), mode='nearest')  
                y = src_bgr  

            x = self.conv3(torch.cat([x, y], dim=1))  
            x = self.bn3(x)  
            x = self.relu(x)  
            x = self.conv4(x)  

            pha = x[:, :1]  
            fgr = x[:, 1:]  
            ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)  

        return pha, fgr, ref  

    def select_refinement_regions(self, err: torch.Tensor):  
        """  
        リファインメント領域を選択します。  
        入力:  
            err: エラーマップ (B, 1, H, W)  
        出力:  
            ref: リファインメント領域 (B, 1, H, W)。 FloatTensor。 1が選択され、0は選択されていません。  
        """  
        if self.mode == 'sampling':  
            # サンプリングモード。  
            b, _, h, w = err.shape  
            err = err.view(b, -1)  
            idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices  
            ref = torch.zeros_like(err)  
            ref.scatter_(1, idx, 1.)  
            if self.prevent_oversampling:  
                ref.mul_(err.gt(0).float())  
            ref = ref.view(b, 1, h, w)  
        else:  
            # 閾値モード。  
            ref = err.gt(self.threshold).float()  
        return ref  

    def crop_patch(self,  
                   x: torch.Tensor,  
                   idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],  
                   size: int,  
                   padding: int):  
        """  
        指定されたインデックスから画像の選択されたパッチをクロップします。  
        入力:  
            x: 画像 (B, C, H, W)。  
            idx: 選択インデックスのタプル[(P,), (P,), (P,),]、3つの値は(B, H, W)インデックスです。  
            size: パッチの中心サイズ、クロップのストライドでもあります。  
            padding: パッチの拡張サイズ。  
        出力:  
            patch: (P, C, h, w)、ここでh = w = size + 2 * paddingです。  
        """  
        if padding != 0:  
            x = F.pad(x, (padding,) * 4)  

        if self.patch_crop_method == 'unfold':  
            # unfoldを使用します。PyTorchとTorchScriptの最良のパフォーマンス。  
            return x.permute(0, 2, 3, 1) \  
                .unfold(1, size + 2 * padding, size) \  
                .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]  
        elif self.patch_crop_method == 'roi_align':  
            # roi_alignを使用します。ONNXとの互換性が最良です。  
            idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)  
            b = idx[0]  
            x1 = idx[2] * size - 0.5  
            y1 = idx[1] * size - 0.5  
            x2 = idx[2] * size + size + 2 * padding - 0.5  
            y2 = idx[1] * size + size + 2 * padding - 0.5  
            boxes = torch.stack([b, x1, y1, x2, y2], dim=1)  
            return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)  
        else:  
            # gatherを使用します。ピクセルごとにパッチをクロップします。  
            idx_pix = self.compute_pixel_indices(x, idx, size, padding)  
            pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))  
            pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)  
            return pat  

    def replace_patch(self,  
                      x: torch.Tensor,  
                      y: torch.Tensor,  
                      idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):  
        """  
        指定されたインデックスにパッチを画像に戻します。  
        入力:  
            x: 画像 (B, C, H, W)  
            y: パッチ (P, C, h, w)  
            idx: 選択インデックスのタプル[(P,), (P,), (P,)]、3つの値は(B, H, W)インデックスです。  
        出力:  
            画像: (B, C, H, W)、ここでidxの位置にあるパッチはyで置き換えられます。  
        """  
        xB, xC, xH, xW = x.shape  
        yB, yC, yH, yW = y.shape  
        if self.patch_replace_method == 'scatter_nd':  
            # scatter_ndを使用します。PyTorchとTorchScriptの最良のパフォーマンス。パッチごとに置き換えます。  
            x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)  
            x[idx[0], idx[1], idx[2]] = y  
            x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)  
            return x  
        else:  
            # scatter_elementを使用します。ONNXとの互換性が最良です。ピクセルごとに置き換えます。  
            idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)  
            return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)  

    def compute_pixel_indices(self,  
                              x: torch.Tensor,  
                              idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],  
                              size: int,  
                              padding: int):  
        """  
        テンソル内の選択されたピクセルインデックスを計算します。  
        crop_method == 'gather'およびreplace_method == 'scatter_element'に使用され、ピクセルごとにクロップおよび置き換えます。  
        入力:  
            x: 画像: (B, C, H, W)  
            idx: 選択インデックスのタプル[(P,), (P,), (P,),]、3つの値は(B, H, W)インデックスです。  
            size: パッチの中心サイズ、クロップのストライドでもあります。  
            padding: パッチの拡張サイズ。  
        出力:  
            idx: (P, C, O, O)のロングテンソル、ここでOは出力サイズ: size + 2 * padding、Pはパッチの数です。  
                 要素は入力x.view(-1)を指すインデックスです。  
        """  
        B, C, H, W = x.shape  
        S, P = size, padding  
        O = S + 2 * P  
        b, y, x = idx  
        n = b.size(0)  
        c = torch.arange(C)  
        o = torch.arange(O)  
        idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1,  
                                                                                                                 O).expand(  
            [C, O, O])  
        idx_loc = b * W * H + y * W * S + x * S  
        idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])  
        return idx_pix  

def load_matched_state_dict(model, state_dict, print_stats=True):  
    """  
    キーと形状が一致する重みのみをロードします。他の重みは無視します。  
    """  
    num_matched, num_total = 0, 0  
    curr_state_dict = model.state_dict()  
    for key in curr_state_dict.keys():  
        num_total += 1  
        if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:  
            curr_state_dict[key] = state_dict[key]  
            num_matched += 1  
    model.load_state_dict(curr_state_dict)  
    if print_stats:  
        print(f'Loaded state_dict: {num_matched}/{num_total} matched')  

def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:  
    """  
    この関数は元のtfリポジトリから取られています。  
    すべてのレイヤーが8で割り切れるチャネル数を持つことを保証します。  
    ここで見ることができます:  
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py  
    """  
    if min_value is None:  
        min_value = divisor  
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)  
    # 切り捨てが10%以上下がらないことを確認します。  
    if new_v < 0.9 * v:  
        new_v += divisor  
    return new_v  

class ConvNormActivation(torch.nn.Sequential):  
    def __init__(  
            self,  
            in_channels: int,  
            out_channels: int,  
            kernel_size: int = 3,  
            stride: int = 1,  
            padding: Optional[int] = None,  
            groups: int = 1,  
            norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,  
            activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,  
            dilation: int = 1,  
            inplace: bool = True,  
    ) -> None:  
        if padding is None:  
            padding = (kernel_size - 1) // 2 * dilation  
        layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,  
                                  dilation=dilation, groups=groups, bias=norm_layer is None)]  
        if norm_layer is not None:  
            layers.append(norm_layer(out_channels))  
        if activation_layer is not None:  
            layers.append(activation_layer(inplace=inplace))  
        super().__init__(*layers)  
        self.out_channels = out_channels  

class InvertedResidual(nn.Module):  
    def __init__(  
            self,  
            inp: int,  
            oup: int,  
            stride: int,  
            expand_ratio: int,  
            norm_layer: Optional[Callable[..., nn.Module]] = None  
    ) -> None:  
        super(InvertedResidual, self).__init__()  
        self.stride = stride  
        assert stride in [1, 2]  

        if norm_layer is None:  
            norm_layer = nn.BatchNorm2d  

        hidden_dim = int(round(inp * expand_ratio))  
        self.use_res_connect = self.stride == 1 and inp == oup  

        layers: List[nn.Module] = []  
        if expand_ratio != 1:  
            # pw  
            layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,  
                                             activation_layer=nn.ReLU6))  
        layers.extend([  
            # dw  
            ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,  
                               activation_layer=nn.ReLU6),  
            # pw-linear  
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),  
            norm_layer(oup),  
        ])  
        self.conv = nn.Sequential(*layers)  
        self.out_channels = oup  
        self._is_cn = stride > 1  

    def forward(self, x: Tensor) -> Tensor:  
        if self.use_res_connect:  
            return x + self.conv(x)  
        else:  
            return self.conv(x)  

class MobileNetV2(nn.Module):  
    def __init__(  
            self,  
            num_classes: int = 1000,  
            width_mult: float = 1.0,  
            inverted_residual_setting: Optional[List[List[int]]] = None,  
            round_nearest: int = 8,  
            block: Optional[Callable[..., nn.Module]] = None,  
            norm_layer: Optional[Callable[..., nn.Module]] = None  
    ) -> None:  
        """  
        MobileNet V2メインクラス  
        引数:  
            num_classes (int): クラスの数  
            width_mult (float): 幅の乗数 - 各レイヤーのチャネル数をこの量で調整します  
            inverted_residual_setting: ネットワーク構造  
            round_nearest (int): 各レイヤーのチャネル数をこの数の倍数に丸めます  
            1に設定すると丸めをオフにします  
            block: mobilenetのための逆残差構築ブロックを指定するモジュール  
            norm_layer: 使用する正規化レイヤーを指定するモジュール  
        """  
        super(MobileNetV2, self).__init__()  

        if block is None:  
            block = InvertedResidual  

        if norm_layer is None:  
            norm_layer = nn.BatchNorm2d  

        input_channel = 32  
        last_channel = 1280  

        if inverted_residual_setting is None:  
            inverted_residual_setting = [  
                # t, c, n, s  
                [1, 16, 1, 1],  
                [6, 24, 2, 2],  
                [6, 32, 3, 2],  
                [6, 64, 4, 2],  
                [6, 96, 3, 1],  
                [6, 160, 3, 2],  
                [6, 320, 1, 1],  
            ]  

        # 最初の要素のみを確認し、ユーザーがt,c,n,sが必要であることを知っていると仮定します  
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:  
            raise ValueError("inverted_residual_settingは空でない必要があります"  
                             "または4要素のリストである必要があります。取得したのは{}".format(inverted_residual_setting))  

        # 最初のレイヤーを構築  
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)  
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)  
        features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,  
                                                        activation_layer=nn.ReLU6)]  
        # 逆残差ブロックを構築  
        for t, c, n, s in inverted_residual_setting:  
            output_channel = _make_divisible(c * width_mult, round_nearest)  
            for i in range(n):  
                stride = s if i == 0 else 1  
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))  
                input_channel = output_channel  
        # 最後の数レイヤーを構築  
        features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,  
                                           activation_layer=nn.ReLU6))  
        # nn.Sequentialにします  
        self.features = nn.Sequential(*features)  

        # 分類器を構築  
        self.classifier = nn.Sequential(  
            nn.Dropout(0.2),  
            nn.Linear(self.last_channel, num_classes),  
        )  

        # 重みの初期化  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                nn.init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    nn.init.zeros_(m.bias)  
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):  
                nn.init.ones_(m.weight)  
                nn.init.zeros_(m.bias)  
            elif isinstance(m, nn.Linear):  
                nn.init.normal_(m.weight, 0, 0.01)  
                nn.init.zeros_(m.bias)  

    def _forward_impl(self, x: Tensor) -> Tensor:  
        # これはTorchScriptが継承をサポートしていないため存在します。  
        # スーパークラスのメソッド  
        # (このメソッド)は、サブクラスでアクセスできる名前を持つ必要があります  
        x = self.features(x)  
        # "squeeze"を使用できません。バッチサイズが1になる可能性があるため  
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))  
        x = torch.flatten(x, 1)  
        x = self.classifier(x)  
        return x  

    def forward(self, x: Tensor) -> Tensor:  
        return self._forward_impl(x)  

class MobileNetV2Encoder(MobileNetV2):  
    """  
    MobileNetV2Encoderはtorchvisionの公式MobileNetV2から継承されます。  
    出力ストライド16を維持するために最後のブロックで膨張を使用するように修正され、  
    元々分類に使用されていた分類器ブロックが削除されました。  
    forwardメソッドは、デコーダーの使用のためにすべての解像度でフィーチャーマップを追加で返します。  
    """  

    def __init__(self, in_channels, norm_layer=None):  
        super().__init__()  

        # in_channelsが一致しない場合は最初のconvレイヤーを置き換えます。  
        if in_channels != 3:  
            self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)  

        # 最後のブロックを削除  
        self.features = self.features[:-1]  

        # 出力ストライド = 16を維持するために膨張を使用するように変更  
        self.features[14].conv[1][0].stride = (1, 1)  
        for feature in self.features[15:]:  
            feature.conv[1][0].dilation = (2, 2)  
            feature.conv[1][0].padding = (2, 2)  

        # 分類器を削除  
        del self.classifier  

    def forward(self, x):  
        x0 = x  # 1/1  
        x = self.features[0](x)  
        x = self.features[1](x)  
        x1 = x  # 1/2  
        x = self.features[2](x)  
        x = self.features[3](x)  
        x2 = x  # 1/4  
        x = self.features[4](x)  
        x = self.features[5](x)  
        x = self.features[6](x)  
        x3 = x  # 1/8  
        x = self.features[7](x)  
        x = self.features[8](x)  
        x = self.features[9](x)  
        x = self.features[10](x)  
        x = self.features[11](x)  
        x = self.features[12](x)  
        x = self.features[13](x)  
        x = self.features[14](x)  
        x = self.features[15](x)  
        x = self.features[16](x)  
        x = self.features[17](x)  
        x4 = x  # 1/16  
        return x4, x3, x2, x1, x0  

class Decoder(nn.Module):  

    def __init__(self, channels, feature_channels):  
        super().__init__()  
        self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)  
        self.bn1 = nn.BatchNorm2d(channels[1])  
        self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)  
        self.bn2 = nn.BatchNorm2d(channels[2])  
        self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)  
        self.bn3 = nn.BatchNorm2d(channels[3])  
        self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)  
        self.relu = nn.ReLU(True)  

    def forward(self, x4, x3, x2, x1, x0):  
        x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)  
        x = torch.cat([x, x3], dim=1)  
        x = self.conv1(x)  
        x = self.bn1(x)  
        x = self.relu(x)  
        x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)  
        x = torch.cat([x, x2], dim=1)  
        x = self.conv2(x)  
        x = self.bn2(x)  
        x = self.relu(x)  
        x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)  
        x = torch.cat([x, x1], dim=1)  
        x = self.conv3(x)  
        x = self.bn3(x)  
        x = self.relu(x)  
        x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)  
        x = torch.cat([x, x0], dim=1)  
        x = self.conv4(x)  
        return x  

class ASPPPooling(nn.Sequential):  
    def __init__(self, in_channels: int, out_channels: int) -> None:  
        super(ASPPPooling, self).__init__(  
            nn.AdaptiveAvgPool2d(1),  
            nn.Conv2d(in_channels, out_channels, 1, bias=False),  
            nn.BatchNorm2d(out_channels),  
            nn.ReLU())  

    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        size = x.shape[-2:]  
        for mod in self:  
            x = mod(x)  
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)  

class ASPPConv(nn.Sequential):  
    def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:  
        modules = [  
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),  
            nn.BatchNorm2d(out_channels),  
            nn.ReLU()  
        ]  
        super(ASPPConv, self).__init__(*modules)  

class ASPP(nn.Module):  
    def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:  
        super(ASPP, self).__init__()  
        modules = []  
        modules.append(nn.Sequential(  
            nn.Conv2d(in_channels, out_channels, 1, bias=False),  
            nn.BatchNorm2d(out_channels),  
            nn.ReLU()))  

        rates = tuple(atrous_rates)  
        for rate in rates:  
            modules.append(ASPPConv(in_channels, out_channels, rate))  

        modules.append(ASPPPooling(in_channels, out_channels))  

        self.convs = nn.ModuleList(modules)  

        self.project = nn.Sequential(  
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),  
            nn.BatchNorm2d(out_channels),  
            nn.ReLU(),  
            nn.Dropout(0.5))  

    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        _res = []  
        for conv in self.convs:  
            _res.append(conv(x))  
        res = torch.cat(_res, dim=1)  
        return self.project(res)  

class ResNetEncoder(ResNet):  
    layers = {  
        'resnet50': [3, 4, 6, 3],  
        'resnet101': [3, 4, 23, 3],  
    }  

    def __init__(self, in_channels, variant='resnet101', norm_layer=None):  
        super().__init__(  
            block=Bottleneck,  
            layers=self.layers[variant],  
            replace_stride_with_dilation=[False, False, True],  
            norm_layer=norm_layer)  

        # in_channelsが一致しない場合は最初のconvレイヤーを置き換えます。  
        if in_channels != 3:  
            self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)  

        # 完全接続レイヤーを削除  
        del self.avgpool  
        del self.fc  

    def forward(self, x):  
        x0 = x  # 1/1  
        x = self.conv1(x)  
        x = self.bn1(x)  
        x = self.relu(x)  
        x1 = x  # 1/2  
        x = self.maxpool(x)  
        x = self.layer1(x)  
        x2 = x  # 1/4  
        x = self.layer2(x)  
        x3 = x  # 1/8  
        x = self.layer3(x)  
        x = self.layer4(x)  
        x4 = x  # 1/16  
        return x4, x3, x2, x1, x0  

class Base(nn.Module):  
    """  
    DeepLabに触発されたベースエンコーダーデコーダーネットワークの一般的な実装。  
    入力と出力のために任意のチャネルを受け入れます。  
    """  

    def __init__(self, backbone: str, in_channels: int, out_channels: int):  
        super().__init__()  
        assert backbone in ["resnet50", "resnet101", "mobilenetv2"]  
        if backbone in ['resnet50', 'resnet101']:  
            self.backbone = ResNetEncoder(in_channels, variant=backbone)  
            self.aspp = ASPP(2048, [3, 6, 9])  
            self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])  
        else:  
            self.backbone = MobileNetV2Encoder(in_channels)  
            self.aspp = ASPP(320, [3, 6, 9])  
            self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])  

    def forward(self, x):  
        x, *shortcuts = self.backbone(x)  
        x = self.aspp(x)  
        x = self.decoder(x, *shortcuts)  
        return x  

    def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):  
        # 事前学習済みDeepLabV3モデルは<https://github.com/VainF/DeepLabV3Plus-Pytorch>によって提供されています。  
        # このメソッドは、事前学習済みのstate_dictを変換して、私たちのモデル構造に一致させてロードします。  
        # このメソッドは、deeplabの重みからトレーニングする予定がない場合は必要ありません。  
        # 通常の重みのロードにはload_state_dict()を使用します。  

        # asppモジュールのためにstate_dictの命名を変換  
        state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}  

        if isinstance(self.backbone, ResNetEncoder):  
            # ResNetバックボーンは変更の必要がありません。  
            load_matched_state_dict(self, state_dict, print_stats)  
        else:  
            # MobileNetV2バックボーンをstate_dict形式に変更し、ロード後に戻します。  
            backbone_features = self.backbone.features  
            self.backbone.low_level_features = backbone_features[:4]  
            self.backbone.high_level_features = backbone_features[4:]  
            del self.backbone.features  
            load_matched_state_dict(self, state_dict, print_stats)  
            self.backbone.features = backbone_features  
            del self.backbone.low_level_features  
            del self.backbone.high_level_features  

class MattingBase(Base):  

    def __init__(self, backbone: str):  
        super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))  

    def forward(self, src, bgr):  
        x = torch.cat([src, bgr], dim=1)  
        x, *shortcuts = self.backbone(x)  
        x = self.aspp(x)  
        x = self.decoder(x, *shortcuts)  
        pha = x[:, 0:1].clamp_(0., 1.)  
        fgr = x[:, 1:4].add(src).clamp_(0., 1.)  
        err = x[:, 4:5].clamp_(0., 1.)  
        hid = x[:, 5:].relu_()  
        return pha, fgr, err, hid  

class MattingRefine(MattingBase):  

    def __init__(self,  
                 backbone: str,  
                 backbone_scale: float = 1 / 4,  
                 refine_mode: str = 'sampling',  
                 refine_sample_pixels: int = 80_000,  
                 refine_threshold: float = 0.1,  
                 refine_kernel_size: int = 3,  
                 refine_prevent_oversampling: bool = True,  
                 refine_patch_crop_method: str = 'unfold',  
                 refine_patch_replace_method: str = 'scatter_nd'):  
        assert backbone_scale <= 1 / 2, 'backbone_scaleは1/2を超えてはいけません'  
        super().__init__(backbone)  
        self.backbone_scale = backbone_scale  
        self.refiner = Refiner(refine_mode,  
                               refine_sample_pixels,  
                               refine_threshold,  
                               refine_kernel_size,  
                               refine_prevent_oversampling,  
                               refine_patch_crop_method,  
                               refine_patch_replace_method)  

    def forward(self, src, bgr):  
        assert src.size() == bgr.size(), 'srcとbgrは同じ形状でなければなりません'  
        assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \  
            'srcとbgrは幅と高さが4で割り切れる必要があります'  

        # バックボーン用にsrcとbgrをダウンサンプル  
        src_sm = F.interpolate(src,  
                               scale_factor=self.backbone_scale,  
                               mode='bilinear',  
                               align_corners=False,  
                               recompute_scale_factor=True)  
        bgr_sm = F.interpolate(bgr,  
                               scale_factor=self.backbone_scale,  
                               mode='bilinear',  
                               align_corners=False,  
                               recompute_scale_factor=True)  

        # ベース  
        x = torch.cat([src_sm, bgr_sm], dim=1)  
        x, *shortcuts = self.backbone(x)  
        x = self.aspp(x)  
        x = self.decoder(x, *shortcuts)  
        pha_sm = x[:, 0:1].clamp_(0., 1.)  
        fgr_sm = x[:, 1:4]  
        err_sm = x[:, 4:5].clamp_(0., 1.)  
        hid_sm = x[:, 5:].relu_()  

        # リファイナー  
        pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)  

        # 出力をクランプ  
        pha = pha.clamp_(0., 1.)  
        fgr = fgr.add_(src).clamp_(0., 1.)  
        fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)  

        return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm  

class ImagesDataset(Dataset):  
    def __init__(self, root, mode='RGB', transforms=None):  
        self.transforms = transforms  
        self.mode = mode  
        self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),  
                                 *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])  

    def __len__(self):  
        return len(self.filenames)  

    def __getitem__(self, idx):  
        with Image.open(self.filenames[idx]) as img:  
            img = img.convert(self.mode)  
        if self.transforms:  
            img = self.transforms(img)  

        return img  

class NewImagesDataset(Dataset):  
    def __init__(self, root, mode='RGB', transforms=None):  
        self.transforms = transforms  
        self.mode = mode  
        self.filenames = [root]  
        print(self.filenames)  

    def __len__(self):  
        return len(self.filenames)  

    def __getitem__(self, idx):  
        with Image.open(self.filenames[idx]) as img:  
            img = img.convert(self.mode)  

        if self.transforms:  
            img = self.transforms(img)  

        return img  

class ZipDataset(Dataset):  
    def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):  
        self.datasets = datasets  
        self.transforms = transforms  

        if assert_equal_length:  
            for i in range(1, len(datasets)):  
                assert len(datasets[i]) == len(datasets[i - 1]), 'データセットの長さが一致しません。'  

    def __len__(self):  
        return max(len(d) for d in self.datasets)  

    def __getitem__(self, idx):  
        x = tuple(d[idx % len(d)] for d in self.datasets)  
        print(x)  
        if self.transforms:  
            x = self.transforms(*x)  
        return x  

class PairCompose(T.Compose):  
    def __call__(self, *x):  
        for transform in self.transforms:  
            x = transform(*x)  
        return x  

class PairApply:  
    def __init__(self, transforms):  
        self.transforms = transforms  

    def __call__(self, *x):  
        return [self.transforms(xi) for xi in x]  

# --------------- Arguments ---------------  

parser = argparse.ArgumentParser(description='hy-replace-background')  

parser.add_argument('--model-type', type=str, required=False, choices=['mattingbase', 'mattingrefine'],  
                    default='mattingrefine')  
parser.add_argument('--model-backbone', type=str, required=False, choices=['resnet101', 'resnet50', 'mobilenetv2'],  
                    default='resnet50')  
parser.add_argument('--model-backbone-scale', type=float, default=0.25)  
parser.add_argument('--model-checkpoint', type=str, required=False, default='model/pytorch_resnet50.pth')  
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])  
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)  
parser.add_argument('--model-refine-threshold', type=float, default=0.7)  
parser.add_argument('--model-refine-kernel-size', type=int, default=3)  

parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')  
parser.add_argument('--num-workers', type=int, default=0,  
                    help='DataLoaderで使用されるワーカースレッドの数。Windowsでは単一スレッド(0)を使用する必要があります。')  
parser.add_argument('--preprocess-alignment', action='store_true')  

parser.add_argument('--output-dir', type=str, required=False, default='content/output')  
parser.add_argument('--output-types', type=str, required=False, nargs='+',  
                    choices=['com', 'pha', 'fgr', 'err', 'ref', 'new'],  
                    default=['new'])  
parser.add_argument('-y', action='store_true')  

def handle(image_path: str, bgr_path: str, new_bg: str):  
    parser.add_argument('--images-src', type=str, required=False, default=image_path)  
    parser.add_argument('--images-bgr', type=str, required=False, default=bgr_path)  
    args = parser.parse_args()  

    assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \  
        'err出力をサポートしているのはmattingbaseとmattingrefineのみです'  
    assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \  
        'ref出力をサポートしているのはmattingrefineのみです'  

    # --------------- Main ---------------  

    device = torch.device(args.device)  

    # モデルをロード  
    if args.model_type == 'mattingbase':  
        model = MattingBase(args.model_backbone)  
    if args.model_type == 'mattingrefine':  
        model = MattingRefine(  
            args.model_backbone,  
            args.model_backbone_scale,  
            args.model_refine_mode,  
            args.model_refine_sample_pixels,  
            args.model_refine_threshold,  
            args.model_refine_kernel_size)  

    model = model.to(device).eval()  
    model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)  

    # 画像をロード  
    dataset = ZipDataset([  
        NewImagesDataset(args.images_src),  
        NewImagesDataset(args.images_bgr),  
    ], assert_equal_length=True, transforms=PairCompose([  
        HomographicAlignment() if args.preprocess_alignment else PairApply(nn.Identity()),  
        PairApply(T.ToTensor())  
    ]))  
    dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True)  

    # # 出力ディレクトリを作成  
    # if os.path.exists(args.output_dir):  
    #     if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':  
    #         shutil.rmtree(args.output_dir)  
    #     else:  
    #         exit()  

    for output_type in args.output_types:  
        if os.path.exists(os.path.join(args.output_dir, output_type)) is False:  
            os.makedirs(os.path.join(args.output_dir, output_type))  

    # ワーカ関数  
    def writer(img, path):  
        img = to_pil_image(img[0].cpu())  
        img.save(path)  

    # ワーカ関数  
    def writer_hy(img, new_bg, path):  
        img = to_pil_image(img[0].cpu())  
        img_size = img.size  
        new_bg_img = Image.open(new_bg).convert('RGBA')  
        new_bg_img.resize(img_size, Image.ANTIALIAS)  
        out = Image.alpha_composite(new_bg_img, img)  
        out.save(path)  

    result_file_name = str(uuid.uuid4())  

    # 変換ループ  
    with torch.no_grad():  
        for i, (src, bgr) in enumerate(tqdm(dataloader)):  
            src = src.to(device, non_blocking=True)  
            bgr = bgr.to(device, non_blocking=True)  

            if args.model_type == 'mattingbase':  
                pha, fgr, err, _ = model(src, bgr)  
            elif args.model_type == 'mattingrefine':  
                pha, fgr, _, _, err, ref = model(src, bgr)  

            pathname = dataset.datasets[0].filenames[i]  
            pathname = os.path.relpath(pathname, args.images_src)  
            pathname = os.path.splitext(pathname)[0]  

            if 'new' in args.output_types:  
                new = torch.cat([fgr * pha.ne(0), pha], dim=1)  
                Thread(target=writer_hy,  
                       args=(new, new_bg, os.path.join(args.output_dir, 'new', result_file_name + '.png'))).start()  
            if 'com' in args.output_types:  
                com = torch.cat([fgr * pha.ne(0), pha], dim=1)  
                Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()  
            if 'pha' in args.output_types:  
                Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()  
            if 'fgr' in args.output_types:  
                Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()  
            if 'err' in args.output_types:  
                err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)  
                Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()  
            if 'ref' in args.output_types:  
                ref = F.interpolate(ref, src.shape[2:], mode='nearest')  
                Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()  

    return os.path.join(args.output_dir, 'new', result_file_name + '.png')  

if __name__ == '__main__':  
    handle("data/img2.png", "data/bg.png", "data/newbg.jpg")  

コード説明

1、handle メソッドの引数はそれぞれ:元の画像のパス、元の背景画像のパス、新しい背景画像のパスです。

1、元のプロジェクトで inferance_images で使用されていたクラスをすべて 1 つのファイルに移動し、プロジェクト構造を簡素化しました。

2、ImagesDateSet を再構築した新しい NewImagesDateSet を作成しました。主に 1 枚の画像だけを処理するつもりだからです。

3、最終的な画像はすべて同じディレクトリに保存され、重複して uuid をファイル名として使用しないようにしました。

4、この記事で提供されるコードはファイル形式に対して厳密な検証を行っていませんが、それほど重要ではありません。必要であれば補足してください。

効果を検証する
640 (5)

読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。