设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2473|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    3 ?. |  M& T' Y( y
    * c' p* A* f. I4 ~. N+ O为预防老年痴呆,时不时学点新东东玩一玩。5 U9 y. t# t) s4 j7 s
    Pytorch 下面的代码做最简单的一元线性回归:; d4 i$ s8 f/ l- n& Y" |6 D6 \4 s
    ----------------------------------------------
    0 y& e" w1 f7 O7 C3 D$ ?import torch
    ' t" j6 K7 n+ U5 d) Z9 Limport numpy as np8 L, Y% |9 J9 e* {2 e
    import matplotlib.pyplot as plt
    ) t! Z5 p. o9 F8 C) Aimport random
    ) m1 S* R7 r1 y. }
    6 o7 B' o$ |. {+ \: o& z& I8 sx = torch.tensor(np.arange(1,100,1))
    ! X& l/ \' M% by = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    & w: D) [  Q! p' {2 Z% b- Y2 a. K+ \$ l7 z* S3 ~
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b) l$ t+ N, N0 M
    b = torch.tensor(0.,requires_grad=True)  ]7 R$ Z/ d* d& |
    / r4 w0 M; i( R# W# ?# [3 j
    epochs = 100
    ( l2 P7 H; h1 q2 D. {: U8 H" v; w1 [7 N. Y. ~
    losses = []$ F* F6 T) L9 N+ a6 G/ W% ?
    for i in range(epochs):7 z8 c2 v9 u# H% O. m
      y_pred = (x*w+b)    # 预测
    . n2 F  ?9 S2 u6 J1 L9 z  y_pred.reshape(-1)
    0 E3 p; V: O# |$ q
      I: ^$ b8 w: ]6 _# s  loss = torch.square(y_pred - y).mean()   #计算 loss
    4 t9 u, L4 J5 k( M# z" w  losses.append(loss)& c' z, I# e- ~
      ( X" E" K$ L5 W7 |) T8 J
      loss.backward() # autograd) o% [7 b3 p: P! I
      with torch.no_grad():) [% S- M# `/ c6 |; L
        w  -= w.grad*0.0001   # 回归 w' }2 }7 U/ j/ ~5 ^$ y0 q
        b  -= b.grad*0.0001    # 回归 b 2 @1 H% h. i* T; k; _/ y2 V5 M& I7 |
      w.grad.zero_()  ) L1 i' |# y( ~$ {, M) v
      b.grad.zero_()+ X1 J1 r1 @( n0 {" I$ w

    & u+ P4 y; z' B1 z3 Z/ I8 uprint(w.item(),b.item()) #结果' o+ A, M+ ^1 L: V8 s8 N& S; Y

    / C* U% q2 z* ?- POutput: 27.26387596130371  0.4974517822265625( k& {* k: R: K
    ----------------------------------------------: i- f, _  \4 O8 V. V
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    / b8 k- q7 d  Z! l高手们帮看看是神马原因?! k3 t& a& P- q) p

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 , Y$ a$ S% e' q: m6 d
    : A' W' I- g' U% ^2 I" B
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?; D: k8 n) z% Q3 i, C
    -------3 r3 z) e9 O4 X. D/ E
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。. b/ l0 e3 v9 S& o2 S; Y5 w' M6 `
    -------& g. N; I; f5 \! V" L
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23' c* v% t. ~2 G7 o4 i% c
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    - J( O: g" t0 R0 a$ x3 j7 T; V8 j; ?-------
    ' S, _, A, x3 Z: b1 q7 @# U不好意思, ...

      A9 I  M; T; i4 A% Z; ]2 |$ K: }# y谢谢,算法应该没问题,就是最简单的线性回归。3 @, n( a: w- l
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 7 P& h  f( b& F+ u5 ^; i& S& d
    雷达 发表于 2023-2-14 21:52& N. J& K0 N% O" z: C
    谢谢,算法应该没问题,就是最简单的线性回归。
    % {4 j6 k0 L8 Y, n. }我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    : a+ I' G! t; H, [# M8 e
      q4 |+ `4 m9 v4 F刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。0 k6 o7 ^: O9 D( @) r/ p

    & S; ~0 P8 k/ g* Z) B+ S或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    6 @9 l0 V8 R5 Z2 g7 K( h
    老福 发表于 2023-2-14 22:00
    " o% f" M( R2 Z  L- `# [刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。0 W8 ?( i# i/ z7 n3 q# O8 E3 ~
    9 X! P7 O. H& O7 g2 n0 @3 |8 k
    或者把b但的起点改为1试试。 ...

    7 M) d1 I8 z! I. J) x" f( l. f6 ~2 N
    你是对的。: L! ~2 Y4 R& p% J1 r1 ~1 o' n$ v
    去掉了随机部分
    * p% u2 b2 o# o#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    ' G: ?0 J9 z, b/ n/ b4 U' e* qy = (x*27+15).reshape(-1)" m8 m! p! s. e) j

    2 a6 Z/ t2 o! K7 i  ?' J循环次数加成10倍,就看到 b 收敛了
    9 W0 d  o6 R. A" w& ~# Cw , b
    ! x* D- k$ R* W" n$ K% s/ `/ Q27.002620697021484 14.826167106628418
    ! k/ t% a( }/ s7 n
    - |) |+ ]8 k" s3 C8 d  d和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-17 05:53 , Processed in 0.061333 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表