반응형
안녕하세요. 파이토치 (pytorch)를 사용하다보면 GPU를 여러개 사용하기 위하여 torch.nn.DataParallel을 이용해 모델을 불러오곤 합니다. 물론 DataParallel을 이용하면 여러 GPU를 사용할 수 있기 때문에 편한데요. 일반 모델과는 다른 형식을 갖기 때문에 사용하는데 있어서 불편함이 있었습니다. 이번 글에서는 torch.nn.DataParallel을 일반 모델로 전환하는 방법에 대해 알아보겠습니다.
torch.nn.DataParallel 사용
아래와 같이 입력하면 해당모델은 multi GPU를 사용할 수 있습니다.
model = torch.nn.DataParallel(model, device_ids=gpu_ids)
torch.nn.DataParallel을 일반모델로 만들기
위와 같이 코드를 입력하고 난후 이 모델을 일반 모델처럼 다루고 싶다면 아래와 같이 입력하면 됩니다.
model = model.module
위와 같이 입력하면 DataParalle로 설정된 모델을 일반모델처럼 전환해서 사용할 수 있어요.
반응형
댓글