设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1896|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 - n( A* y& h( g  F

    + o9 |& H3 x" M. Q  e为预防老年痴呆,时不时学点新东东玩一玩。- h: C% x9 _- s/ q9 h; |: m
    Pytorch 下面的代码做最简单的一元线性回归:
    2 l5 t2 N. D* ^: U! ]----------------------------------------------: h. Z( T. K1 H+ g% T# r3 ]
    import torch
    7 L/ O" f# X# H  {$ Qimport numpy as np
    $ B" r* k( m1 X6 F3 q9 Fimport matplotlib.pyplot as plt
    2 o% g  x+ l" v/ h* r, d) g2 jimport random
    ' A7 ?( {: u* C/ q2 O# }( _
    & \8 Z# Y$ z  s5 f7 o; |1 s3 ?, \1 a7 }x = torch.tensor(np.arange(1,100,1))
    % _" Q5 ?( n% jy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15: U8 }- [" Q) N# }0 H9 O

    ' a. I9 Z: k+ c& Mw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b. n/ L3 l) T3 ~$ R4 v& ~/ ]
    b = torch.tensor(0.,requires_grad=True)) V  c% x0 }0 S9 }( K0 R/ R9 I

    ' q; q7 `- i( ~$ ~# [epochs = 100
    ' C# r7 ^5 L& f) ]5 F0 u$ M& O
    , |3 [4 t* w' Z# {' R: l# ilosses = []
    / H$ F" T. ?8 z6 m. c, ufor i in range(epochs):
      i8 S! R* v" }& s  Q/ \( Y  y_pred = (x*w+b)    # 预测
    ' X9 P: ]0 q6 k5 Y  o: G  y_pred.reshape(-1)& @# u& T* ~: T# _* a+ w. m
    $ I) i$ J. H. c) ]4 E$ s
      loss = torch.square(y_pred - y).mean()   #计算 loss
    ! r  k' W/ u6 g  \  losses.append(loss)
    * ~6 Q- E+ g0 m2 r  0 f' o6 {3 L7 X; w. @
      loss.backward() # autograd
    1 o1 [7 a5 E; X- c* v4 K  with torch.no_grad():
    $ J8 G' g# P8 {' @4 g$ T' L' w    w  -= w.grad*0.0001   # 回归 w
    ' h7 e$ d" `( E, N% F: y1 O    b  -= b.grad*0.0001    # 回归 b
    # X. R: j$ f3 a  w.grad.zero_()  & T2 H  S3 Z& }
      b.grad.zero_()& f5 [9 l) M) O3 |' R7 W
    6 o% T- z& M6 T$ {9 _/ p# k
    print(w.item(),b.item()) #结果
    4 A3 h) E! b  d8 d+ i4 V. Q0 E5 M/ M
    Output: 27.26387596130371  0.4974517822265625
    # g2 [3 u8 M2 t( S0 t- h----------------------------------------------: }9 O/ m# X& e; a9 ~% ]5 g
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    5 B. K+ M) U. ^7 |5 H6 {5 _高手们帮看看是神马原因?$ U$ I" O3 E7 X& N; K# T

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 7 g, C/ q  l8 f& B
    ; s# U) j! x) Q, h& X: V6 H; s# z) {
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?: g/ C  |' ?2 I' C: a7 _
    -------8 K0 t$ s# A( K( g' Q; K7 X4 ~
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。. l* @, N) q. k" x  h
    -------) |& K- p2 E. Q; v9 k8 Y; u2 T
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    2 k3 \0 A! N' [. @, W没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?5 }$ v" _6 `7 f5 Z0 ^
    -------) Z$ x4 ^5 t" D* X" U
    不好意思, ...
    $ M# v2 U; _" Z$ a. t
    谢谢,算法应该没问题,就是最简单的线性回归。/ {# H9 [3 F$ F/ O% n  r1 V, m
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
      |  r, N' U$ ]6 V3 z1 b
    雷达 发表于 2023-2-14 21:52
    4 v4 r/ Z5 N6 ]3 J1 q! z3 z谢谢,算法应该没问题,就是最简单的线性回归。0 D1 \6 r) ^4 @/ c* N1 f# X
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    $ {6 a" I3 Y* a0 j0 D: W, M8 M: A+ _8 F+ o' O4 x4 v
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    6 N( P6 [) B3 ~" ~  {
    / ?' G, _4 c& g+ `或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    6 ?  k* t- U7 x0 r4 K
    老福 发表于 2023-2-14 22:00
    * I7 ~3 q) x, x刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。. C2 |& E2 J& b* P% P2 }

    1 w: t* t, r: s. Y( F! `& G或者把b但的起点改为1试试。 ...
    - u# l6 I/ F9 B/ ?

    8 O/ }. n, C, i1 b, v. j+ o你是对的。
    6 R3 g8 A" I, n9 T/ ~去掉了随机部分
    / K! |; z% w& u& C. Q#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    " @. P) z5 t' M! f! _3 }* L: @y = (x*27+15).reshape(-1)
    6 `; }# o. [. M2 J: J. E6 m, C7 F+ }( G! O2 t' r
    循环次数加成10倍,就看到 b 收敛了, N2 K! f( O! a0 g3 [6 T: N
    w , b
    + R# c  ?  a5 y27.002620697021484 14.8261671066284182 Z3 [" g2 \& z2 u

    + ?  P+ z, q$ Y5 o+ O和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-8-22 01:49 , Processed in 0.044028 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表