torch.cdist 는 document에 따르면 다음과 같은 기능을 한다.
요약하자면, tensor와 tensor 간의 distance matrix를 반환해주는 함수라고 할 수 있다.
대표적으로 3D points 간 거리 계산에 유용하게 쓰인다. x,y,z로 표현 되는 3D point 집합 X와 Y가 있고 각각 N개, M개의 3D point로 구성되어 있다고 했을 때, X와 Y 집합 간 서로 최단 거리로 이웃한 점들을 선별할 때나 특정 거리 이내의 이웃한 점들은 선별할 때 자주 사용하게 된다.
하지만, 가끔 제대로 기능하지 않는 듯한 모습을 보일 때가 있어 이를 정리하고자 한다.
"가끔 제대로 기능하지 않는 듯한 모습"은 크기가 25 이상인 경우에 발생했다. 위 document에 따르면 P나 R이 25 이상일 때, 아래 예시에 따르면 N과 M이 25개 이상일 때다.
이것을 어떻게 발견하게 되었는지 경험을 보여주면, 아래와 같이 같은 입력을 반복해서 넣어줄 경우, 스스로에 대한 distance matrix를 얻을 수 있을 것이다.
# cam_positions : [240 3] xyz
distance_matrix = torch.cdist(cam_positions, cam_positions)
'''
tensor([[ 0.0000, 15.8538, 0.0000, ..., 11.8134, 10.7934, 11.8204],
[15.8538, 0.2454, 15.8536, ..., 10.8047, 12.0143, 10.7941],
[ 0.0000, 15.8536, 0.0000, ..., 11.8024, 10.8071, 11.8080],
...,
[11.8134, 10.8047, 11.8024, ..., 0.2098, 14.8524, 2.1046],
[10.7934, 12.0143, 10.8071, ..., 14.8524, 0.0995, 15.5955],
[11.8204, 10.7941, 11.8080, ..., 2.1046, 15.5955, 0.1893]])
'''
정확하다면 위 distance matrix의 diagonal elements는 무조건 0이 나와야 하지만 보다시피 이상한 값이 껴있는 것을 알 수 있다. 무시하기엔 작은 값이 아니라서 계산이 확실히 의도대로 되지 않았음을 알 수 있다. (개수를 240개가 아닌 20개를 사용한다면 대각 성분이 제대로 0으로 계산이 된다.)
원인은 모르겠다만 이 경우에는 "compute_mode" argument로 "donot_use_mm_for_euclid_dist"를 넣어주어야 해결이 된다.
# cam_positions : [240 3] xyz
distance_matrix = torch.cdist(cam_positions, cam_positions, compute_mode="donot_use_mm_for_euclid_dist")
'''
tensor([[ 0.0000, 15.8549, 0.0190, ..., 11.8137, 10.7930, 11.8202],
[15.8549, 0.0000, 15.8546, ..., 10.8038, 12.0130, 10.7930],
[ 0.0190, 15.8546, 0.0000, ..., 11.8026, 10.8066, 11.8078],
...,
[11.8137, 10.8038, 11.8026, ..., 0.0000, 14.8505, 2.0976],
[10.7930, 12.0130, 10.8066, ..., 14.8505, 0.0000, 15.5939],
[11.8202, 10.7930, 11.8078, ..., 2.0976, 15.5939, 0.0000]])
'''
argument를 바꿔주면 제대로 대각 성분이 0으로 계산되는 것을 볼 수 있다. 내부 구현을 볼 수 없으니 어떠한 이유로 값이 차이가 나는지, 그리고 얼마나 잘못 계산되는지 확인할 수 없지만 이론적으로 나와야 하는 값에서 꽤나 벗어나고 있기 때문에 항상 확인하고 쓰는 것이 좋겠다.
결론적으로 크기가 25 이상일 경우에는 torch.cdist 사용 시 "donot_use_mm_for_euclid_dist"를 꼭 넣어주고, 대각 성분을 검토하는 정도의 확인은 꼭 해줘야 한다.