Knowledge/Vision

Spectral bias of Neural network

침닦는수건 2023. 3. 10. 16:27
반응형

딥러닝을 공부하다가 inductive bias라는 개념은 CNN을 다룰 때나 transformer를 다룰 때 자주 들어보았지만, 최근 spectral bias라는 개념을 처음 접했다. 특정 task에서만 발생하는 bias라면 그냥 지나갔겠지만 짧게 알아본 바로는 모든 neural network 학습 과정에서 전반적으로 나타나는 bias여서 알아두어야 할 것 같았다.

 

그래서 이 글은 간략하게 나마 spectral bias가 뭔지 설명하고 그 파훼법을 정리하기 위해 적는다.

 

Spectral bias 

On the Spectral Bias of Neural Networks 라는 제목의 논문에서 저자들이 처음 주장했다.

By using tools from Fourier analysis, we highlight a learning bias of deep networks towards low frequency functions – i.e. functions that vary globally without local fluctuations – which manifests itself as a frequency-dependent learning speed.

네트워크는 low frequency 정보를 우선적으로 학습하는 경향이 있다는 말인데 간략하게 말하자면 네트워크는 큰 그림부터 배우고 세부사항은 나중에 배운다는 것이다. 

 

논문에서는 이를 증명하기 위해서 fourier spectrum을 소개하고 다양한 수식과 이론을 전개하는데 그 과정이 복잡하기 때문에 생략하고 실험적 내용만 점프해서 보면 다음과 같다.

 

첫번째 그림은, 입력을 low frequency ~ high frequency를 전부 다 갖고 있는 형태로 사용해서 학습했을 때, 네트워크 내에서 해당 frequency에 대응되는 node 활성도를 시각화 한것이다. 밝을 수록 활성이 많이 된 것인데 학습 초기에 low frequency영역이 먼저 활성화되고 학습이 진행됨에 따라 점점 high frequency 영역까지 활성화되는 것을 볼 수 있다. 

 

두번째 그림은, 네트워크에게 복잡한 signal을 맞추도록 학습시키는 과정이다. 이 때 네트워크 출력은 low frequency signal로 시도를 시작해서 점점 high frequency를 더해나가면서 최종 signal을 맞추는 것을 볼 수 있다. 

 

위 두 실험의 결과로 미루어보아, 네트워크는 학습 과정에서 low frequency 정보 먼저 학습하는 경향이 있다고 볼 수 있다고 하는 것이다. 

 

 How to reduce spectral bias

위 논문은 spectral bias가 존재한다는 것을 소개하는 논문이기 때문에 이 bias가 구체적으로 어떤 문제를 야기하는지까지는 짚지 않았지만, 간단하게 생각해보면 generative model이나 reconstruction model에서 디테일한 부분 복원이 중요할 경우 spectral bias는 반드시 방해되는 bias라는 것을 알 수 있다. 그렇다면 어떻게 피할 수 있을까?

 

그 힌트는 본 논문과 Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains 논문에서 찾을 수 있었다. 

 

아까 증명을 위해서 입력을 frequency domain으로 보내서 사용했다고 했는데 이때 얼마나 높은 frequency까지 사용해서 mapping했는지에 따라 네트워크 학습 양상이 달라졌다. high frequency가 포함된 입력 형태일수록 고른 학습이 진행되었다. 

 

추가 논문에서 보여준 자료에서도, high frequency가 많이 포함된 입력 형태(p=0일 때)이면 target signal을 좀 더 잘 맞추는 결과를 볼 수 있다. 

 

다시 말하면 high frequency가 포함되도록 입력 형태를 변형해주면 조금은 spectral bias를 줄일 수 있다는 뜻이다. 

 

그러면 그 구체적 방법은 뭐가 있을까? 추가 논문에서는 3가지를 보여줬다.

 

일반 값보다는 낫지만 귀찮으니 cos, sin만 씌운 basic, positional encoding으로 알려진 형태, Gaussian을 섞은 cos, sin 형태다. 

 

 결과를 보면 3가지 모두 일반 입력 형태에 비해선 성능이 높아짐을 볼 수 있다. Gaussian이 제일 좋게 나왔지만 이건 실험한 task에만 한정된 것일 수 있으니 실제로는 basic, PE, Gaussian 다 써보고 좋은 것을 쓰면 될 것 같다.

 

Conclusion

네트워크 입력 형태를 정할 때 low-to-high frequency로 나눠서 표현할 수 있다면 바꿔서 사용하자.

반응형