设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3087|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    2 _5 R+ Q/ y2 ]9 L9 A2 i
    & g# {! ]- H* O为预防老年痴呆,时不时学点新东东玩一玩。
    8 m0 n3 l+ @6 dPytorch 下面的代码做最简单的一元线性回归:
    : o) v( j: p5 Y% ]----------------------------------------------  g+ V. E) \  ~
    import torch
    ) q6 R7 ]0 O, o# F& Z' h; m+ Yimport numpy as np$ E9 ^  c' w# i( M# O+ M
    import matplotlib.pyplot as plt
    ) @; a# o$ @6 r9 w$ Timport random
    ; n% I% O& h% d0 ]& n8 P/ _5 Z5 f  J# c* U
    x = torch.tensor(np.arange(1,100,1))- |' u3 l% V3 v, a; e- Q$ Q
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15  d, L" j+ [4 d/ l3 Q- z

    ; }  d) \! ^6 g+ gw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    8 Z4 h1 R( q3 ib = torch.tensor(0.,requires_grad=True)
    & x9 L! e/ z! l' k4 ?' x/ c$ ^& H6 C1 f' I( U
    epochs = 100
    # u, @% o, a7 e6 ]9 a
    ; a6 j* i' P, H# T& Mlosses = []7 i$ ~0 R6 y( W( M$ t
    for i in range(epochs):! I$ Z0 D7 D3 y
      y_pred = (x*w+b)    # 预测
    & f  k" }! J) T$ v; N) N  y_pred.reshape(-1)
    8 ?  r# F- U7 c 3 X& W  P% R9 q! v+ y
      loss = torch.square(y_pred - y).mean()   #计算 loss
      c/ B0 j$ z; q, d6 z; Z  losses.append(loss)+ G# ^, ~  f/ F( Q$ T) ~
      0 s# ]* V9 F" Y
      loss.backward() # autograd
    ! U" U  x; w; h; A  with torch.no_grad():  @  r. R5 ~+ y( u) W- ^/ s7 U' B
        w  -= w.grad*0.0001   # 回归 w
    " [/ P; O, Y7 e( X* y" [- \; O, K    b  -= b.grad*0.0001    # 回归 b
    + c9 v: ]. B! N0 d% w: `; U  w.grad.zero_()  & S% l8 z- W+ V1 Z
      b.grad.zero_()* q8 ]! o" ~- Y: q8 {

    : h: Y6 Z" }* Z1 ?' j# k8 pprint(w.item(),b.item()) #结果
    8 E" O; i4 b. d# I6 ~; D! W, q. P2 f, a  T
    Output: 27.26387596130371  0.4974517822265625- B7 T9 M/ Z+ g+ C9 L+ o
    ----------------------------------------------
    + F- n$ F4 X# k- ~) i最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    9 T. R5 N9 Z3 z; r# M) a! V高手们帮看看是神马原因?
    8 y: U1 Y3 M/ J6 n

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 5 w/ w4 h) @: m+ v
    4 f8 v' l5 O* r% I$ S4 z
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?$ F6 v+ i0 B! ~5 m& O
    -------
    ' f0 m: E/ x) H不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    ; e/ Y2 d0 k8 \" S+ w$ d2 J-------
    8 P7 u* r9 }; D0 t算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    1 E% `8 R: v" C/ {- ]8 A没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    # v: @- i, r: V8 m- [  X-------
    / a0 e! o% |% P6 n; k不好意思, ...

    9 ?/ G! |+ E5 c  [: t* ?谢谢,算法应该没问题,就是最简单的线性回归。
    0 m! w: E6 c; J$ @* [0 ~" g我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 5 g2 w" U( z; E- a2 p' P% `$ \
    雷达 发表于 2023-2-14 21:52
    4 K0 s" m/ L  u  m. j9 g/ n! D谢谢,算法应该没问题,就是最简单的线性回归。
    7 l4 s3 ?* L  G. o我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    : d# B1 k( n! \7 c4 C. S
    6 V- X" R8 i' o: O刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。' F$ L7 u; t: y; t2 |6 D
    / L; Y- I  o+ F( F6 w7 x3 e+ z
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    . V! N8 O/ h3 n* o/ U* \" v, B
    老福 发表于 2023-2-14 22:00& T, P% M1 ~- w. J
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。% [3 D! l4 G: S' U# |4 a

    9 W* m) i5 A2 a  f( D) P或者把b但的起点改为1试试。 ...
    + n# N2 P3 H/ I& ?* `1 U. B! g( P

    $ `* R! ]) U1 J4 S# z你是对的。3 {' L) p1 |/ f; g% @& K
    去掉了随机部分
    ! ?+ b& P/ A+ G; O, q#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    ; N% U% x% z" {y = (x*27+15).reshape(-1)3 c6 G5 j9 R9 `9 \5 P( {* t& E0 y

    3 T. o! m3 x$ f/ V2 W8 b循环次数加成10倍,就看到 b 收敛了
    ) G. c, N$ T/ Q. o/ U' H3 Rw , b
    & U, p2 ^% R# \4 s6 u& u7 c# n27.002620697021484 14.826167106628418: D& O! y& K0 s% P

    : N- h1 L* @! V; D和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-13 11:07 , Processed in 0.105755 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表