문제 상황
pytorch에서 텐서 연산 시 빠른 처리를 위해 batch로 연산하는 경우가 많다.
그런데 하나씩 하나씩 연산하는 것과, 배치로 한꺼번에 연산하면 결과값이 미세하게 달라지는 현상을 발견했다.
다음은 전체 코드이다.
코드
import torch
from torch.nn import Linear, Module
class NN(Module):
def __init__(self):
super().__init__()
self.fc1 = Linear(64, 2)
def forward(self, x):
x = self.fc1(x)
return x
def main():
batch_size = 10
x = torch.rand((batch_size, 64))
nn = NN()
# one by one
l = []
for element in x:
l.append(nn(element))
# batch process
result = nn(x)
print()
print(f"total batch size: {batch_size}")
print()
for i, item in enumerate(l):
# print(item == result[i])
print(f"{i}th element: {'equal' if torch.equal(item, result[i]) else 'different'}")
if __name__ == '__main__':
main()
결과
출력 결과, 배치 처리한 결과와 하나씩 처리한 결과가
같을 때도 있고 다를 때도 있는 것을 알 수 있다.
그렇기 때문에, 이렇게 비교해야 하는 상황에서는 torch.equal 이나 == 연산자를 쓰면 안 된다!
원인
pytorch에서 Linear layer 처리 시, 즉 batch matrix multiplication 연산 시 한꺼번에 처리하기 위해 추가적인 최적화 과정을 거치게 되는데, 이 때 부동소수점 연산의 순서가 달라진다.
부동소수점 연산에서는 결합법칙이 성립하지 않기 때문에, 미세한 차이가 발생한다.
멀티쓰레드를 싱글쓰레드로 바꾸든, torch.float64로 바꾸든 이 에러는 해결할 수 없다.
문제 해결 방법
torch.allclose 메서드를 사용하면 일반적인 상황에서는 해결 가능하다.
하지만, 후술할 참고문헌을 찾아보면,
error가 propagation 되어서 오차 값이 매우 커질 수도 있다고 한다.
그래서 결론은,
최대한 이런 식으로 비교하지 않도록 코드를 짜는 것이 가장 좋아 보인다.
참고문헌
https://discuss.pytorch.org/t/numerical-error-between-batch-and-single-instance-computation/56735
'NLP lab > 파이썬' 카테고리의 다른 글
[Solved] device-side assert triggered, Assertion `t >= 0 && t < n_classes` failed. (0) | 2022.08.05 |
---|---|
[numpy] + 연산과 += 연산의 차이점 (0) | 2022.02.07 |