Kaggle TPU VM에서 Pytorch Lightning 2.0 훈련 오류 해결하기

PJRT 런타임을 끄면 됩니다

문제점

Kaggle에서 제공하는 TPU VM v3-8 accelerator를 활용해 훈련을 돌리려고 해보았습니다. 자동으로 TPU에서 Distributed Training을 지원하는 Pytorch Lightning을 사용해 보았는데 훈련이 돌아가지 않고 오류가 나거나 멈춰 버립니다.

해결책

결론

PJRT 런타임을 꺼버리면 됩니다.

import os
os.environ.pop('PJRT_DEVICE', None)

자세한 설명

TPU는 XRT와 PJRT라는 두가지 런타임을 지원합니다. 이 중 XRT가 기존의 런타임, PJRT가 새 런타임입니다. 그리고 Pytorch Lightning 2.0.x는 PJRT를 지원하지 않기에 XRT를 이용해야 합니다. 그러기 위해서는 Google Cloud 문서에 따르면 XRT_TPU_CONFIG 환경 변수를 설정해주어야 합니다. 문제는 Kaggle에서 TPU VM을 할당받으면 PJRT를 활성화 하는 PJRT_DEVICE 환경 변수가 미리 설정되어 있어서 XRT_TPU_CONFIG가 설정되어도 무시됩니다. 그래서 이 환경 변수를 지워버리면 훈련이 정상적으로 돌아갑니다.

이렇게 하자 훈련이 잘 돌아갑니다. 위 사진은 Weight & Bias 서비스를 이용해 기록한 그래프인데, 여기서 tpu의 그래프를 보면 정상적으로 훈련이 되고 있는 걸 볼 수 있습니다. 실제로 훈련에 사용한 노트북은 여기서, 사용한 코드는 여기서 볼 수 있습니다.

또 다른 해결책

현재(2023년 4월 기준) Pytorch Lightning 쪽에서 PJRT 런타임을 지원하기 위한 작업이 활발히 진행되고 있습니다. 2.1 버전이 나오면 아마 PJRT가 지원되며 문제가 해결될 것으로 보입니다.

마치며

인터넷에서 관련 정보가 검색되지 않아 이 간단한 방법을 찾는 데 이틀이나 걸렸습니다. 혹시라도 저와 같은 상황을 겪고 있으신 분들이 있다면 도움이 되길 바라며 글을 써보았습니다.