torch.where
torch.index_select
#class_value 라는 label을 가지는 값들의 index 뽑기
target_indexes = torch.where(target_tensor == class_value)[0]
#뽑은 index의 값을 가지는 새로운 tensor 정의
input_with_class_value = torch.index_select(input_tensor, 0, target_indexes)'Code & Framework > Pytorch' 카테고리의 다른 글
| [Pytorch] Batch 개수, Data 개수 (0) | 2023.05.14 |
|---|---|
| [Pytorch] Seed 고정 (0) | 2023.05.09 |
| [Pytorch] GPU에 맞는 CUDA version 설치 (0) | 2023.04.14 |
| [Pytorch] pytorch version 확인 (0) | 2023.04.14 |
| [Pytorch] .detach().cpu().numpy() (0) | 2023.04.10 |