Python/에러 대모험
ImportError: cannot import name 'zero_gradients' from 'torch.autograd.gradcheck'
ikaros0427
2023. 2. 15. 11:52
반응형
gradcheck.py에 아래 코드를 삽입하고 저장하면 해결된다.
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.data.zero_()
elif isinstance(x, container_abcs.Iterable):
for elem in x:
zero_gradients(elem)
반응형