设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1766|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    6 ]7 [0 K" d! b- `* D6 \) x* c
    1 m3 ?9 Y8 ^) e+ O为预防老年痴呆,时不时学点新东东玩一玩。
    8 t% H% [$ u3 }: m% x  qPytorch 下面的代码做最简单的一元线性回归:
    : R2 w0 m6 U2 y7 g4 E----------------------------------------------
    6 z7 g$ t1 L4 ~( r# Z+ Aimport torch
    + A$ C# V* ^& }  g* e2 n1 Dimport numpy as np5 y9 g* p+ @( m& ~
    import matplotlib.pyplot as plt2 b" V# f" ~9 g1 J- n) E7 o1 s
    import random
    3 L7 j/ u1 r/ A
    " t* ~7 V2 [# Q+ H- A, _/ Lx = torch.tensor(np.arange(1,100,1))4 X( J: Q% _7 [( R; K" N
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15/ M0 a1 T3 s) l& c1 Y* b3 X

    7 h. t5 C1 P2 B0 e* F5 V9 sw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b8 ]& s9 q9 u4 \2 s
    b = torch.tensor(0.,requires_grad=True)
    9 d) q5 N; Y: s0 h
    2 L$ t7 k1 p8 s: R9 mepochs = 100$ d) g* n9 z: N: Q% ~

    7 P' x$ m8 p% \; Z' v' zlosses = []% {% D" o8 f0 z  C) }5 t$ Q% a
    for i in range(epochs):' w- W5 n& |8 w1 |+ d' j
      y_pred = (x*w+b)    # 预测# U* q& o5 m6 [; H  l, H
      y_pred.reshape(-1)) K1 |2 a: D3 Q. O3 o& ^
    4 n  K- ?* z; H
      loss = torch.square(y_pred - y).mean()   #计算 loss
    + S8 F7 s* G; p3 E  losses.append(loss)
    1 T% S- c" v: F: P* s) W  8 j3 u( e* Z9 [- ^/ ?
      loss.backward() # autograd
    4 ]* N1 [8 j# N; W! y8 Y0 k* G: v& c9 x* G  with torch.no_grad():
    $ G9 L4 I+ l: f" k! a    w  -= w.grad*0.0001   # 回归 w/ M, L0 ~! m; }9 E9 u- Y
        b  -= b.grad*0.0001    # 回归 b ' O) K6 ^& S$ O0 J% q( Q' g
      w.grad.zero_()  & i4 O4 }& j+ l2 ^$ e' ~8 Q0 |
      b.grad.zero_()4 f1 u1 I0 N1 y# X+ _; W

    # B+ T. o' n9 Uprint(w.item(),b.item()) #结果
    9 c* r) Z: `& Q7 }" T: P  e  ]1 @2 W
    Output: 27.26387596130371  0.4974517822265625
    ; O- H7 u, b/ I/ q. ^# s* p----------------------------------------------9 K- K8 T: c! s7 O. y: T* t
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    0 n0 N7 |1 y- p: `高手们帮看看是神马原因?, P9 |$ U. H2 L( `" Y% ]

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 / |! A: i4 D  E; x4 D; E
    6 j6 V- `# |( {; q; u3 n: G1 ?
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    5 e- u6 b: [$ z" x6 |/ [/ J-------
    ! P) j& s# b$ s; _  q/ ]不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。! o7 t1 v$ x5 K) A8 x
    -------8 x2 [: T, K7 j5 \3 V3 M9 N% c
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:231 C7 ?# ?; ]+ Z2 v: U' ^
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    - ^" a4 m1 {! }-------
    ! U7 G3 \$ A7 P( o( p: {2 e不好意思, ...

    0 Z9 f# v9 ]( K8 B, B/ H谢谢,算法应该没问题,就是最简单的线性回归。
    % s. n# I2 e4 k$ |3 W# D我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 # T: _; R- E* X4 Q
    雷达 发表于 2023-2-14 21:52# P, R0 n6 h4 }
    谢谢,算法应该没问题,就是最简单的线性回归。
    ' o+ x$ [! R* ]我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    , F" f2 i; \! m9 ?6 I* b

    0 G8 I. J+ s2 }* M3 f8 P刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    0 S8 ~% e  A, T$ }7 o1 y  e2 Y. p1 k( `& D
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ; v7 A; N; {! n0 |. t& `% v
    老福 发表于 2023-2-14 22:00
    ) J: k; m9 \8 b' E4 C; `2 \刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    & J4 i+ f* N: \3 F
    * ^$ I0 y) Z# j1 o( X) F9 A  H或者把b但的起点改为1试试。 ...
    2 L2 g+ ?% e6 y+ i; ]4 q, X

    0 M: x& P. q0 w% z你是对的。
    - g9 b0 c  z! R# k4 d# }: u# b去掉了随机部分
    1 ~) y4 S: ?  V/ @5 v. n#y = (x*27+15+random.randint(-2,3)).reshape(-1); \) Z! Q/ m! p3 n+ W6 R
    y = (x*27+15).reshape(-1)3 L" X9 J% p! U9 o

    * E  C" ^8 `8 u, ~( ^) v% Q循环次数加成10倍,就看到 b 收敛了
    8 h  }6 |$ A, y# G! [3 u6 Aw , b
    ) K- m2 S- q+ V; O8 x: F27.002620697021484 14.8261671066284184 I- y1 h. {" |/ M& ]8 w, M% Q9 K# `
    ; t9 d8 I5 |8 z4 |
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-7-9 18:32 , Processed in 0.050021 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表