设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2486|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    ! H6 I* r7 }: O' A1 E# a
    % S; }) `; r7 y: v- v- M- @为预防老年痴呆,时不时学点新东东玩一玩。
    + V) X. ~, _4 E( x* ^- k1 LPytorch 下面的代码做最简单的一元线性回归:
      M% \, `1 n3 o) M" a  V" G! W----------------------------------------------# A9 [8 ?% `  L0 ~. R+ r: L- s# a
    import torch
    3 n4 p4 b; |) f( ^9 C  V; ximport numpy as np
    2 M+ I7 x0 D- s" F4 q4 O. |3 H2 Fimport matplotlib.pyplot as plt
    ' B1 d* j' ]0 z0 D0 Rimport random  H: [# o; w% Y
    5 \: d% `( e7 ]% O. V' P/ B7 X' V
    x = torch.tensor(np.arange(1,100,1))
    ; ?4 o! g7 ~0 w  b! _: ^5 ey = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=156 C5 f3 S2 a4 c6 K1 l
    9 z8 ^: w# A1 l( q: e% d
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b5 B; i2 q9 A5 T+ a1 C/ n' Z
    b = torch.tensor(0.,requires_grad=True)
    ; x9 i/ O6 B* r0 n: u
    $ C5 O& N6 F# C, g* q6 Hepochs = 100/ @( V6 D. H$ U9 K

    9 |  J  K7 o+ D4 L0 W- w, }5 P) Closses = []6 D9 Z: f1 _3 O9 F$ u
    for i in range(epochs):3 O: a) Z; w+ W' f% o
      y_pred = (x*w+b)    # 预测( x) ]3 u: j( f
      y_pred.reshape(-1)) d" i9 M& @4 P! E! s$ \+ H
    # m3 m  M& j  U! j: G7 b
      loss = torch.square(y_pred - y).mean()   #计算 loss
    6 L) |4 u+ h0 o0 t. ~. j$ A! Y; T: l3 U  losses.append(loss)1 G+ P( C0 \# n9 u- Y) S
      
    8 p" @/ W8 n0 M8 o6 G9 R: u  loss.backward() # autograd
    4 b! E( M6 u& @: _2 K% h$ A. m* B  with torch.no_grad():
    ' e; O- G  `: ]& I3 @; P9 w9 ]# a& F    w  -= w.grad*0.0001   # 回归 w! O0 ^+ }& u% c2 C! K+ f
        b  -= b.grad*0.0001    # 回归 b # u: a+ i( l& f  T7 x6 f$ H, |
      w.grad.zero_()  
    8 W: ~" }1 p: j  I# O; j  b.grad.zero_()
    3 W! w9 p  m( r' g+ r0 n0 n/ d* I- Z1 U" d7 y) q8 P) I$ _
    print(w.item(),b.item()) #结果! R; v6 w( P# k* r" i
    . ^3 a  u2 I. U$ a  A
    Output: 27.26387596130371  0.4974517822265625
    1 x2 |3 T" k; f! s+ k/ Q# `----------------------------------------------
    : v0 a  ?( S4 M* i2 \  l* z$ K最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。) j6 P% R- F3 ]1 F) L6 r
    高手们帮看看是神马原因?
    5 v% f+ N! M& h

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    0 Z0 K8 e1 K+ @# y3 j, n2 L. \+ O6 J
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    % z0 A0 P7 k& P-------% M; r2 {% ]3 y% `
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    3 {9 `1 N) B8 d" A: s3 T9 G-------1 n& M1 _# u9 ~# G; Q
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23/ o1 [: a" V! w
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?  ~# [7 [5 c# V9 b2 I4 F
    -------
    2 o" F# w$ h  X不好意思, ...

    4 H; [) |) P/ A谢谢,算法应该没问题,就是最简单的线性回归。
    - Q6 q+ v# K+ i1 A( a( Q我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 7 R; s6 e8 ~. h* w# F
    雷达 发表于 2023-2-14 21:52# N7 T6 n4 W7 \8 C5 p1 Y  s" o
    谢谢,算法应该没问题,就是最简单的线性回归。
    4 {% p. r4 I! C. x8 o4 g我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ' P7 s; ?/ d% F: d7 i
    + G0 R% ^$ Y$ o0 f$ n( F
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    , S: T' |, I0 z. e! l3 b: |. r" U
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    9 a6 g$ x+ H* x% s3 E3 O3 d6 P
    老福 发表于 2023-2-14 22:00
    ; c1 Y) q7 l. W# q$ S3 U刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ; Y/ L* Q; _( Q9 G0 D3 L4 ]
    ) h5 v9 V9 ]" T2 }6 K或者把b但的起点改为1试试。 ...
    9 C2 i6 l# C$ N( V
    ' ]4 o6 {0 U6 x2 x! Q6 m8 G$ z
    你是对的。
    + B3 J. R' K8 w! u8 S去掉了随机部分
    # K  C0 J. h4 f, ?  E( ^#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    % n, M7 y5 s8 t# X; S, Y9 K: ~y = (x*27+15).reshape(-1)
    ' W; [1 M/ n; Q+ r) I& F. @/ O7 F) @
    循环次数加成10倍,就看到 b 收敛了: z1 j, n& Q- [# ~
    w , b
    ' ~- h1 m: y9 J27.002620697021484 14.826167106628418; _. F0 d/ Z. [8 _
    : x% V. [9 A" ^  {3 r4 y+ R* v  u
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-21 20:56 , Processed in 0.054341 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表