设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2211|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 4 [9 l4 G0 V( o" T
    / a; Y( V  R! J$ I+ m4 M6 }* C$ e) _
    为预防老年痴呆,时不时学点新东东玩一玩。2 }" Y4 |. G/ M
    Pytorch 下面的代码做最简单的一元线性回归:$ a6 e  Q+ n) S( K5 O  M4 G
    ----------------------------------------------8 j2 i! F1 k# F1 r3 N/ B
    import torch2 ^; O9 A7 Y  @, l7 B8 d
    import numpy as np
    + J- H' W' n! w& }import matplotlib.pyplot as plt+ ]( c" w$ x9 z9 k
    import random
    " f4 i. W( `! ?* L4 M- Q( g5 K* O; w2 k; v6 g
    x = torch.tensor(np.arange(1,100,1))
    ! ^) W6 K3 e* M8 i) ^y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=154 S" e! }8 L/ q+ D, l
    % u1 Z$ o8 Z- E8 g/ \: z! O+ R
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b- A6 _  [9 m2 u! V
    b = torch.tensor(0.,requires_grad=True)
    1 K0 m, v( n& v! _" q
    3 j/ S* w! P/ @! e3 V1 u- vepochs = 100
    / {0 O1 v. N+ h; p$ ]+ i% I3 ~  H; V" Z; Y( ?
    losses = []& S2 O2 G5 v3 l  Z; z7 r: Z
    for i in range(epochs):
    9 A3 w% V) [1 m  y_pred = (x*w+b)    # 预测
    7 C2 g0 j, Y- z) v! e7 ]8 d1 p  y_pred.reshape(-1)3 c* w9 [* A+ L) Y' t4 e7 D2 c

      J1 O8 Z0 n" g6 \3 {$ x. x& I* W  loss = torch.square(y_pred - y).mean()   #计算 loss8 h# l8 _6 }4 l0 \" ?% W& F/ u
      losses.append(loss). A# K# `) k1 v& K
        J% d# n0 ~( L4 X! r
      loss.backward() # autograd3 m% A" h0 h! _5 x: U; Y) Y) o
      with torch.no_grad():
    / `5 t5 m* P( M  Z    w  -= w.grad*0.0001   # 回归 w" \- A9 Z  s8 j! a* a* R2 u
        b  -= b.grad*0.0001    # 回归 b 7 ]& T. Y) T% ?  z% K$ |$ o
      w.grad.zero_()  9 {0 ?! j" a. V: I1 l: Z0 F
      b.grad.zero_()
    % Q: c: g- i/ U6 X
    8 H# S' ?* b7 oprint(w.item(),b.item()) #结果5 K9 E; O! Y3 y
    % K0 e% `( J* m  t# F
    Output: 27.26387596130371  0.4974517822265625
    4 Q  H. M% W$ a5 u3 G----------------------------------------------
    7 T  w6 y0 L% K最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    ( l* ~5 [$ ~9 t高手们帮看看是神马原因?" N' O9 i3 [$ l, K& u

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    * P3 `1 X7 n$ H+ ~  ]' c
    6 e0 a, u* X- k, G没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?( y6 t% w9 n2 A4 Z5 `
    -------. B: h& R3 Q6 H5 R& W" K: t
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。7 P6 Z. V9 @, z3 J7 c2 i7 ]* m3 K
    -------' k) O! B, `6 z2 `$ y
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23/ x6 s! x6 u$ h; }- s
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    . r( d6 V, I) R6 f-------  |! J+ l3 O2 B1 J$ _1 F' g
    不好意思, ...
    & y5 P( [, Q$ L& l
    谢谢,算法应该没问题,就是最简单的线性回归。
    ' k$ r/ P, W" d) p我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 ) K- o" _/ a+ {  l- K
    雷达 发表于 2023-2-14 21:52
    1 e; M) i2 ]! J7 {谢谢,算法应该没问题,就是最简单的线性回归。, z- T4 z3 r, h5 K
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    4 I1 [6 |5 P% Q+ N1 @

    5 H& n/ r8 X- z刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。, p$ ]9 O5 P$ y! `! K
    4 A- e: M+ D$ G' E( h- V- s) S
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 3 r8 {% N* s( }; |, ?% Y/ B
    老福 发表于 2023-2-14 22:004 E/ G# Y6 \5 B: C5 D+ u  |
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。. H$ }; Q3 `; j$ O' o9 N9 m# h, r6 {

    9 n' }8 q% d1 Y2 ~9 }或者把b但的起点改为1试试。 ...
    / [1 Q% ]* `6 J& \- v8 D5 v2 @
    3 |5 p+ A9 G( J
    你是对的。- u& P; ?! J; O. U$ j
    去掉了随机部分
    8 ]2 M# g" ?* C$ L) _5 Z8 D4 y5 c#y = (x*27+15+random.randint(-2,3)).reshape(-1)9 d8 u  |/ R# C, ]
    y = (x*27+15).reshape(-1)$ G: \* x* l+ K  M: v
    % w. x- E8 l& M7 e
    循环次数加成10倍,就看到 b 收敛了
    ! k, t1 g2 U) D) w0 fw , b2 ~* o) c5 a3 g& p
    27.002620697021484 14.826167106628418; J; r% ?. o! [3 f0 P1 b  d

    / w( M; H7 A$ O+ P4 L9 w和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-12-14 05:29 , Processed in 0.033105 second(s), 23 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表