torch.where
torch.index_select
#class_value 라는 label을 가지는 값들의 index 뽑기
target_indexes = torch.where(target_tensor == class_value)[0]
#뽑은 index의 값을 가지는 새로운 tensor 정의
input_with_class_value = torch.index_select(input_tensor, 0, target_indexes)
'Code > Pytorch' 카테고리의 다른 글
[Pytorch] Batch 개수, Data 개수 (0) | 2023.05.14 |
---|---|
[Pytorch] Seed 고정 (0) | 2023.05.09 |
Pytorch GPU에 맞는 CUDA version 설치 (0) | 2023.04.14 |
[Torch] pytorch version 확인 (0) | 2023.04.14 |
[Pytorch] .detach().cpu().numpy() (0) | 2023.04.10 |