PyTorch中Variable变量

  • 时间:
  • 来源:互联网

一、了解Variable

顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。

具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。
【tensor 是一个多维矩阵】

用一个例子说明,Variable的定义:

import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)

print(tensor)
“”"
1 2
3 4
[torch.FloatTensor of size 2x2]
“”"

print(variable)
“”"
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
“”"

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

注:tensor不能反向传播,variable可以反向传播。

二、Variable求梯度

Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。

v_out.backward()    # 模拟 v_out 的误差反向传递

print(variable.grad) # 初始 Variable 的梯度
‘’’
0.5000 1.0000
1.5000 2.0000
‘’’

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

三、获取Variable里面的数据

直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。

print(variable)     #  Variable 形式
"""
Variable containing:
 1  2
 3  4
[torch.FloatTensor of size 2x2]
"""

print(variable.data) # 将variable形式转为tensor 形式
“”"
1 2
3 4
[torch.FloatTensor of size 2x2]
“”"

print(variable.data.numpy()) # numpy 形式
“”"
[[ 1. 2.]
[ 3. 4.]]
“”"

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

扩展

在PyTorch中计算图的特点总结如下:
autograd根据用户对Variable的操作来构建其计算图。

  1. requires_grad
    variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True
  2. volatile
    variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高
  3. retain_graph
    多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存
  4. .backward()
    反向传播,求解Variable的梯度。放在中间缓存中。

主要参考:
[1] https://blog.csdn.net/liuhongkai111/article/details/81291003
[2] https://zm10.sm-tc.cn/?src=l4uLj4zF0NCIiIjRnJGdk5CYjNGckJLQl5qTk5yei9CejYuWnJOajNDHy8vGz8zO0ZeLkpM%3D&uid=f3737415e9557099fc449754849c7d19&hid=a1e29e2ae407bbcbd5945b3277d7f9e2&pos=5&cid=9&time=1544919030860&from=click&restype=1&pagetype=0020000002000408&bu=web&query=Pytorch中Variable的作用&mode=&v=1&force=true&wap=false&province=辽宁省&city=大连市&uc_param_str=dnntnwvepffrgibijbprsvdsdichei

                                </div>
Arthur-Ji
发布了378 篇原创文章 · 获赞 55 · 访问量 13万+
私信 关注

本文链接http://element-ui.cn/news/show-1459.aspx