Code/Pytorch

[Pytorch] Tensor 에서 특정 index 의 값을 뽑아 새로운 tensor 정의

이성훈 Ethan 2023. 4. 12. 15:52

torch.where

 

torch.index_select

#class_value 라는 label을 가지는 값들의 index 뽑기

target_indexes = torch.where(target_tensor == class_value)[0]

#뽑은 index의 값을 가지는 새로운 tensor 정의

input_with_class_value = torch.index_select(input_tensor, 0, target_indexes)

'Code > Pytorch' 카테고리의 다른 글

[Pytorch] Batch 개수, Data 개수  (0) 2023.05.14
[Pytorch] Seed 고정  (0) 2023.05.09
Pytorch GPU에 맞는 CUDA version 설치  (0) 2023.04.14
[Torch] pytorch version 확인  (0) 2023.04.14
[Pytorch] .detach().cpu().numpy()  (0) 2023.04.10