保留计算图

PyTorch在多线程使用autograd时的一个潜在问题 - 保留计算图。主要内容如下:

  1. 当不同线程共享autograd计算图的一部分时,例如前向传播的第一部分在单线程,第二部分在多线程中执行。
  2. 这个共享的图会被不同的线程调用grad()或backward()进行反向求导。
  3. 问题在于一个线程在求解时可能会破坏计算图,而另一个线程还在使用相同的计算图,因此会crash。
  4. 此时PyTorch会像调用backward()两次而不使用retain_graph=True那样报错。
  5. 提示用户在这种共享计算图的多线程使用场景下,需要设置retain_graph=True。
  6. 以保留住计算图不被反向求导操作破坏,避免线程间的冲突。

总结来说,多线程下可能共享图的一部分,此时需要retain_graph=True来保留计算图,避免线程间的不确定行为和错误。


<
Blog Archive
Archive of all previous blog posts
>
Next Post
tensor 广播机制