반응형
pytorch에서 텐서를 합치는 방법으로 torch.cat과 torch.stack이 있는데 이번 글에서는 torch.cat에 대해서만 알아보자.
목차

torch.cat 함수에 대하여?
orch.cat 함수를 이용하면 텐서들을 concatenate 해준다. concatenate하면서 텐서의 축은 늘어나지않고 이미 있던 축의 차원을 늘려준다. concatenate할 때 축을 지정하면 그 축으로 두 tensor의 축에 해당하는 차원의 합만큼 축의 차원이 변경된다. concatenate 할 때 dimensión 변수를 따로 넣어주지 않으면 default로 0으로 설정된다. 뉴럴넷 내의 다른 두 feature를 fusion하거나 batch 단위의 output을 모아 모든 데이터셋에 대한 output을 저장할 때 사용한다.
torch.cat 함수 그림으로 나타내기

torch.cat Python 코드
import torch
B, H, W = 16, 12, 128
x1 = torch.rand(batch_size, H, W) # [B, H, W]
x2 = torch.rand(batch_size, H, W) # [B, H, W]
output1 = torch.cat([x1, x2], dim=1) #[B, H+H, W]
output2 = torch.cat([x1, x2], dim=2) #[B, H, W+W]
Cat 함수 활용: 여러 Tensor 들을 cat해서 하나의 텐서 만들기
import torch
#(....중략)
output_list = []
for data in dataloader:
output = model(data) # [batch_size, dimension]
output_list.append(output)
output = torch.cat(output_list, dim=0) # [batch_size * 배치수, dimension ]반응형
댓글