设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3031|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    ( u+ G7 `+ V. }. m2 O' i) q. A4 [4 d; l. @6 x1 E* H
    为预防老年痴呆,时不时学点新东东玩一玩。
    & t" Z! R; B" N; o7 TPytorch 下面的代码做最简单的一元线性回归:
    + C2 ~$ v7 ]/ H  @7 B2 ?----------------------------------------------: ^8 D4 S: V6 L
    import torch, P& E6 S( Q* Z1 Z/ l/ K8 ^, W
    import numpy as np
    ' M4 Q- u5 R1 m; E$ N* aimport matplotlib.pyplot as plt
    * M+ n  h3 G  e. e* limport random
    3 w# a8 z/ o, k
    . y# M2 X+ _; ~. E7 l1 wx = torch.tensor(np.arange(1,100,1))7 r- ^8 F3 |  T% r
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    0 o8 `0 |: ~3 R' b' h/ y" D& I: p
    ) F3 {2 r- F' O1 j( Q5 [2 p$ ~2 o2 Gw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b" U/ s( v7 u( j7 T4 u: ?
    b = torch.tensor(0.,requires_grad=True)
    * A$ ?- {8 Q( ^6 _2 m# [- e% j4 Y! n4 D5 M% U; n
    epochs = 1007 K. P2 d2 e" ^6 c6 ]4 ~8 `$ E4 b

      \9 r/ e" _6 n, {; U  nlosses = []
    ; Z0 P5 t! M0 ~1 O+ A+ W0 e* Xfor i in range(epochs):
    . U4 t" @5 S- A. P3 v9 @  y_pred = (x*w+b)    # 预测
    3 o$ s: A) P8 L3 o  y_pred.reshape(-1); E( u; x6 p& i: P1 b! r% d
    2 Z( G9 m. z: @) m
      loss = torch.square(y_pred - y).mean()   #计算 loss
    2 q  n* ?2 ]: l* `9 t* c3 @  losses.append(loss)
    . J; W% n4 h2 M  / B% Z+ Y3 D/ J
      loss.backward() # autograd6 `6 T! A, [8 m- v7 ~
      with torch.no_grad():, g" Z! ^% h" M$ _1 X7 r8 K; c/ a, U
        w  -= w.grad*0.0001   # 回归 w
    ! h3 q* x1 N$ K! i+ q  j    b  -= b.grad*0.0001    # 回归 b
    + d1 O. E6 ]3 O1 M. H1 G- U  w.grad.zero_()  
    $ F8 n: O- k$ I. N, a  b.grad.zero_()
    * s& a+ K4 d' ]% j0 y: D) h" k% i( y/ O- Z+ _- w
    print(w.item(),b.item()) #结果
    / I3 \5 Q' |" I$ t' A
    " m  V8 O& v( q, U4 F# Y" JOutput: 27.26387596130371  0.4974517822265625
    7 d: P2 l2 X- T3 x----------------------------------------------( Q1 Y6 w; N7 F/ B; N4 U
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    : ]5 `/ J1 Q2 a/ P1 V8 O9 }# j高手们帮看看是神马原因?
    2 {9 Y2 K. F3 C% ]0 U1 w

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    ; D5 ~9 B) z. G" g+ M9 }: T4 J& f+ F$ }' g, z$ {9 N" k
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    ' _; e7 @4 e1 {1 b-------
    / g! `) x6 Q3 m$ S" ?: j不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    9 C. A( Y4 T& {+ R-------& _! t, e+ y" X* {
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    4 r1 Y  g* ?, S( v/ }没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?; l- B2 R- \& m$ m/ x
    -------
    8 I) @1 e  [; n, s% j2 j不好意思, ...

    * |5 ?/ N' A$ n8 O. ~谢谢,算法应该没问题,就是最简单的线性回归。$ Q/ s3 `: O7 b. a0 ?4 K4 l
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 , D; v$ }6 K% Y( m0 W5 u
    雷达 发表于 2023-2-14 21:520 P2 c" F/ [9 d  r0 X! A
    谢谢,算法应该没问题,就是最简单的线性回归。" Z2 D+ [; S6 {& Q# m
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    7 |$ r8 V; L, Q  J
    1 U( x) G! p3 s刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。6 l9 K8 c2 o0 V! H( w) p
    + X5 r. |) U8 ?, a. ?; D' _6 s
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    : k+ f) ]- C+ T/ j. m
    老福 发表于 2023-2-14 22:009 J* w" w& h0 q; ?; ]+ j; x
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    * r! t1 z6 x: A7 n/ w3 X8 C2 {. l& B4 c
    或者把b但的起点改为1试试。 ...
    8 \3 n. g0 ]0 Y! U, @# p( B; z

      g0 a( a$ j6 S+ V  \) M2 I  ]' `你是对的。
    ! S( B! T+ y6 P( c去掉了随机部分
    3 `5 V* h! D9 g+ [: u#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    : `$ O2 w0 ~" A* ]* _y = (x*27+15).reshape(-1)) `5 t4 k- A! u3 R$ X

    , ]: P' [$ x7 ~循环次数加成10倍,就看到 b 收敛了
    : i. `' E/ Q  _# x& ?  I& Tw , b
    9 h. ]& G# A) @) v- _7 o27.002620697021484 14.826167106628418
    0 X' l# \7 D' f# e: i! o  l7 B4 q, o* f  _* M/ T& K# _
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-4 08:09 , Processed in 0.062011 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表