GAN里面什么时候用detach的说明

2023-09-22 11:52:21

在生成对抗网络(GAN)中,生成器(G)和判别器(D)通常是两个独立的神经网络,它们之间会有梯度传播的互动。下面是一个简单的GAN的PyTorch实现,用于生成一维数据,以展示何时应该使用detach()。

import torch
import torch.nn as nn
import torch.optim as optim

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(10, 50),
            nn.ReLU(),
            nn.Linear(50, 1)
        )
    
    def forward(self, x):
        return self.model(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(1, 50),
            nn.ReLU(),
            nn.Linear(50, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

# 实例化生成器和判别器
G = Generator()
D = Discriminator()

# 定义优化器和损失函数
optimizer_G = optim.Adam(G.parameters(), lr=0.001)
optimizer_D = optim.Adam(D.parameters(), lr=0.001)
loss_func = nn.BCELoss()

# 训练循环
for epoch in range(1000):
    # 训练判别器
    D.zero_grad()
    real_data = torch.randn(100, 1)  # 真实数据
    real_labels = torch.ones(100, 1) # 真实标签
    fake_data = G(torch.randn(100, 10)).detach() # 使用detach(), 因为我们不想在这一步更新生成器
    fake_labels = torch.zeros(100, 1) # 假的标签

    real_loss = loss_func(D(real_data), real_labels)
	# real_loss = loss_func(D(real_data.detach), real_labels)
    fake_loss = loss_func(D(fake_data), fake_labels)
    d_loss = real_loss + fake_loss
    d_loss.backward()
    optimizer_D.step()

    # 训练生成器
    G.zero_grad()
    noise_data = torch.randn(100, 10) # 噪声数据
    fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
    g_loss = loss_func(D(fake_data), torch.ones(100, 1))
    g_loss.backward()
    optimizer_G.step()

在这个例子中:

  1. 当训练判别器(D)时,我们使用了detach()来中断梯度传播到生成器(G)。这是因为在这一步中,我们仅关心优化判别器,而不希望更新生成器的参数。
  2. 当训练生成器(G)时,我们没有使用detach(),因为我们需要通过反向传播的梯度来更新生成器的参数。

注意:在训练判别器时,不使用real_loss = loss_func(D(real_data.detach), real_labels), 也就是这里不需要对real_data进行detach操作。

而且即使对real_data进行.detach()操作实际上应该不会有明显影响,原因在于real_data并不是通过模型参数生成的,也不是一个需要优化的变量。.detach()方法主要用于将一个张量从当前计算图中分离出来,阻止反向传播过程中对其计算梯度。但在本例中,real_data本身就没有与需要优化的模型参数有直接关系,也不是由其他需要优化的变量通过一些运算得到的。

注意: 在训练判别器时,使用fake_data = G(torch.randn(100, 10)).detach(), 注意是因为这个fake_data是由生成器G生成的, 为了保证分开训练判别器和生成器,即在训练判别器的时候,不对生成器的参数进行更新,这里就要把G生成的数据进行detach操作

在训练生成器时, 也用到了判别器,用判别器去判别生成器生成的内容,希望判别器能把G生成的内容当做真的,这样就说明G的生成的内容可以以假乱真

fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
g_loss = loss_func(D(fake_data), torch.ones(100, 1))
g_loss.backward()
optimizer_G.step()

上面没有对传进D的fake_data进行detach,是因为下面的代码只有g_loss_backward(),也就是只对G进行参数更新,当然这里也不能对fake_data进行detach,如果detach了,就无法更新G的参数了

更多推荐

P1827 [USACO3.4] 美国血统 American Heritage(前序 + 中序 生成后序)

P1827[USACO3.4]美国血统AmericanHeritage(前序+中序生成后序)一、前言二叉树入门题。涉及到树的基本知识、树的结构、树的生成。本文从会从结构,到完成到,优化。二、基础知识Ⅰ、二叉树的遍历前序遍历:根左右中序遍历:左根右后序遍历:左右根通过上面的观察,可得根在那,就是什么方式的遍历Ⅱ、二叉树的

Kotlin Coroutines包下的select函数简介

在工作中,发现了kotlinCoroutines包下有大量功能非常强大的API,这篇文章中,我们主要来聊一聊select函数1.什么是select函数想象一下这个场景,在程序应用中,为了实现一个业务逻辑,你可能有好几种方式来实现,但是我只需要最快实现结果的一种方式,这时候我们就可以使用select函数了。如果还不是很清

地球系统模式(CESM)技术应用

近年升级的CESM2.0在大气、陆地、海洋、海冰、陆冰、径流等几大模块以及一个中央耦合器(CIME)中都有较大更新,可以在不同的硬件平台上移植使用,尤其可以用于CMIP6的研究。CESM中CIME(CommonInfrastructureforModelingtheEarth)为模式配置、编译和运行提供个例控制器。CA

web浏览器公网远程访问jupyter notebook【内网穿透】

文章目录前言1.Python环境安装2.Jupyter安装3.启动JupyterNotebook4.远程访问4.1安装配置cpolar内网穿透4.2创建隧道映射本地端口5.固定公网地址前言JupyterNotebook,它是一个交互式的数据科学和计算环境,支持多种编程语言,如Python、R、Julia等。它在数据科学

数据不平衡GPT调研

数据不平衡判别式和生成式的区别是什么判别式模型(DiscriminativeModels)生成式模型(GenerativeModels)对比对于AE或者VAE这种生成式模型,其实更关注数据本身,那这种有什么好处?那对于判别式模型,它更关注什么呢?它存在什么样的弊端?比如可能落入局部最优,无法进行优化啥的展开讲讲这个判别

android 存储新特性

分区存储本页内容应用访问限制将分区存储与FUSE搭配使用FUSE和SDCardFSFUSE性能微调减轻与FUSE相关的性能影响隐私优势远超性能劣势MediaProvider和FUSE更新分区存储会限制应用访问外部存储空间。在Android11或更高版本中,以API30或更高版本为目标平台的应用必须使用分区存储。之前,在

性能测试监控指标及分析调优 | 京东云技术团队一、哪些因素会成为系统的瓶颈?

1.什么是MAF和MEF?MEF和MEF微软官方介绍:ManagedExtensibilityFramework(MEF)-.NETFramework|MicrosoftLearnMEF是轻量化的插件框架,MAF是复杂的插件框架。因为MAF有进程隔离和程序域隔离可选。我需要插件进程隔离同时快速传递数据,最后选择了MAF

【FAQ】安防视频监控平台EasyNVR无法控制云台,该如何解决?

TSINGSEE青犀视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。在智慧安防等视频监控场景中,EasyNVR可提供视频实时监控直播、云端录像、云存储、录像检索与回看、

vue3如何导入使用自定义yaml配置文件

要在Vue3中导入和使用自定义的YAML配置文件,你可以按照以下步骤进行操作:安装所需的依赖:首先,确保你的项目中已经安装了vue和vue-router。你还需要使用js-yaml库来解析YAML文件,可以使用以下命令进行安装:npminstalljs-yaml创建YAML文件:在你的项目中创建一个YAML文件,用于存

Linux 软件包管理器-yum使用

文章目录前言一、yum使用1、什么是软件包2、yum源3、yumlist指令4、yuminstall指令5、yumremove指令二、git的使用1、gitee中仓库的创建2、仓库的克隆3、提交代码到远程仓库4、提交时可能遇到的问题5、.gitignore文件6、删除文件前言一、yum使用1、什么是软件包在Linux下

SLAM从入门到精通(消息传递)

【声明:版权所有,欢迎转载,请勿用于商业用途。联系信箱:feixiaoxing@163.com】前面我们只是编写了一个publisher节点,以及一个subscribe节点。有了这两个节点,它们之间就可以通信了。在实际生产中,我们除了简单的通信之外,要传递的数据可能还有很多。这个时候,我们就要构建一个消息体。这个消息体

热文推荐