1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| class InversionNet(nn.Module): def __init__(self, dim1=32, dim2=64, dim3=128, dim4=256, dim5=512, **kwargs): """ Network architecture of InversionNet
:param dim1: Number of channels in the 1st layer :param dim2: Number of channels in the 2nd layer :param dim3: Number of channels in the 3rd layer :param dim4: Number of channels in the 4th layer :param dim5: Number of channels in the 5th layer :param sample_spatial: Scale parameters for sampling in space """ super(InversionNet, self).__init__() self.convblock1 = ConvBlock(5, dim1, kernel_size=(7, 1), stride=(2, 1), padding=(3, 0)) self.convblock2_1 = ConvBlock(dim1, dim2, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)) self.convblock2_2 = ConvBlock(dim2, dim2, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.convblock3_1 = ConvBlock(dim2, dim2, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)) self.convblock3_2 = ConvBlock(dim2, dim2, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.convblock4_1 = ConvBlock(dim2, dim3, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)) self.convblock4_2 = ConvBlock(dim3, dim3, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.convblock5_1 = ConvBlock(dim3, dim3, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)) self.convblock5_2 = ConvBlock(dim3, dim3, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.convblock6_1 = ConvBlock(dim3, dim4, kernel_size=(3, 3), stride=2, padding=1) self.convblock6_2 = ConvBlock(dim4, dim4, kernel_size=(3, 3), stride=1, padding=1) self.convblock7_1 = ConvBlock(dim4, dim4, kernel_size=(3, 3), stride=2, padding=1) self.convblock7_2 = ConvBlock(dim4, dim4, kernel_size=(3, 3), stride=1, padding=1) self.convblock8 = ConvBlock(dim4, dim5, kernel_size=8, padding=0)
self.deconv1_1 = DeconvBlock(dim5, dim5, kernel_size=5, stride=1, padding=0) self.deconv1_2 = ConvBlock(dim5, dim5, kernel_size=3, stride=1) self.deconv2_1 = DeconvBlock(dim5, dim4, kernel_size=4, stride=2, padding=1) self.deconv2_2 = ConvBlock(dim4, dim4, kernel_size=3, stride=1) self.deconv3_1 = DeconvBlock(dim4, dim3, kernel_size=4, stride=2, padding=1) self.deconv3_2 = ConvBlock(dim3, dim3, kernel_size=3, stride=1) self.deconv4_1 = DeconvBlock(dim3, dim2, kernel_size=4, stride=2, padding=1) self.deconv4_2 = ConvBlock(dim2, dim2, kernel_size=3, stride=1) self.deconv5_1 = DeconvBlock(dim2, dim1, kernel_size=4, stride=2, padding=1) self.deconv5_2 = ConvBlock(dim1, dim1, kernel_size=3, stride=1) self.deconv6 = ConvBlock_Tanh(dim1, 1)
def forward(self, x): x = self.convblock1(x) x = self.convblock2_1(x) x = self.convblock2_2(x) x = self.convblock3_1(x) x = self.convblock3_2(x) x = self.convblock4_1(x) x = self.convblock4_2(x) x = self.convblock5_1(x) x = self.convblock5_2(x) x = self.convblock6_1(x) x = self.convblock6_2(x) x = self.convblock7_1(x) x = self.convblock7_2(x) x = self.convblock8(x)
x = self.deconv1_1(x) x = self.deconv1_2(x) x = self.deconv2_1(x) x = self.deconv2_2(x) x = self.deconv3_1(x) x = self.deconv3_2(x) x = self.deconv4_1(x) x = self.deconv4_2(x) x = self.deconv5_1(x) x = self.deconv5_2(x) x = F.pad(x, [-2, -3, -2, -3], mode="constant", value=0) x = self.deconv6(x) return x
|