1
# --------------------------------------------------------------------------------------------------
# --------------------------- MODULE :: Detail Enhancer and Text Remover ---------------------------
# --------------------------------------------------------------------------------------------------
from . Network import Network
import torch.nn.functional as functional
import torch.nn as nn
import torch
# --------------------------------------------------------------------------------------------------
# -------------------------------- CLASS :: Enhancer & Text Remover --------------------------------
# --------------------------------------------------------------------------------------------------
class Sharpener(Network):
# ----------------------------------------------------------------------------------------------
# -------------------------------- METHOD :: Compute Loss Value --------------------------------
# ----------------------------------------------------------------------------------------------
def loss(self, reconstruction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return nn.functional.mse_loss(reconstruction, target, reduction='sum')
# ------------------------------------------------------------------------------------------
# ------------------------------- Constructor :: Constructor -------------------------------
# ------------------------------------------------------------------------------------------
def __init__(self, name: str) -> None:
super().__init__(name)
self.noise = nn.Dropout(0.10)
self.funct = nn.ReLU()
self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=1, padding=1)
self.norm1 = nn.BatchNorm2d(num_features=64)
# self.conv2 = nn.Conv2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm2 = nn.BatchNorm2d(num_features=32)
#
# self.conv3 = nn.Conv2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm3 = nn.BatchNorm2d(num_features=32)
#
# self.conv4 = nn.Conv2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm4 = nn.BatchNorm2d(num_features=32)
#
# self.conv5 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm5 = nn.BatchNorm2d(num_features=32)
#
# self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm6 = nn.BatchNorm2d(num_features=32)
#
# self.conv7 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=1, padding=1)
# self.norm7 = nn.BatchNorm2d(num_features=32)
self.conv8 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=1, padding=1)
# ------------------------------------------------------------------------------------------
# ------------------------------ METHOD :: Forward Activation ------------------------------
# ------------------------------------------------------------------------------------------
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.noise(self.funct(self.norm1(self.conv1(x)))) # 4x4
# x = self.noise(self.funct(self.norm2(self.conv2(x)))) # 8x8
# x = self.noise(self.funct(self.norm3(self.conv3(x)))) # 12x12
# x = self.noise(self.funct(self.norm4(self.conv4(x)))) # 16x16
# x = self.noise(self.funct(self.norm5(self.conv5(x)))) # 16x16
# x = self.noise(self.funct(self.norm6(self.conv6(x)))) # 12x12
# x = self.noise(self.funct(self.norm7(self.conv7(x)))) # 8x8
return torch.sigmoid(self.conv8(x)) # 4x4
For immediate assistance, please email our customer support: [email protected]