基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(一)

2023-09-21 21:38:55

前面完成了基于知识蒸馏的去雨去雪去雾模型大的部署与训练,下面则进行代码的学习。
使用debug的方式进行代码的学习。
首先是网络结构展示:轻易不要打开,这个模型太复杂了。说到底倒不是多复杂,就是层数太多了

Net(
  (conv_input): ConvLayer(
    (reflection_pad): ReflectionPad2d((5, 5, 5, 5))
    (conv2d): Conv2d(3, 16, kernel_size=(11, 11), stride=(1, 1))
  )
  (dense0): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv2x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv1): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(80, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion1): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense1): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv4x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv2): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion2): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense2): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv8x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv3): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion3): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense3): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv16x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv4): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion4): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dehaze): Sequential(
    (res0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res3): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res4): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res5): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res6): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res7): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res8): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res9): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res10): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res11): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res12): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res13): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res14): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res15): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res16): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res17): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (convd16x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_4): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_4): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_4): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd8x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_3): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_3): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_3): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd4x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_2): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_2): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(80, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_2): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd2x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_1): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_1): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(40, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_1): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (conv_output): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1))
  )
)

首先进入训练模式,又称知识收集训练阶段:

def train_kc_stage(model, teacher_networks, ckt_modules, train_loader, optimizer, scheduler, epoch, criterions):
	print(Fore.CYAN + "==> Training Stage 1")
	print("==> Epoch {}/{}".format(epoch, args.max_epoch))
	print("==> Learning Rate = {:.6f}".format(optimizer.param_groups[0]['lr']))
	meters = get_meter(num_meters=5)	
	criterion_l1, criterion_scr, _ = criterions
	model.train()
	ckt_modules.train()
	for teacher_network in teacher_networks:
		teacher_network.eval()

声明所需要的损失函数,ckt_models(协作知识迁移模型)的训练模式
ckt_models 的详细结构如下:

ModuleList(
  (0): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (1): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (2): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (3): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
)

criterions的结构,其定义的是损失函数,分别是L1损失,SCR损失以及HCR损失

ModuleList(
  (0): L1Loss()
  (1): SCRLoss(
    (vgg): Vgg19(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (slice2): Sequential(
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
      )
      (slice3): Sequential(
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
      )
      (slice4): Sequential(
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (17): ReLU(inplace=True)
        (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
      )
      (slice5): Sequential(
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (24): ReLU(inplace=True)
        (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (26): ReLU(inplace=True)
        (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
      )
    )
    (l1): L1Loss()
  )
  (2): HCRLoss(
    (vgg): Vgg19(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (slice2): Sequential(
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
      )
      (slice3): Sequential(
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
      )
      (slice4): Sequential(
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (17): ReLU(inplace=True)
        (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
      )
      (slice5): Sequential(
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (24): ReLU(inplace=True)
        (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (26): ReLU(inplace=True)
        (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
      )
    )
    (l1): L1Loss()
  )
)

可以看到教师网络就是将先前的Net网络复制了3份,只是加载不同权重而已。即三个model。

在这里插入图片描述

在这里插入图片描述

继续训练

start = time.time()
pBar = tqdm(train_loader, desc='Training')
for target_images, input_images in pBar:
	
	# Check whether the batch contains all types of degraded data
	if target_images is None: continue

	# move to GPU
	target_images = target_images.cuda()
	input_images = [images.cuda() for images in input_images]

	# Fix all teachers and collect reconstruction results and features from cooresponding teacher
	preds_from_teachers = []
	features_from_each_teachers = []
	with torch.no_grad():
		for i in range(len(teacher_networks)):
			preds, features = teacher_networks[i](input_images[i], return_feat=True)
			preds_from_teachers.append(preds)
			features_from_each_teachers.append(features)	
			
	preds_from_teachers = torch.cat(preds_from_teachers)
	features_from_teachers = []
	for layer in range(len(features_from_each_teachers[0])):
		features_from_teachers.append([features_from_each_teachers[i][layer] for i in range(len(teacher_networks))])

	preds_from_student, features_from_student = model(torch.cat(input_images), return_feat=True)   

	
	# Project the features to common feature space and calculate the loss
	PFE_loss, PFV_loss = 0., 0.
	for i, (s_features, t_features) in enumerate(zip(features_from_student, features_from_teachers)):
		t_proj_features, t_recons_features, s_proj_features = ckt_modules[i](t_features, s_features)
		PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
		PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

	T_loss = criterion_l1(preds_from_student, preds_from_teachers)
	SCR_loss = 0.1 * criterion_scr(preds_from_student, target_images, torch.cat(input_images))
	total_loss = T_loss + PFE_loss + PFV_loss + SCR_loss

	optimizer.zero_grad()
	total_loss.backward()
	optimizer.step()

进入评估模块:加载模型,验证集,最终输出psnr与ssim

if epoch % args.val_freq == 0:
			psnr, ssim = evaluate(model, val_loader, epoch)
			# Check whether the model is top-k model
			top_k_state = save_top_k(model, optimizer, scheduler, top_k_state, args.top_k, epoch, args.save_dir, psnr=psnr, ssim=ssim)

evaluate(model, val_loader, epoch) 函数详细代码:

在这里插入图片描述
随后进行结果输出:

pred = model(image)

即跳入Net的forward中进行特征提取

输入值:

输入x: 图像维度为640x480,此时初始维度:torch.Size([1, 3, 480, 640])
随后经过一系列的卷积降维,生成了如下特征图:这个过程就不赘述了。

在这里插入图片描述

输出值:

输出x与feature:最终的x的维度依旧为torch.Size([1, 3, 480, 640])

在这里插入图片描述

feature的维度,共有4个特征图,分别如下:

在这里插入图片描述

这里设置只输出x,所以pred的值即为x的值:

在这里插入图片描述
得到输出值后,即可进行损失的计算了:

psnr_list.append(torchPSNR(pred, target).item())
ssim_list.append(pytorch_ssim.ssim(pred, target).item())

具体实现:

@torch.no_grad()
def torchPSNR(prd_img, tar_img):
	if not isinstance(prd_img, torch.Tensor):
		prd_img = torch.from_numpy(prd_img)
		tar_img = torch.from_numpy(tar_img)

	imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)
	rmse = (imdff**2).mean().sqrt()
	ps = 20 * torch.log10(1/rmse)
	return ps

最终将19张图片全部评估完毕:

在这里插入图片描述
得到psnr_list值:
在这里插入图片描述
需要19张全部评估完,这里只进行了两张。

最终返回平均值:

return np.mean(psnr_list), np.mean(ssim_list)

该方法最终的值变为:

在这里插入图片描述

更多推荐

浅谈安科瑞智慧消防在城市综合体应急安全中的应用

摘要:城市综合体作为当前城市化进程中主要的系统组成部分,具有空间大、容纳的人员多、火灾隐患多等特点,为了延长城市综合体运行寿命,加强智慧消防建设,切实需要做好城市综合体运行安全性的防护,从而为城市社会经济发展发挥重要的促使作用。关键字:智慧消防;城市综合体;应急安全一、概述随着社会经济与科学技术的发展,智慧城市成为发展

【C++STL基础入门】list的运算符重载和关于list的算法

文章目录前言一、list运算符1.1逻辑运算符1.2赋值运算符二、list相关算法2.1查找函数总结前言C++标准模板库(STL)是一组强大而灵活的工具,用于处理数据结构和算法。其中,std::list是STL中的一个重要容器,它实现了双向链表的数据结构,具有高效的插入和删除操作。本文将介绍list容器的运算符重载和相

设计模式之中介者模式

尽管将一个系统分割成许多对象通常可以增加其可复用性,但是对象间相互连接的激增又会降低其可复用性大量的连接使得一个对象不可能在没有其他对象的支持下工作,系统表现为一个不可分割的整体,所以,对系统的行为进行任何较大的改动就十分困难了中介者模式,用一个中介对象来封装一系列的对象交互。中介者使各对象不需要显式地相互引用,从而使

【Vue】Vue的监听属性与计算属性

在Vue中,有两种类型的属性是用于响应式的,即监听属性和计算属性。监听属性:监听属性是声明在Vue实例的data选项中的属性,它们用于存储应用程序中的状态或数据。当监听属性的值发生变化时,Vue会自动响应地更新绑定到该属性的视图。例如:data:{message:'HelloVue!'}当message的值改变时,绑定

树结构数据在table中回显 treeselect disabled

<el-table-columnlabel="产业认定"align="center"prop="industryIdentification"><templateslot-scope="scope"><treeselectv-if="scope.row.industryIdentification"v-model="s

(Clock Domain Crossing)跨时钟域信号的处理 (自我总结)

CummingsSNUG2008Boston_CDC.pdf参考:跨时钟域处理方法总结–最终详尽版-love小酒窝-博客园跨时钟域(CDC)设计方法之单bit信号篇(一)|电子创新网赛灵思社区孤独的单刀_Verilog语法,FPGA设计与调试,FPGA接口与协议-CSDN博客跨时钟域传输总结(包含verilog代码|T

低代码开源项目整理

低代码是基于可视化和模型驱动理念,结合云原生与多端体验技术,它能够在多数业务场景下实现大幅度的提效降本,为专业开发者提供了一种全新的高生产力开发范式。下面就来分享几个值得学习和使用的前端低代码开源项目,更深入地了解什么是低代码。1AppsmithAppsmith是一款开源低代码框架,主要用于构建管理面板、内部工具和仪表

基于SSM+Vue的亿互游在线平台的设计与开发

末尾获取源码开发语言:JavaJava开发工具:JDK1.8后端框架:SSM前端:采用Vue技术开发数据库:MySQL5.7和Navicat管理工具结合服务器:Tomcat8.5开发软件:IDEA/Eclipse是否Maven项目:是目录一、项目简介二、系统功能三、系统项目截图用户功能模块的实现管理员功能模块的实现前台

王道考研计算机组成原理

王道考研计算机组成原理计算机系统概述计算机系统层次结构计算机的性能指标错题数据的表示和运算数制与编码运算方法和运算电路浮点数的表示与运算存储系统存储器概述主存储器主存储器与CPU的连接外部存储器高速缓冲存储器虚拟存储器指令系统指令格式指令的寻址方式程序的机器级代码表示CISC和RISC的基本概念中央处理器CPU的功能和

苹果专用解压缩推荐 BetterZip 5 中文for mac

BetterZip5是一款功能强大的解压缩软件,旨在帮助用户在Mac上快速、方便地解压缩各种压缩文件格式。它支持常见的压缩文件格式,如ZIP、RAR、7z、TAR、GZIP等,并提供了一系列高级功能和工具。以下是BetterZip5解压缩的一般步骤:打开BetterZip5:在Mac上启动BetterZip5应用程序。

TCP拥塞控制,拥塞窗口,携带应答,捎带应答,面向字节流,异常情况处理,最终完结弹

之前我们已经将TCP的前面一些性质介绍过了,接下来来介绍剩余的一些性质1.确认应答2.超时重新传输3.连接管理4.滑动窗口5.流量控制目录一、拥塞控制二、拥塞窗口三、携带应答四、粘包问题方法1:应用层协议引入分隔符方法2:应用层协议引入包长度五、TCP异常情况的处理一、拥塞控制💛总的传输效率,是一个木桶效应,取决于最

热文推荐