设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2703|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 8 ^) ]/ ?# e9 s# `, W( h
    ' A; R8 @4 z5 @
    为预防老年痴呆,时不时学点新东东玩一玩。  S+ x% L. |6 p: I
    Pytorch 下面的代码做最简单的一元线性回归:+ b( @& d; y  U5 C8 {' B" s
    ----------------------------------------------& F% O+ }' @! D; q
    import torch- u9 n! h3 m4 u/ {6 g2 v( @( o
    import numpy as np, @5 I2 }, T0 ?" S" e. J
    import matplotlib.pyplot as plt
    % U# ~8 y0 m9 T% Q+ O$ E% j! rimport random
    1 P$ `: s1 a2 |2 x1 ^* Y4 r; N7 ^9 a7 v2 l1 X7 `) n
    x = torch.tensor(np.arange(1,100,1))
    ! X8 V4 u8 a9 [y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    9 T$ j2 u1 N  g) E2 t/ D1 ]4 U; H- {: O) z$ `$ P
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b8 W$ c3 h; u* |0 [
    b = torch.tensor(0.,requires_grad=True)- |. W0 g+ R8 @! l# e

    3 ]" X2 g4 ?  g% H6 F" Mepochs = 100
    - X! a7 l: v8 Q4 K2 n$ U% f. W  |$ l+ Z5 e% @0 I9 i: [
    losses = []
    4 Z! g- f, n. o* rfor i in range(epochs):- f+ p, Q  {3 x5 D
      y_pred = (x*w+b)    # 预测- u6 m# Y. K- q3 v) O* ?5 ^
      y_pred.reshape(-1)
    7 w# l2 v  i4 g3 f* ^; m: T % @% V& O4 H* L0 @/ S( I
      loss = torch.square(y_pred - y).mean()   #计算 loss3 u2 A/ A  j- z( b; N6 T
      losses.append(loss)
    . ]$ V$ k( w. ^. h  % ^8 K7 S% d/ S) E
      loss.backward() # autograd1 k; i( x3 {8 _( ^+ K& c
      with torch.no_grad():
    2 l# e! q, W  T% u    w  -= w.grad*0.0001   # 回归 w
    1 F, L/ b$ \) Z5 \1 I    b  -= b.grad*0.0001    # 回归 b
    0 C% ^6 L5 r; P  w.grad.zero_()  ) [" Z6 }' X4 j& ^. p0 ~
      b.grad.zero_()
    2 S9 j2 [6 r8 ?  L* H4 c3 Q8 c- A1 k4 O9 L# V
    print(w.item(),b.item()) #结果
    & h0 w# v" A- t: x/ s* N! d
    1 a' o# L6 {7 K6 hOutput: 27.26387596130371  0.4974517822265625
    $ Q, }  H5 |3 J* h9 Z; z- j2 z----------------------------------------------
    8 T  n9 I5 e( ~9 F$ Z9 O最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    , O0 `- U, X% V高手们帮看看是神马原因?* q* {. ~0 u! z+ l. [

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 3 B: [/ x; |  x1 l4 w( [

    5 }# h9 b7 A7 j: f没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    $ d. c" B& r* ?! }' k$ w-------
    , P4 i) ^7 I6 ?2 [2 H0 N4 J不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。$ ~5 h( w* |& ]( p" O8 @
    -------
    1 v0 z1 N: g1 A8 T算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23( h8 [) J. Z0 |8 F6 n; Y
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?( f1 h% @& b! L' d0 ]  U3 i
    -------& |# v+ Q- q# z2 j# V
    不好意思, ...
      P( p$ D9 [' z& M  ]- e, a; g
    谢谢,算法应该没问题,就是最简单的线性回归。
    4 ~7 |" S0 t4 G/ Y9 X我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    $ F7 T9 C1 s# n
    雷达 发表于 2023-2-14 21:523 C- W5 p5 R2 f1 m; l) ^7 F
    谢谢,算法应该没问题,就是最简单的线性回归。1 O" l/ y$ F% Z* B
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    ) K9 d- h6 G; Q5 t" X3 g
    8 N' P5 \8 ]& {! z: ~( e刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    8 B% P1 N0 ^4 L% K  R- H6 n6 S/ ^4 h/ C* q
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    3 R+ m- a6 k  F" G! H
    老福 发表于 2023-2-14 22:00* [3 k* R6 s) e9 G" t7 t9 `
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。' Q: s/ W' b/ V5 V, w4 @, ]

    * o  u1 Q; V- F( A+ `或者把b但的起点改为1试试。 ...

    8 {5 y8 W& `& p
    , `; i" W* N5 }你是对的。
    % C  K! O, x! T; l8 d6 s去掉了随机部分
    ' [/ q; B$ ]0 k# t; j& }0 {" ?7 {#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    8 h% n. {$ s6 ^y = (x*27+15).reshape(-1)  n) _  x9 c+ E# [1 \+ D, L9 l0 U
    1 f: q1 {+ ^: v+ q$ o( e; U
    循环次数加成10倍,就看到 b 收敛了
    * @' z: X  R4 i/ zw , b
    3 K- C$ `; X. B& L5 x7 ?27.002620697021484 14.826167106628418
    0 d9 C& T) {+ c: S. U0 }9 J" [
    - t! q) S+ o3 M/ z& r: f$ x和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-4-6 09:15 , Processed in 0.066812 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表