设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1794|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 . @& Z: l1 p# Y$ ~9 n& e. ]3 }
    3 w9 _8 ^3 h. K$ R; C
    为预防老年痴呆,时不时学点新东东玩一玩。
    5 Y2 o/ w5 R' T% V. N6 BPytorch 下面的代码做最简单的一元线性回归:$ }  i8 I2 S0 P2 @- r4 n5 E0 Y
    ----------------------------------------------6 @5 F7 O, G4 K# _
    import torch7 X& N. A! [+ D5 b
    import numpy as np! f' H' v5 B9 _. h5 |5 t
    import matplotlib.pyplot as plt! ~3 v- O5 }" J0 ^/ Z
    import random
    * j% o" T  s0 ]' z3 C" E2 F* k5 Y) H$ L7 e1 P: A
    x = torch.tensor(np.arange(1,100,1))
    & w8 v- ~0 [( V2 u  w" K) m% i& Ty = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15% [& c: q! y6 r) O9 g
    ! l( M/ L0 B( @
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    # {! L: b$ T9 L4 gb = torch.tensor(0.,requires_grad=True)
    ' s2 j2 m' Q7 s. }- B  I- n6 S: W4 ?! R7 [& B  z# W/ G
    epochs = 100
    # J# Q$ P% ~4 b2 F& Z* t- V9 O# H( ^
    losses = []) |1 O) p7 u% w! w7 D4 d3 g
    for i in range(epochs):9 S9 U* d9 R& R4 {
      y_pred = (x*w+b)    # 预测
    ) w$ e2 T; ^* w5 }. u  y_pred.reshape(-1)
    8 G$ u) e) [; I/ P
    0 U: Q9 ]7 x: T  loss = torch.square(y_pred - y).mean()   #计算 loss
    % p2 v2 r4 ~4 v" p. ~" X  losses.append(loss)
    $ K3 e3 ~6 d& _' z! @  
    ' I6 p( ?7 U7 X/ V( x  loss.backward() # autograd; L/ ]+ ~  J, J. z% i" Y
      with torch.no_grad():% p% A2 ^# r) B# @
        w  -= w.grad*0.0001   # 回归 w# t, x+ M4 ?1 K8 A5 G! O$ Z  |
        b  -= b.grad*0.0001    # 回归 b 5 \+ _9 ~' \+ [+ A# D) j
      w.grad.zero_()  
    " s+ D  E& j' N& ~6 j  b.grad.zero_()
    8 T" d: x+ O  {! C
    * j7 t# V9 w, hprint(w.item(),b.item()) #结果
    0 T4 w' Y( ^! t* U- r2 P1 g; q# v( D6 u) h" R! U  @  L
    Output: 27.26387596130371  0.4974517822265625
    3 S. G) ]! o& N) [----------------------------------------------9 n' v% D; D8 G* @! A
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。* m3 f: E# ?, [! Q9 _1 z  b4 P
    高手们帮看看是神马原因?3 p' F8 g4 `1 c  t  Z, O

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    , h8 V/ z$ {: S0 A8 H& s9 |1 g/ Q2 _7 w8 q9 x, n; d4 d8 W
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    , N; @7 |. B5 K0 I# ]) }-------+ G$ M1 a- r2 @( u, u% Z
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    1 s4 V( r2 S! B4 f; C/ O8 |-------5 s5 T, v' |3 f2 K/ X
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:237 e7 j$ S* g% H2 D9 }: @6 e  @
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    1 W# P; C4 \' U) I! g0 n-------, D, V2 R0 N  I: S0 y8 d, i/ y
    不好意思, ...

    0 O/ a& U; {6 U谢谢,算法应该没问题,就是最简单的线性回归。. m( _$ a, J$ L3 h9 j& Z
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 0 D/ s/ v9 g( P. m9 {' m( R
    雷达 发表于 2023-2-14 21:52
    7 b6 G; k" b% _; [2 _* ]1 |4 p谢谢,算法应该没问题,就是最简单的线性回归。) p6 d- g0 G1 b. a
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    / M3 o: q: L) |3 E7 O$ {) l

    8 \; J! H9 {5 @- P- d$ v/ _3 [: W刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。7 b. Z& M; N6 H5 a& w

    * ?* S0 N7 a' r: B: g! \4 S或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 , h0 a) W4 j) l- M
    老福 发表于 2023-2-14 22:00# c6 ]9 K6 Z' G1 j9 t* }
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。( Q/ C' \# S, t/ h

    ' B4 j5 w& p$ g或者把b但的起点改为1试试。 ...

    - Z, L4 |( a$ [5 i3 F2 @' f; C- Y0 _( m. E/ c. S
    你是对的。/ Y' Q+ K8 f9 M) |) t* B& B* v
    去掉了随机部分: V2 C$ j4 f9 g0 q/ J
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)0 |" Z+ I' b( V& Q, b6 ^& D
    y = (x*27+15).reshape(-1)
    2 O( Z; b  x7 B1 G; b; X
    % g5 n% S0 c$ c; _9 K循环次数加成10倍,就看到 b 收敛了
    3 p2 F" {( y+ w( ^w , b( ~4 |' `* `! o2 O
    27.002620697021484 14.8261671066284189 q2 r% `& o$ ?, m; t4 }3 y

    ' j" y, l9 L- d6 y2 F& e- e和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-7-19 11:15 , Processed in 0.035767 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表