Pytorch 是非常注明的机器学习框架,其中的 torch.Tensor 是自带排序的,直接使用 torch.sort()
这个方法即可。排序可以按照升序、降序,可以选择排序的维度,等等。下面介绍一下 Pytorch 中的排序方法。
文章来自:https://hxhen.com/sort-method-in-torch-model/
一、方法原型
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
二、返回值
A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input tensor.
三、参数
- input (Tensor) – the input tensor
形式上与 numpy.narray 类似 - dim (int, optional) – the dimension to sort along
维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1 - descending (bool, optional) – controls the sorting order (ascending or descending)
降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase
四、实例
import torch x = torch.randn(3,4) x #初始值,始终不变 tensor([[-0.9950, -0.6175, -0.1253, 1.3536], [ 0.1208, -0.4237, -1.1313, 0.9022], [-1.1995, -0.0699, -0.4396, 0.8043]]) sorted, indices = torch.sort(x) #按行从小到大排序 sorted tensor([[-0.9950, -0.6175, -0.1253, 1.3536], [-1.1313, -0.4237, 0.1208, 0.9022], [-1.1995, -0.4396, -0.0699, 0.8043]]) indices tensor([[0, 1, 2, 3], [2, 1, 0, 3], [0, 2, 1, 3]]) sorted, indices = torch.sort(x, descending=True) #按行从大到小排序 (即反序) sorted tensor([[ 1.3536, -0.1253, -0.6175, -0.9950], [ 0.9022, 0.1208, -0.4237, -1.1313], [ 0.8043, -0.0699, -0.4396, -1.1995]]) indices tensor([[3, 2, 1, 0], [3, 0, 1, 2], [3, 1, 2, 0]]) sorted, indices = torch.sort(x, dim=0) #按列从小到大排序 sorted tensor([[-1.1995, -0.6175, -1.1313, 0.8043], [-0.9950, -0.4237, -0.4396, 0.9022], [ 0.1208, -0.0699, -0.1253, 1.3536]]) indices tensor([[2, 0, 1, 2], [0, 1, 2, 1], [1, 2, 0, 0]]) sorted, indices = torch.sort(x, dim=0, descending=True) #按列从大到小排序 sorted tensor([[ 0.1208, -0.0699, -0.1253, 1.3536], [-0.9950, -0.4237, -0.4396, 0.9022], [-1.1995, -0.6175, -1.1313, 0.8043]]) indices tensor([[1, 2, 0, 0], [0, 1, 2, 1], [2, 0, 1, 2]])