设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2290|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 - }! Y  z/ Q; s; V# s, p
    $ B2 A# V0 u) C  N# T
    为预防老年痴呆,时不时学点新东东玩一玩。' a1 u) c% Y/ k" ]; q' ?! u
    Pytorch 下面的代码做最简单的一元线性回归:% l+ V& J1 b7 S
    ----------------------------------------------9 Z9 Q: g( J) o: H
    import torch
    ( P3 z1 |) Y0 _import numpy as np/ V% N; n8 s9 p3 ]' O
    import matplotlib.pyplot as plt
    8 e! ?8 ?! w  ~7 u4 a, nimport random
    0 `  q( i" v) Q7 i9 ~: n2 z7 A2 v, V# G3 d' A# S  q
    x = torch.tensor(np.arange(1,100,1))0 j2 G' }/ C' ]7 |+ v
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15& _4 @7 V+ S( e/ Q
    ; |4 B6 F; H7 D$ G5 y4 `' ^# o6 T
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    6 T7 X# L) m9 zb = torch.tensor(0.,requires_grad=True)
    + D$ _/ S+ L* f7 H9 i  `5 U: A
    ! X$ g2 ^% n' O4 w! C" nepochs = 1002 U/ M' ^. b, q$ k4 y: q- J
    ; J+ ]9 ?% R, Z/ t' j+ x& T
    losses = []
    9 W, {! t( D4 \! j0 ffor i in range(epochs):
    9 X5 v7 x- F; w3 `6 s6 L  y_pred = (x*w+b)    # 预测
    ! R  z$ }4 U4 A5 r  y_pred.reshape(-1)
    8 m1 r' k% n! G5 ~ 6 I+ i/ H  A/ g" d
      loss = torch.square(y_pred - y).mean()   #计算 loss  Q' w6 ^6 |! l: Z
      losses.append(loss)
    ( k1 r& P5 b# b* ~6 s    P8 \5 J% ~$ J& P! a# n
      loss.backward() # autograd' ^( `# M% S9 M8 _7 X
      with torch.no_grad():
    9 a+ {% Z) ~- U    w  -= w.grad*0.0001   # 回归 w% }' M1 y; ~) B1 f
        b  -= b.grad*0.0001    # 回归 b
    9 p! e- X/ {& B6 D  w.grad.zero_()  
    # X7 w* a, a: Z3 \* t, O3 u  b.grad.zero_()
    - t/ b2 `% M6 N* x8 @* S+ ^
    : [' Z4 h6 Z+ T& v6 g& q4 Xprint(w.item(),b.item()) #结果
    $ A, c; p/ a1 S0 P$ t' S# }9 e9 P: o5 {: q5 o$ }, D; r
    Output: 27.26387596130371  0.4974517822265625) V) [2 q1 l: [: u5 ~+ g
    ----------------------------------------------
      d2 P# O: r0 L6 O0 U' N最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    0 i8 g1 m4 L# b2 o* [& n5 L% L3 X高手们帮看看是神马原因?
    ; O& m5 `! d- X, R" y" j

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 / ]$ M8 S2 N1 P& F# }! Q" G" [/ }4 z

    $ M( j) p! O$ }5 p8 C2 q& x  }0 N没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?) C& X( F0 n+ L$ B8 V
    -------
    ( j4 J: T, @$ q0 k不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。0 D! v% \" p6 k" L& c
    -------* R9 f7 j$ P2 z/ |# o
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    # ~  n: V- F* U- {  t没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?2 u% C- t% f! }
    -------
    4 @/ A" t  U, x" u2 ]; q6 c不好意思, ...
    ; [9 {! D  `1 Q7 C
    谢谢,算法应该没问题,就是最简单的线性回归。
    9 h# A+ P# M7 }' C我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 % V, H2 M$ G/ T# t
    雷达 发表于 2023-2-14 21:52( F* Q2 I2 N9 k( |4 `
    谢谢,算法应该没问题,就是最简单的线性回归。
      D8 [. D" D2 Y/ e: T我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ' w7 X) y% k4 c3 W- S; h
    # a" @6 E2 ^3 U* E
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    * K2 x' u2 T( I
    ) m* k" [: m/ y' s; Z6 ?或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ! f# z$ e/ v. H# o5 D
    老福 发表于 2023-2-14 22:00
    3 n2 O+ C! d2 \刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。) Q3 q8 k& P* O2 a' l$ v: M
    6 m( l0 X' Q8 ^1 X# s
    或者把b但的起点改为1试试。 ...

    - }% b4 k0 F+ M) Q8 r: |
    2 ]9 b2 \6 N0 w6 ^. l0 O( F你是对的。" L  V+ x& C4 K. z
    去掉了随机部分
    * h/ l: Q8 h. S$ k( M$ \' f* D! n0 ?#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    - H6 e2 W- Q- ?! ~9 Cy = (x*27+15).reshape(-1)
    * _+ s! u) C) j; S( \# U
    ' J$ {; v1 Q; Y' G1 E4 N循环次数加成10倍,就看到 b 收敛了% R# K, p9 a6 j  s4 P1 z* u1 m
    w , b& |3 c' y* E" V4 A
    27.002620697021484 14.826167106628418
    , {0 w& h: t5 w. A/ g
    # f3 ~! A2 }4 u( j9 H和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-7 04:15 , Processed in 0.045903 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表