设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2638|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    4 R4 f. H6 N$ }" t! E9 ]+ \
    " V0 ]. e$ Q- ~  u: Y+ t为预防老年痴呆,时不时学点新东东玩一玩。
      s, ^& Y% T! GPytorch 下面的代码做最简单的一元线性回归:
    9 h# N" t2 e% F" q. n! ~+ |----------------------------------------------! j) v. ?) a% E1 G3 _
    import torch
    ! _1 ]% U, {4 I% [* G9 E  s7 @7 N: Dimport numpy as np) W1 `5 J9 q; b' A* p6 Q4 w
    import matplotlib.pyplot as plt3 H/ R# C0 Q( C8 s/ G2 x+ c
    import random. S7 q$ y# p$ Y
    % l- _6 x7 P1 E; A; K
    x = torch.tensor(np.arange(1,100,1))( M% o5 c1 [4 S
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15+ T( S5 {5 n3 `5 C

    7 H; Z9 E! @- j7 f1 Qw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b3 z) p& i* a' r  o
    b = torch.tensor(0.,requires_grad=True)3 X6 {5 f; ~, R6 p, V

    $ g; s, t9 F+ q  v3 d3 L1 Zepochs = 100% N! H7 n. ?7 u; ]  {0 f3 I

    ( o" I7 v. B1 `* y- ^losses = []9 B+ C, A6 p& F5 m) y8 r" Z
    for i in range(epochs):
    $ F6 W: u8 G$ [; ?2 \' m! ]) \  y_pred = (x*w+b)    # 预测- }5 i; w( g+ N( U. |
      y_pred.reshape(-1)1 F5 v& y2 u( L. u- k

    3 r5 A# p+ v7 f& w" U  loss = torch.square(y_pred - y).mean()   #计算 loss
    . \1 r! J) B8 W5 \! H  losses.append(loss)  m- F2 l  Z5 T7 m, d& T, a) S# D
      
    ! W/ W7 w# F: m  loss.backward() # autograd4 F, I  b" l" z0 O3 U0 x' J
      with torch.no_grad():
    5 e$ ^) {) _; {! ]    w  -= w.grad*0.0001   # 回归 w/ t: Q* a1 k  V
        b  -= b.grad*0.0001    # 回归 b & N! K, a4 D6 o! {- w3 v+ P& ?6 Q6 y- T
      w.grad.zero_()  
    ( ~0 v$ b+ p8 L; g# j, O; g  b.grad.zero_()8 Z1 n5 x( @% ?* L8 k/ E5 o2 b; A
    * h2 `4 [. ]/ y4 o
    print(w.item(),b.item()) #结果
    6 J4 i* t7 t* I% J1 L2 l% [1 S, L$ ?6 R' j! v+ k0 Z
    Output: 27.26387596130371  0.49745178222656254 B3 s! T6 j* k7 M+ E
    ----------------------------------------------9 K' x4 ^$ I9 ]' ^  L: v
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    * b/ J& X& f' N高手们帮看看是神马原因?# A, C3 A; d: Z) S  ~! B% h% j. T

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    $ t* j+ L3 b6 g5 H. J: J+ w4 |$ P+ R( a9 d# }& \  P
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    # J0 r: E7 d2 X-------
    , F! E& D9 _$ \/ u3 c' u, y不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    6 o/ I) i# X# W  J6 V( A3 j8 w-------' V- V1 f' E% K* a
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23* x) D7 ^" S; Y" X1 s
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?/ q8 W' Y) L$ v# W* k- S
    -------
    7 n( v! |$ t8 u1 S8 ~# y不好意思, ...

    3 @' o6 r# y' Z5 ~谢谢,算法应该没问题,就是最简单的线性回归。
    - y/ w5 k  t6 D) S9 M3 K我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 0 L1 M3 E, y# C. k
    雷达 发表于 2023-2-14 21:52* a+ A3 f) ?4 e
    谢谢,算法应该没问题,就是最简单的线性回归。+ ]5 I. s3 E) b: C5 m: w
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    % @6 L8 ?5 H9 S" C" b; ?5 C
    0 j0 f; L5 n1 g
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    8 V8 d) @, m6 z  t% u) M4 z
    7 X: Z  X8 M0 ~( S7 V; x或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    0 [, _9 A: }. n
    老福 发表于 2023-2-14 22:00  ]0 W3 `  Q& l+ b5 |" B
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    2 P4 K' u% E) k0 w# d* `9 _6 ^; B" t9 ^/ _! M3 r9 S' S
    或者把b但的起点改为1试试。 ...

    2 @+ y) ~3 S9 T* n. Q- E% E' i- K2 ^+ J4 }! l2 q
    你是对的。
    3 J' t6 u+ h$ r- X; {去掉了随机部分2 B8 D5 u+ x8 i. ~( v) ^+ w, `
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)
    & m2 c* t1 o* My = (x*27+15).reshape(-1)3 E" f' q+ T/ T- z; F: w$ C0 k
    * P. O% z, S. H) e2 M) j0 @
    循环次数加成10倍,就看到 b 收敛了0 h4 p) C$ f, Z3 q/ v
    w , b3 x/ l3 I$ Y+ L6 y: O% A
    27.002620697021484 14.8261671066284189 @# p7 ]# n0 Z
    4 K+ \) T" s. q4 D5 }& m
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-3-20 11:40 , Processed in 0.057288 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表