본문 바로가기
카테고리 없음

[pytorch] torch.stack() 텐서합치는 방법

by 최신 마트 정보 2024. 9. 29.
반응형

텐서를 합치는 방법으로 torch.stack 함수를 사용하는 방법이 있습니다. torch.stack함수를 사용하면 축이 하나 늡니다. 자세한 원리는 아래에서 보겠습니다.

 

 

 

 

 

목차

    torch.stack 함수에 대하여?

    stack 함수 사용 시 지정하는 축으로 차원을 확장하여 tensor를 쌓는 것을 의미합니다. tensor를 쌓기 때문에 축이 하나 늘죠. 텐서를 차곡차곡 쌓는 것이기 때문에 두 tensor의 차원이 정확하게 일치하여야만 torch.stack 함수를 사용할 수 있습니다. 그리고 stack을 할 때 새로 추가할 축을 dim 입력변수로 정해주어야 한다.

     

    torch.stack 시각화

     

     

    torch.stack 코드

    import torch
    
    B, H, W = 10, 20, 128
    
    x1 = torch.rand(B, H, W) # [B, H, W]
    x2 = torch.rand(B, H, W) # [B, H, W]
    
    output = torch.stack([x1,x2], dim=1) #[B,2,H,W]

    torch.stack 함수 활용: tensor list를 이용해서 한번에 tensor list만들기

    import torch
    
    #(....중략)
    
    out_list = []
    for data in dataloader:
        out = model(data) # [B, dimension]
        out_list.append(out)
    output = torch.stack(out_list, 0) # [Batch 수, B, dimension]

     

     

    반응형

    댓글