设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2497|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    6 s6 q  N: D# ~6 |+ o
    1 P; h5 f. g2 E; {* W/ Y为预防老年痴呆,时不时学点新东东玩一玩。
    3 a& c/ ^2 x3 V" F% k% c8 kPytorch 下面的代码做最简单的一元线性回归:
      b* G9 K  m0 d8 v----------------------------------------------
    5 i; G5 \) l3 ]( @$ himport torch; L2 l7 E8 R* b4 d1 K+ ]: w
    import numpy as np
    6 c+ U& S. z2 w- n! O1 qimport matplotlib.pyplot as plt3 E- \; ]  a# X1 ^
    import random4 Y6 N# r* l2 V: ]
    / J9 W  [( K! e  V6 U( k
    x = torch.tensor(np.arange(1,100,1))5 f/ M; N* N+ K' @* k
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=153 f* h, V: {0 f
    , q1 v4 P/ S* h  M5 J' P- d
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b2 N1 i9 T- N4 e' ~1 g' m  N/ a
    b = torch.tensor(0.,requires_grad=True)& v, q4 L9 ~8 S4 F5 q
    2 u/ L) Q0 E/ V
    epochs = 100
    $ ?% F- D1 p  `1 l4 X; Q0 d" M
      u8 i* Z" y$ G. x: ]" H, w* o5 |losses = []
    $ V4 E/ o- N- f1 O( ffor i in range(epochs):  C; j; A8 c4 X8 y4 N7 J
      y_pred = (x*w+b)    # 预测5 l+ g! V& h& b, |6 ?! C
      y_pred.reshape(-1)
    8 U9 b9 F5 x# a) M- \: f8 u% H
    & u3 \7 c, g1 l5 m% W  loss = torch.square(y_pred - y).mean()   #计算 loss
    6 k7 R  V- h' u: W; T  losses.append(loss)* }0 L0 ?! L8 W- F- l
      
    5 _  O! s4 y2 i" W% {  loss.backward() # autograd, @1 }' _8 p  O1 c8 B
      with torch.no_grad():. W, m% X3 x7 M1 S
        w  -= w.grad*0.0001   # 回归 w
    ( }) L- d$ m! r    b  -= b.grad*0.0001    # 回归 b
    2 j  F0 u' p) p. ^/ x  w.grad.zero_()  
    5 T9 V" A) {$ S9 ~  b.grad.zero_()
    1 J) Q1 L& F3 j. K1 }# I5 c( I3 w+ d3 S! ]+ }8 b$ b
    print(w.item(),b.item()) #结果
    9 g5 d) k% @2 f
    . K3 Q- A3 j1 I. p8 F4 qOutput: 27.26387596130371  0.4974517822265625- L1 P/ _' A' T/ Q
    ----------------------------------------------3 k! D/ Q5 G8 N2 ^- U% g, @6 d
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    1 X* n7 a( f1 \1 U高手们帮看看是神马原因?
    + e( {6 C3 U: r

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 . l( f9 b9 n* y5 [! N1 M% T/ s4 r

    8 K- q" J/ V3 v6 b" x; I没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    & \6 o1 t! H+ }0 A- P-------4 n  C5 X5 j0 f& d
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    3 i  C# \$ l* G; C) g( H- T-------7 \% c* w9 k  h2 ^- G" K. G. C  `
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:238 j# R" m4 v5 }$ Z  x$ N8 m
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    6 {$ m1 B! t5 f: n$ k( y-------
    ; k, z6 p9 G( `$ G* z* E+ n不好意思, ...

    7 U; i' a9 \" u1 [8 D谢谢,算法应该没问题,就是最简单的线性回归。: v9 ^9 |8 ^, L
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ; f, W, R# M8 c
    雷达 发表于 2023-2-14 21:52) q2 l! y" C  u
    谢谢,算法应该没问题,就是最简单的线性回归。
    : w' j* P# L0 D3 I- y/ m我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    5 M) _; k% H( m! K  S" V! x: U: E5 n7 S# n9 m% F
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    & ]3 G: l2 A. o8 h5 N, f% a( o# I& }2 J2 ?- k! C( @
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    1 p# P; Y) H5 u" j: z3 U
    老福 发表于 2023-2-14 22:00
    7 i7 U: m  P  N) i8 x7 {刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。/ q7 S5 F) O& C: e1 _$ p
    ! L  ~9 P/ o! ~( e- [/ m
    或者把b但的起点改为1试试。 ...

    7 g- B: |7 K" J) s. N8 x9 S, ?0 r3 w$ S; y: T5 `5 ]
    你是对的。
    ! ^5 _7 s* _- i. }去掉了随机部分
    ) y0 S3 j1 I# u5 k! h2 I& E#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    & ^" K- T, ~: jy = (x*27+15).reshape(-1)- l3 _* o6 A0 n. ~4 ^4 \5 i0 n

    , S; f# I! Z, ]7 d# t, S) B) T循环次数加成10倍,就看到 b 收敛了
    . v0 h- f# p5 M( r1 t# F& |* b( j  X% W) vw , b: J+ U% n3 K) m& i
    27.002620697021484 14.826167106628418
    6 @% O0 U* A; s) `. H. E- O
    0 C3 s! k# n# c5 Q' S) M5 i和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-24 14:35 , Processed in 0.060810 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表