pytorch 自动微分机制
保留计算图
PyTorch在多线程使用autograd时的一个潜在问题 - 保留计算图。主要内容如下:
- 当不同线程共享autograd计算图的一部分时,例如前向传播的第一部分在单线程,第二部分在多线程中执行。
- 这个共享的图会被不同的线程调用grad()或backward()进行反向求导。
- 问题在于一个线程在求解时可能会破坏计算图,而另一个线程还在使用相同的计算图,因此会crash。
- 此时PyTorch会像调用backward()两次而不使用retain_graph=True那样报错。
- 提示用户在这种共享计算图的多线程使用场景下,需要设置retain_graph=True。
- 以保留住计算图不被反向求导操作破坏,避免线程间的冲突。
总结来说,多线程下可能共享图的一部分,此时需要retain_graph=True来保留计算图,避免线程间的不确定行为和错误。