设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2466|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 7 ?  M& T% s0 k4 o

    # Y; W3 g$ c0 n- Z+ F) F为预防老年痴呆,时不时学点新东东玩一玩。
    2 c. i) f8 J# u" CPytorch 下面的代码做最简单的一元线性回归:
    " \6 j# l* ?+ }0 r- X, O----------------------------------------------
    , C5 E; m- @" c2 y5 Cimport torch. `5 \, E4 J, x, n
    import numpy as np
    . d! M& F3 S. x) L/ W; himport matplotlib.pyplot as plt9 s0 a% z$ ~( X) L% U' ]8 B
    import random
    . J6 k5 `% R/ z
    ( z, k4 d+ r1 t/ [; xx = torch.tensor(np.arange(1,100,1))
    , e( ^+ g# b" [y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=152 Z; x/ d, J7 ^: J: W. i
    % f+ j& M2 Q, }  T8 f4 ]* l
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    ; i4 i2 u6 v& Y+ f* ^b = torch.tensor(0.,requires_grad=True)9 V* S0 d( v5 I& z8 r8 d7 i2 Y

    ' m" q3 }! S5 j) E! [( Vepochs = 100
    8 z  D& `2 f0 C" ^- F- |; t- J2 F" S. O) C7 H& ]$ z/ G4 n# ?* W7 |
    losses = []
    * P8 J6 [+ f% z1 s7 |- hfor i in range(epochs):
    / l. [4 {5 W% P  y_pred = (x*w+b)    # 预测
    0 z, k" g- ?3 I+ y* h/ @. I  y_pred.reshape(-1)
    : T) [5 f( H) p
    - G2 |3 n: v6 J# k4 F# l! L  loss = torch.square(y_pred - y).mean()   #计算 loss
    7 v" H7 O0 b2 Q3 i. u  losses.append(loss)
    , D. e" _% [( w1 ]8 _: [8 e" \$ l  - Z; H. s' n+ c6 w0 Z
      loss.backward() # autograd
    8 q; |) }) P) p* `  with torch.no_grad():
    + d9 N" r3 _( @    w  -= w.grad*0.0001   # 回归 w* m9 f4 i9 Q8 A$ V" {5 F1 Y
        b  -= b.grad*0.0001    # 回归 b
    , j( X2 w( D# r  w.grad.zero_()  ! O: y* g! ~" g9 `; ^
      b.grad.zero_()
    " m9 q7 h9 _+ r1 K
    , B( _3 s0 {9 L: lprint(w.item(),b.item()) #结果: H8 r# D( [/ ?/ z$ y! e
      `  Y8 d/ c& n" c* z- F
    Output: 27.26387596130371  0.4974517822265625
    $ I2 j, ~8 j2 u7 ?& |$ O7 [7 T----------------------------------------------
    . w# l1 S5 d; j1 q8 z+ ]最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    4 A/ x" Q' m/ M* @高手们帮看看是神马原因?3 l; q4 Q' i+ g3 @9 `

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 ' w! p( i! a9 j
    0 l( g: R6 a  c; i# A* P; E/ H; g# c5 u
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    : P2 `3 r& x% R1 c-------6 l) x/ e$ Y( y/ c( m# K& ^
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    3 n8 l' j$ z7 d; K5 `-------  P6 {) G% v! Y: Q  V
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    6 ?* L. j3 D. _2 s没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?1 T* O7 }8 o) p; R* W4 c
    -------
    8 u9 W- T6 S) o+ ^不好意思, ...

    2 C: q3 q  D8 c4 a+ m2 G谢谢,算法应该没问题,就是最简单的线性回归。
    + ?; P$ N9 ^3 B) S! d我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 3 W& E" \; N' l3 o
    雷达 发表于 2023-2-14 21:52
    2 a- e9 q1 J: F/ x8 T; l; P6 ^1 {谢谢,算法应该没问题,就是最简单的线性回归。  c  U7 v: ]6 i* v4 C* |
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    5 D! ]1 l( I1 j% Z/ p$ L+ J- n# @& X" z5 D, P
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    1 t* `& I; s. `9 V. i. ^- C3 o. B% O* I4 D6 v$ g9 O& b2 E* V
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    5 }! _/ D" t- ^% D
    老福 发表于 2023-2-14 22:00
    2 C* r4 N0 R+ f7 }8 m4 b  U刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。2 [+ }% Y. n2 ~2 |, Q1 j

    3 M3 j. `# A. [1 `$ {9 S或者把b但的起点改为1试试。 ...

      C( r; n# I2 I! u. e
    ' Q# w8 v; |& s! }/ Z你是对的。% c7 E2 a2 {! p5 I6 a- P- c' b* e: ~
    去掉了随机部分5 N$ y; H5 D* T6 C1 @6 l( I3 z
    #y = (x*27+15+random.randint(-2,3)).reshape(-1): v' z! e* B( U
    y = (x*27+15).reshape(-1)
    4 ?. S7 P. }/ `7 \
    ' e$ O# [7 |7 S* L# ?$ u循环次数加成10倍,就看到 b 收敛了
    2 @4 a- X# K# O; @9 f$ Cw , b
    ; I9 a0 o8 P5 t. X8 F27.002620697021484 14.826167106628418! U2 X7 F$ V3 l" B7 M, m
    8 a* h, b  _5 \7 [/ b
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-15 09:34 , Processed in 0.064283 second(s), 21 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表