设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2490|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    ! K' h$ I5 X" n
    0 ~! [+ P" T$ B: I# T* P! s为预防老年痴呆,时不时学点新东东玩一玩。
    5 F' Z$ n& L. BPytorch 下面的代码做最简单的一元线性回归:  H4 [, h, A* h( t
    ----------------------------------------------8 ]2 V& @/ G! G, W+ I/ F8 s9 W
    import torch6 O" v! J1 b: ?! G! i
    import numpy as np
    % {5 J- @0 F$ t3 v% X6 `import matplotlib.pyplot as plt
    : f' a% \1 X7 Z# D( e% Yimport random
    ) B: G* ~3 i! [& u% y# O9 Y: e3 f. x2 _4 q: L9 d! r! f% n& f
    x = torch.tensor(np.arange(1,100,1))
    1 _' H% `$ C2 F! v- ty = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15, _- K- r* W/ Y4 ^) X$ b3 o" K

    " s4 S9 T1 M2 G- W" Z' vw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    ( u/ U# w4 h" n+ \3 pb = torch.tensor(0.,requires_grad=True)5 h( Q  y' N: M, f; I. o4 O4 W; X

    - m3 i% k3 ?$ ]2 x' F0 n, Uepochs = 100
    0 z/ v- B2 l, r' i  k# r# q- I2 |9 C1 l
    losses = []- x: O2 w' @* L7 Y$ s8 K7 a3 ~
    for i in range(epochs):0 U0 I& k8 A2 q/ F* A
      y_pred = (x*w+b)    # 预测3 q3 w2 G5 t( u5 Z. P" m
      y_pred.reshape(-1)$ d" N( s9 n+ _9 D
    3 J/ G# w0 A5 J
      loss = torch.square(y_pred - y).mean()   #计算 loss
      g& |' l( R3 x# [6 ~  losses.append(loss)/ f; L" @5 t5 [2 P% y  I/ ~
      ) R( K9 L+ D0 i1 v
      loss.backward() # autograd9 t6 p3 l3 J2 ]$ ^
      with torch.no_grad():' l) A- M+ b3 V3 h5 x; T
        w  -= w.grad*0.0001   # 回归 w
    ( u. k3 u! h1 A    b  -= b.grad*0.0001    # 回归 b * _" f. j; Z) d- V! }: R/ z
      w.grad.zero_()  ! G# B; L, u0 R! O5 q' Q1 h* g+ y
      b.grad.zero_()* X& n* ?& J# e# m6 [9 q* {6 n
    ! a# u5 [: _7 G- `3 Q3 V+ v
    print(w.item(),b.item()) #结果2 @$ S& O. Z! J

    + ~+ Z! D* M' b# K) vOutput: 27.26387596130371  0.4974517822265625
    5 \, D  E3 H* a5 w----------------------------------------------
    $ j: v+ n8 }4 n最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    4 N; X% Z3 c; S( @& d$ y( y高手们帮看看是神马原因?2 @$ P8 a4 U. U4 E1 D8 i. `$ {  G

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    ' m; W: }  }) v3 K$ Z
    ( X4 p5 M: S. w: i: [5 d" n没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?: H7 g! L3 b( G: l, Q
    -------. t; c+ i8 P3 _- T- N
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    # H9 ?7 E9 j3 k; u2 A: |-------
    & B8 O' i3 c1 v) l. |/ y* q算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23+ K4 s: v/ ^9 I8 A( T6 N
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    1 b0 M0 y/ F5 u-------$ g+ u& v' p1 o0 Q
    不好意思, ...
    $ D: v/ w4 S5 Q5 K' g
    谢谢,算法应该没问题,就是最简单的线性回归。- g' r( F9 B, Y) W9 ~( u6 E
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    4 n; ^3 {' z( P. ~8 B
    雷达 发表于 2023-2-14 21:52
    " _; L( V* k7 r$ [1 i  k0 Y谢谢,算法应该没问题,就是最简单的线性回归。" S  c% U4 Q( n) U- s3 [; B
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    # ]4 A+ b) {# E! B& ?. Y1 Y4 a

    5 Q9 |! s3 A- N刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    3 ?; U5 h) W3 \* W4 s7 S% T+ _
    3 N5 r; p' q1 y: b: _; I9 {8 ], c或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ( a3 {/ a0 L: T: A/ m% h
    老福 发表于 2023-2-14 22:005 e0 O2 F  B1 N) P8 O( A( i
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。! v/ A6 \( \3 G. q& z% Q: y

    ' K: ]7 |/ y& c" ?: n& ?$ D/ Q或者把b但的起点改为1试试。 ...

    . l- g2 v9 T. @0 K" X. M
    2 U' v3 U. G, s$ h7 v) A你是对的。
    & n) l' _7 p3 e5 Z0 i去掉了随机部分
    : }9 u* {  a. B6 Z4 s" U#y = (x*27+15+random.randint(-2,3)).reshape(-1): e/ T: u9 Y9 w/ ~
    y = (x*27+15).reshape(-1)# C3 Q  y) t7 F# u+ P
    % H7 M2 }9 x, v2 ]1 A
    循环次数加成10倍,就看到 b 收敛了
    3 r  J, f8 j* q: `" f5 k* i) U0 ow , b
    7 x6 l9 k7 f8 w0 k7 v27.002620697021484 14.826167106628418
    8 J8 j* C2 |( a$ ~% C5 Y, _& b) X1 d/ O
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-2-23 08:02 , Processed in 0.060605 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表