设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2712|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    5 t# ]1 L: |0 h: A( s
    ' z, i4 e% }( N: f7 x为预防老年痴呆,时不时学点新东东玩一玩。5 W2 w! F7 s- P7 w. W
    Pytorch 下面的代码做最简单的一元线性回归:" x* k+ x  b+ s/ I/ s
    ----------------------------------------------
    8 a4 y+ b' G1 E# X8 A/ Cimport torch% z/ \& ?/ ~$ V
    import numpy as np
    % l( h4 [$ Z/ M9 ?; r5 D& z0 Vimport matplotlib.pyplot as plt
    1 j9 Q' c$ x* }! X  cimport random2 m0 V4 q" I7 e

    / a- w7 K' w' {6 i- @) W) l; Nx = torch.tensor(np.arange(1,100,1))
    4 b& `' R* A8 W' Iy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=153 K) f" w9 |3 v1 z

    ( L, h0 W3 q' M* f, W% g2 qw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    3 C; J4 ^. l; f& \4 s! ?b = torch.tensor(0.,requires_grad=True)0 R( e4 `6 q7 K/ |; R7 R8 x! b

    9 P. g  }0 g7 I/ m- qepochs = 100
    4 @$ Q8 r0 d' [& I
    ) r. @7 r% e) p/ t; tlosses = []
    3 q1 ?. ^! N* _for i in range(epochs):: T4 Y- q% m1 O3 i$ y% K( j8 _9 c
      y_pred = (x*w+b)    # 预测& t6 Y, ~  j/ \8 y2 e0 v" j
      y_pred.reshape(-1)5 h: f) S+ \. {7 j* x

    : K  B$ ^& Q: A  X" K  loss = torch.square(y_pred - y).mean()   #计算 loss* U$ K) T+ b# m3 f9 |
      losses.append(loss)
    4 L1 G+ k7 k( ^; o, `7 q  
    % C8 x$ P) S9 q' e, @* x' G  loss.backward() # autograd8 I+ l6 q# S: ^8 K) q+ |
      with torch.no_grad():, Q3 _- h) P2 n* F( v- {
        w  -= w.grad*0.0001   # 回归 w( C. ]) F! I# S
        b  -= b.grad*0.0001    # 回归 b
    4 C; V1 |  |- T& j& M# H9 S( r5 ?  w.grad.zero_()  + n4 \8 a4 p; b
      b.grad.zero_()
    3 `  Y) u3 W6 \, Z6 S
    / @" j) u& d2 c# e- |* ~print(w.item(),b.item()) #结果! U9 B2 E. l9 z

    5 E2 x0 M; R% ^# p& L$ j1 GOutput: 27.26387596130371  0.4974517822265625
    5 g2 J, h9 K- L1 s----------------------------------------------
    ) V" w$ f; n0 ~最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。2 H% C  A: s! _% U/ v9 l- Q
    高手们帮看看是神马原因?
    * c' m7 i& L2 ], L% R! E

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 # a- A0 ~* u# X: C; ], d- k7 U
    5 s( z1 l  `9 d
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ( _$ E# f# B" J( g-------
    3 s  s% L: j  H' X不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    $ N* O# I* }# A1 u: y-------+ f3 U5 ]" P: A% |
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23) r7 v6 p7 |1 {: \  x, E  m
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?% Y! E6 J; z% T, t. d1 [3 m8 Q+ l
    -------+ S* K+ o; L5 K* y' w
    不好意思, ...

    / P* I. v9 n, J, q- l  L谢谢,算法应该没问题,就是最简单的线性回归。2 C$ u0 v/ U2 W# ^. ?& @- D3 i$ a5 l
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 9 j# a) \% r, M  ^* ?2 S; m* K
    雷达 发表于 2023-2-14 21:52
    ! N! a/ H+ r: Y2 Y4 }谢谢,算法应该没问题,就是最简单的线性回归。7 C6 x, Z% Q3 ?
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    - M; G3 T  F$ ^' u1 J
    & D. W  S; H6 n刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。/ l8 m/ \: a2 W; D
    2 q- o% a3 Y! X5 b6 ~9 h8 I+ m
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 1 W8 O# k/ I) L
    老福 发表于 2023-2-14 22:002 X( m# r; X% O2 g; C$ W& _
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。+ W1 G5 y. S. |3 K
    ) M$ n; w, e/ o
    或者把b但的起点改为1试试。 ...
    7 w5 ]4 @% y* Y7 }& t; O' v# [: m
    * N% Y4 G) `9 ~% G/ ^5 k9 ?1 t
    你是对的。
    1 P1 r) J: m. \7 C- s) B9 c去掉了随机部分0 \5 k  D. j) v- }' B2 c! F
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)$ e+ s4 _  I- A0 c; ?# x
    y = (x*27+15).reshape(-1)- Y& u3 d4 E9 P; W1 }, e
    + H6 V- z2 l( P1 k1 ~
    循环次数加成10倍,就看到 b 收敛了1 Y6 `/ N8 ]7 z0 X3 d
    w , b: ~; U  f! n4 g; R
    27.002620697021484 14.826167106628418* H# o- ^! _/ x0 C; k% U

    " X7 k# K% C1 Z! D和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-7 08:27 , Processed in 0.061979 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表