AI/Computer Vision
torch.cat
HEAD1TON
2025. 4. 6. 04:11
728x90
반응형
torch.cat은 PyTorch에서 제공하는 함수로, 여러 개의 텐서를 특정 차원을 따라 연결하여 하나의 텐서로 만드는 역할을 합니다.
쉽게 말해, 여러 개의 텐서를 하나로 합치는 기능을 한다고 생각하면 됩니다.
사용법:
torch.cat(tensors, dim=0, *, out=None) → Tensor
tensors: 연결할 텐서들을 담은 시퀀스 (튜플, 리스트 등) 입니다.
dim: 텐서를 연결할 차원을 지정합니다. 기본값은 0입니다.
out: 결과를 저장할 텐서 (선택 사항) 입니다.
예시:
import torch
# 2개의 텐서 생성
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
# dim=0을 따라 연결 (세로로 연결)
tensor3 = torch.cat([tensor1, tensor2], dim=0)
# 결과: tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
# dim=1을 따라 연결 (가로로 연결)
tensor4 = torch.cat([tensor1, tensor2], dim=1)
# 결과: tensor([[1, 2, 5, 6], [3, 4, 7, 8]])
주의 사항:
연결할 텐서들은 연결하려는 차원을 제외한 모든 차원에서 동일한 크기를 가져야 합니다.
dim은 연결할 차원을 나타내며, 0부터 시작합니다.
추가 설명:
torch.cat은 torch.split과 torch.chunk의 역 연산으로 볼 수 있습니다. 즉, torch.split이나 torch.chunk으로 나눈 텐서들을 다시 torch.cat으로 합칠 수 있습니다.
Autograd는 torch.cat 연산을 추적합니다. 따라서, torch.cat으로 생성된 텐서에 대한 기울기를 계산할 수 있습니다.
활용:
torch.cat은 다양한 상황에서 유용하게 사용될 수 있습니다. 예를 들어, 여러 개의 배치를 하나로 합치거나, 여러 개의 특징 맵을 연결하는 등의 작업에 활용할 수 있습니다. 특히, 딥러닝 모델에서 데이터를 전처리하거나, 모델의 출력을 결합하는 등 다양한 용도로 활용됩니다.
728x90
반응형