设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2304|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    / E0 d! p' E* n- Z. n5 t
    2 Z6 D# Z6 j7 P0 ^为预防老年痴呆,时不时学点新东东玩一玩。: c8 [! N! h) t9 B
    Pytorch 下面的代码做最简单的一元线性回归:1 H  s; t. L! Y1 M0 `
    ----------------------------------------------3 x- f; ?% Y9 D, K3 z6 D
    import torch
    % R! Q. B* f; w/ Z- aimport numpy as np6 _+ n" ]5 c: a, n
    import matplotlib.pyplot as plt+ ~" r4 U0 G. b
    import random( e, B  Y. B8 f
    6 w9 ~4 m  }) B  E/ n; D5 @2 t
    x = torch.tensor(np.arange(1,100,1))' p9 A! f8 B. i, e
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    # ^/ p4 b5 D" c" ]' v+ N
    % o  e% j' W+ t8 ow = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b- F! {6 g% A' d) C( w$ o
    b = torch.tensor(0.,requires_grad=True): g" _3 J0 e0 U; o+ f

    4 o9 J+ A9 n4 tepochs = 100
    8 {+ o# j$ v5 e, R/ W
    $ q3 _; [* n- U/ j+ g1 ^losses = []
    5 t9 s7 c/ e- B7 }2 C) p; V. \% P; afor i in range(epochs):- {: I. m1 c. u3 L& D  S
      y_pred = (x*w+b)    # 预测3 q# E& W: k6 y3 |9 T5 G
      y_pred.reshape(-1)
    ! S) N3 G* U& R8 Y5 y 5 u+ h- Q' ^$ Z9 h4 h1 `
      loss = torch.square(y_pred - y).mean()   #计算 loss
    3 @6 v8 v( @) f& D) @8 X: S  losses.append(loss)3 g! M9 }4 L# D% }2 m5 U
      
    ( ?; N+ H" X+ P+ c! v/ G& S* U  loss.backward() # autograd" m' w. F' o; B; b* _
      with torch.no_grad():2 H% b/ v5 ?) N9 t( v) d
        w  -= w.grad*0.0001   # 回归 w
    6 ~9 \- j9 y' x& p4 I2 F! N    b  -= b.grad*0.0001    # 回归 b ) S% h8 ^# v* o; j
      w.grad.zero_()  
    6 {# W$ p3 i$ z: O% b& O& S/ I  b.grad.zero_()5 Q7 d& U! g" C, z. u& N6 ^
    * V8 l% x/ @" G! c* i% S
    print(w.item(),b.item()) #结果
    ' G* t1 I+ \7 N3 _" o4 d& t$ u. K/ J" A
    Output: 27.26387596130371  0.4974517822265625
    # V) }# H/ v, u4 I8 Q* `----------------------------------------------
    2 Y- f5 s; `! i最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    , ^! e/ Y: M( F$ H( ]. I; [8 |. o高手们帮看看是神马原因?2 x# |4 y( Q, ?% |4 G" V8 y/ ?- p, @

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    4 l; E# P5 o* j
    1 ^4 G9 N$ z# f' i) r没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?9 _& q) M& z3 e7 I
    -------
    9 k6 @4 k/ d+ Z& [: N1 w4 q; {不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    ( t+ @, t! ^  a1 `-------; F- c, Z# z3 ]! z! k
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    & |" y& w0 x8 [$ f; @( k没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?/ C. w! m7 h3 `  }
    -------! C1 m( Y( @3 n8 V
    不好意思, ...
    5 m' N% k: E* Z4 G! W+ U
    谢谢,算法应该没问题,就是最简单的线性回归。
    ; `4 V4 V3 N% ^# X, a* B* J我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    . k* f- B5 N2 ?0 _
    雷达 发表于 2023-2-14 21:52
    7 R4 L1 ^) `. n谢谢,算法应该没问题,就是最简单的线性回归。1 P) ~% [) Q$ K2 O
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    , r& `- O, q0 P5 L3 B- r1 Z# O* q8 p! j: S) p
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    4 t8 L$ y1 h! [9 m& p
      i( H. f/ }8 i9 u- O& F# Q0 k或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    5 Q% h# {) R8 e4 H1 B0 t0 Z
    老福 发表于 2023-2-14 22:00
    2 v% S; S, J4 a$ P刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    $ s, h) P& k! c" O
    2 I  q9 q3 E8 U# f2 x; r3 m或者把b但的起点改为1试试。 ...
    . W1 x! s" t- q4 r& Q- t; n

    ! c2 V: o8 k$ p/ [- I2 `你是对的。
    8 A' ^, h$ j9 o+ e6 [& q去掉了随机部分8 x9 G, N. J2 F" O2 \7 ]4 c) n8 ]
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)5 U3 A0 T* Z5 {6 X( l
    y = (x*27+15).reshape(-1)
    ( W: `( W$ j6 \+ I- ]7 ~7 w9 b; {5 A% `! L: D& K
    循环次数加成10倍,就看到 b 收敛了
    ' g% e6 ^' y' E! P! w5 y, jw , b. B3 z+ ]$ B, b7 Z% |# C
    27.002620697021484 14.826167106628418
    2 l3 M& }! O4 ~: ]
    9 U, w0 g0 L# q5 O7 j1 J/ N和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-9 09:52 , Processed in 0.033612 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表