본문 바로가기
카테고리 없음

torch.nn.DataParallel을 일반 모델로 전환하기

by 최신 마트 정보 2024. 2. 5.
반응형

안녕하세요. 파이토치 (pytorch)를 사용하다보면 GPU를 여러개 사용하기 위하여 torch.nn.DataParallel을 이용해 모델을 불러오곤 합니다. 물론 DataParallel을 이용하면 여러 GPU를 사용할 수 있기 때문에 편한데요. 일반 모델과는 다른 형식을 갖기 때문에 사용하는데 있어서 불편함이 있었습니다. 이번 글에서는 torch.nn.DataParallel을 일반 모델로 전환하는 방법에 대해 알아보겠습니다.

torch.nn.DataParallel 사용

아래와 같이 입력하면 해당모델은 multi GPU를 사용할 수 있습니다.

model = torch.nn.DataParallel(model, device_ids=gpu_ids)

torch.nn.DataParallel을 일반모델로 만들기

위와 같이 코드를 입력하고 난후 이 모델을 일반 모델처럼 다루고 싶다면 아래와 같이 입력하면 됩니다.

 

model = model.module

위와 같이 입력하면 DataParalle로 설정된 모델을 일반모델처럼 전환해서 사용할 수 있어요.

반응형

댓글