본문 바로가기

Python/에러 대모험

ImportError: cannot import name 'zero_gradients' from 'torch.autograd.gradcheck'

반응형

gradcheck.py에 아래 코드를 삽입하고 저장하면 해결된다.

 

def zero_gradients(x):
    if isinstance(x, torch.Tensor):
        if x.grad is not None:
            x.grad.detach_()
            x.grad.data.zero_()
    elif isinstance(x, container_abcs.Iterable):
        for elem in x:
            zero_gradients(elem)
반응형