设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2600|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 * e! m" C& F5 p
    ! J, Z0 d4 b& @) x
    为预防老年痴呆,时不时学点新东东玩一玩。1 B, k7 }: {0 N. q8 S0 ]9 T
    Pytorch 下面的代码做最简单的一元线性回归:4 l! N& }) @7 }, h& X* h
    ----------------------------------------------
    * I3 Q, c  _, @# E9 A0 u% f2 }import torch9 g! D/ w) r4 @3 d: P* ~* g
    import numpy as np+ T1 C) k5 A3 r- \. o
    import matplotlib.pyplot as plt
    : g  h, k5 Y! Z: ~: y  Limport random
    ' _9 e4 ~; \9 @2 e
    # X. {, j% e% b2 J7 xx = torch.tensor(np.arange(1,100,1))3 N, U/ r* M/ n$ H" X; X
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15% d/ z, m6 ~0 V( ~+ F  D

    ' |8 p& q, y- X5 Gw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    % T7 |2 C5 u  }2 h" T3 K" eb = torch.tensor(0.,requires_grad=True)
    * N; Z) o1 c, o; e7 o
    ; ^+ a) U/ _0 V9 _: a/ G$ mepochs = 1006 d) [) k' I# o' a5 q
    % e% [8 l2 M( L( I) j! L6 u' h8 S
    losses = []2 {" k1 F4 h5 E- i) k
    for i in range(epochs):9 E( {, h" F& z( j# a" k
      y_pred = (x*w+b)    # 预测
    & n! T$ _, L/ m) L  y: }; E+ [3 ~  y_pred.reshape(-1)5 m2 l( L+ F$ i4 D* u7 a% {
    9 H/ D# r" E; H4 _8 h) s$ C
      loss = torch.square(y_pred - y).mean()   #计算 loss
    " O/ E$ G4 R0 }( `4 I  losses.append(loss)
    " B+ M7 O1 k6 G. j$ P  $ I7 v& o& f0 n' @$ A: Q) {
      loss.backward() # autograd
    6 `3 D7 @- D; ?% K  with torch.no_grad():
    5 t; m- a; r4 A    w  -= w.grad*0.0001   # 回归 w5 S: I( F9 p( x4 Q+ T: H
        b  -= b.grad*0.0001    # 回归 b
    . L6 y9 V, s6 F, h  f  w.grad.zero_()  
    " x" ]) i% N" W& V5 p  b.grad.zero_()
    4 `7 {+ p5 H& a# h6 ~: _# M4 `8 {& A9 e( [9 K% A. D
    print(w.item(),b.item()) #结果* Y4 x: Y3 L5 ?" Q

    5 a2 {7 S; @; B  s) \Output: 27.26387596130371  0.4974517822265625
    5 D0 D! i, F9 U$ e/ J9 l, M----------------------------------------------, c. d/ I$ }6 z! O% m
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    8 \7 w" h+ f; u. a6 o高手们帮看看是神马原因?
      U# O$ C; w: m0 Y8 @( q

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑   p. V) t/ {% |3 G

    8 g! y- v9 u8 W$ ?8 R6 \2 G没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    1 D: @' R  H- ]) \2 L9 q0 n-------
    + ^+ M& h) g; V# p不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。. P2 T7 }  ~( z+ P) {$ W
    -------. s1 H. b2 K8 ]3 P0 `( d
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23$ T& Q  ], N3 z* ?4 d
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    " O9 z. {/ g/ X0 t$ s-------
      N4 @4 P9 k3 y# U7 a5 A) J不好意思, ...

    ) V) A9 v7 U. ~1 u  T6 t+ M$ [谢谢,算法应该没问题,就是最简单的线性回归。
    9 ~) c- L' i+ D, C我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ( W! h7 o) g8 k3 F9 f; C
    雷达 发表于 2023-2-14 21:52
    % h7 d: K7 u9 ]. Z* f6 [谢谢,算法应该没问题,就是最简单的线性回归。! |6 W' x, ^- F
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    * _2 ?6 N; s& _; h
    / ^, x) E# L; ~% h: p+ [刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。! f2 w! o  T8 w9 L; Z7 ]

    & C' D# l3 P' d0 E7 D+ I. a或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ' n7 H8 T: I  B/ `4 l$ s4 @
    老福 发表于 2023-2-14 22:00
    8 n( @$ J# @* A  h刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    " p: z8 f$ |' g. X6 `( b7 E& u; h/ i+ P! P3 n3 D7 X) F
    或者把b但的起点改为1试试。 ...
    % g0 o9 H6 f, l) r

    3 F+ |+ Z( P/ m  ?& c% T4 g你是对的。! }* [. f7 k1 s
    去掉了随机部分( L( O; y/ r* L" Y# Y. S/ X
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)) k2 g% J- t2 t; q! }
    y = (x*27+15).reshape(-1)
    ) D+ i4 y+ W, T/ S9 K. V/ J0 L) W$ C% Y/ s* [: {) i7 i
    循环次数加成10倍,就看到 b 收敛了9 ^: ]8 A% g& j( I
    w , b; _. Q9 J1 s0 w0 \
    27.002620697021484 14.826167106628418( c6 `* D8 O4 `: a5 K: h
    ; }! }  @9 y" _/ Y1 U  M
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-3-14 16:46 , Processed in 0.058087 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表