设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3127|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 8 v# B; j& P2 r. Z# J* {6 U
    / x' }! E; E* g
    为预防老年痴呆,时不时学点新东东玩一玩。
    ! \- m2 b) g, Z; ZPytorch 下面的代码做最简单的一元线性回归:  N1 \( S$ H7 S8 f: {0 M3 O8 C
    ----------------------------------------------3 N  z, a- A% ?  _8 N, U' c
    import torch5 K" n6 D# ~! \( U1 M  k
    import numpy as np) n( Z  [; T6 b( @
    import matplotlib.pyplot as plt; C9 ~$ c2 e8 M: d2 G& Y
    import random. u" F0 ]  ]6 t* i! J: Y

    + ~9 B" ^+ Z3 H' x9 {x = torch.tensor(np.arange(1,100,1))6 U6 n/ L5 k  Z! |# B
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=153 v, i  }% @7 M' N6 w% i& d
    0 f' w3 g, p- C3 l( s! d
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b/ V4 A$ _$ x9 b+ y
    b = torch.tensor(0.,requires_grad=True)1 e& @( ^1 _, @
      I# J# @; j5 {- S. Y
    epochs = 1002 h) z. z- J' v8 U

    5 g; {; Y& p% ~3 h7 S' G; Hlosses = []
    ; Q! {7 Q) S& I+ R; efor i in range(epochs):7 J$ Y; \  F( b0 L
      y_pred = (x*w+b)    # 预测
    " m+ P: O, z! V9 o2 @  y_pred.reshape(-1)7 L4 c# ~) B! W3 V& f# d+ z
    9 e/ i9 o  G. z6 A  B. j
      loss = torch.square(y_pred - y).mean()   #计算 loss
    . x  O. J2 ]* o4 r$ q, M" H) C9 C  losses.append(loss)* S4 ?9 z# y( w, T3 h8 ~
      4 w$ O0 h! Z8 e8 }' o3 U$ v
      loss.backward() # autograd5 ^  B& C& {: i& h
      with torch.no_grad():
    & c! ?9 S* K2 B+ o2 B) d# p3 c, d    w  -= w.grad*0.0001   # 回归 w. M, [8 z' K8 s
        b  -= b.grad*0.0001    # 回归 b
    ! p- Y3 u, {! |" _; A% q  w.grad.zero_()  
    * B( n$ @. A! a  G& E3 [  b.grad.zero_()& l# a6 G- F; q0 [  N

    ! H! U& M# P6 J3 r2 pprint(w.item(),b.item()) #结果& C( A1 _( C3 [4 t3 G
    $ s& B9 \3 X$ u# f* l! g
    Output: 27.26387596130371  0.4974517822265625" P  l  t$ l, r. D2 g, [7 H
    ----------------------------------------------
    ; T* W+ r( {1 [最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    , V3 A+ g* i/ k* E: M高手们帮看看是神马原因?* m# \3 x+ a5 x' M6 i

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 ; u" y6 v( f, @* e7 z5 ?' b# I: D

    4 \% K5 V6 B+ k* b没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    - N* m1 {  F' @) O( a+ G% }# u-------
    % g9 F# q/ ]! E3 |1 T不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    8 P2 n$ u0 o. m5 f  |-------
    * t& S& s8 @+ ^# N( n算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    ) I8 H. U% h0 O7 F* d$ S没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?6 e& j5 I5 o$ r. j! E. r
    -------' U1 c# e( i) U& C& ~$ M1 I
    不好意思, ...
    2 Y1 e2 ?, I: l4 v9 h; |# J
    谢谢,算法应该没问题,就是最简单的线性回归。
    $ j2 d+ h# t5 f7 G5 P* w. u0 V. }我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    " {7 w: |7 ~# M; y* x) n
    雷达 发表于 2023-2-14 21:52; L' I9 X6 J& m! e/ A0 T
    谢谢,算法应该没问题,就是最简单的线性回归。8 K. P  I) N' H. D
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    $ I, W, A; F* j8 v
    . R- t( q3 q9 f1 v% h刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。+ U2 X# J: J3 `1 l" X; g+ x

    . T. ^2 I  I7 S' k# s+ v或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 . [0 S6 V7 h7 r% d
    老福 发表于 2023-2-14 22:00
    7 L. ~6 T( T$ S4 l8 R5 g: G% `刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。! R& H' V% N* P, s1 [

    - h3 |& d$ b: p" L( A或者把b但的起点改为1试试。 ...
    7 X/ O$ N9 J6 M; \

    9 i2 T* v! S$ H# M你是对的。) Y6 ?9 y6 L' ~& y
    去掉了随机部分
    ( {' o$ |( I" C. }! T#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    4 k* N5 c$ Z8 wy = (x*27+15).reshape(-1)" y) U( ?/ p, i7 A1 g% t

    5 w( m" p. d1 C! Z7 K循环次数加成10倍,就看到 b 收敛了
    : K9 F* i& A7 i4 @1 `w , b
    ) E/ P6 u0 m: [27.002620697021484 14.826167106628418
    ; n: ^; [2 f8 `) ]3 ?1 _* ^' B# _  q% ]$ w+ k
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-21 23:14 , Processed in 0.061108 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表