设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1771|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 7 T. X6 _$ }3 g: f% d

    ' b* k7 D% v# N% |  f" g为预防老年痴呆,时不时学点新东东玩一玩。
    9 n  T7 u! y0 _  n4 W* fPytorch 下面的代码做最简单的一元线性回归:
    ' m3 A! i5 I) ]3 \- {2 d4 [, X$ i----------------------------------------------
    - V0 _3 }. X) ]4 a  }4 \import torch
    2 m! }" V" u1 Z% R/ Dimport numpy as np
      q/ i( h# k0 s5 e$ o) b8 uimport matplotlib.pyplot as plt
    5 ~' \; J9 ^# M5 w1 aimport random
    # h! o, u, O8 i+ L! S  k) S
    ; s4 }8 t5 _8 t  T0 x/ Mx = torch.tensor(np.arange(1,100,1))
    - L) v/ y. {9 Xy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    / g/ b4 Y- s# T
    / k6 L8 w/ Q2 t7 _" X$ ?0 D% Vw = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    " `+ m' v% m0 Q8 t" z& ob = torch.tensor(0.,requires_grad=True)" X+ w3 I9 w4 }; B8 v# B$ v8 l
      n1 Q- c$ K2 s- h- M
    epochs = 100
    0 G( \2 F: k; S5 y: U  f
    ' k( B6 E5 `5 r( C( tlosses = []
    ; c6 i1 m+ D4 P2 Y5 x( Qfor i in range(epochs):
    ) [- k: i6 L& ^: u3 z" @7 G# ?: u  y_pred = (x*w+b)    # 预测
    0 i4 _$ b4 [4 N1 N* q% }  y_pred.reshape(-1)
    / q1 I# v9 K+ s$ R2 p/ [
    & Z1 _3 A9 }& G: u6 N# [* r" o  loss = torch.square(y_pred - y).mean()   #计算 loss
    9 D6 M+ \/ O2 b% D  losses.append(loss)" f1 L) n+ r$ A! j. U) i$ m
      
    2 \) E# |# W, b  loss.backward() # autograd
    4 q: R, F; M" V) d7 t# t5 u  with torch.no_grad():
    : h) R" u0 C" A6 G5 T$ j6 ?3 L    w  -= w.grad*0.0001   # 回归 w- l2 N) ?3 Z2 ^6 Q/ }
        b  -= b.grad*0.0001    # 回归 b . q5 h5 C: i% ^% x. y( P
      w.grad.zero_()  
    3 s/ n) n8 q. k' B  b.grad.zero_()3 w9 c. d- u' V3 s" `3 q7 j

    ) P$ t) y9 h" X$ {print(w.item(),b.item()) #结果
    7 j; b, Z' @5 h# l2 U+ o
    , a8 ^- U& D' b6 DOutput: 27.26387596130371  0.4974517822265625
    7 H1 n9 o/ [5 _% x- `" v& o----------------------------------------------! I& g# a% [! w( p# j) t
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    4 r: n1 n& W$ c+ t$ [1 `4 C; Q: E* y高手们帮看看是神马原因?; H& K" c" m/ n1 x0 |

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 " R' @0 A4 E5 o9 l  v$ i& z

    3 Q4 y5 \- a! ]4 M1 }. \没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# A; {1 y& S3 L/ n! t  z/ J& N
    -------/ O/ V& e% E3 O- |6 o& h
    不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。! a3 H! e( @+ z9 L
    -------, N$ o8 Y9 p% a; u6 l! N) E8 R
    算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    # U5 @6 m% o, H* [. q( F  c0 n没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?1 \& g8 q" f, r; Q$ K; B' q; E
    -------' H8 d7 ]! s$ X' J
    不好意思, ...

    ) D7 ^! z5 Y- r* k% [- U3 q% j: L谢谢,算法应该没问题,就是最简单的线性回归。9 K# L$ G4 ~5 \* e
    我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑 3 U; e$ _# S( `0 B( ^
    雷达 发表于 2023-2-14 21:52% z  O4 |$ R6 ~$ S3 T4 p
    谢谢,算法应该没问题,就是最简单的线性回归。( C; |1 D- _, R7 _8 ~
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    ' U9 w3 M; {+ _! d
    $ q3 f3 c% V; ?
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    ) S% ^( ]: f/ `" h6 m3 \( X3 X
    6 o+ l, Y2 e) g. w7 o5 o或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 4 ^) n$ g% l# ?+ C) t7 O
    老福 发表于 2023-2-14 22:00
    2 l" a9 ~* R4 h# e9 `# v" y刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。- b8 S& _: u* `+ K

    9 O& C! v+ s7 k- z/ c& a& {/ J# g! A或者把b但的起点改为1试试。 ...
    ) E8 M. e/ o2 u$ E5 e5 n
    " _$ g" [4 _4 s7 O; B1 k; `/ N; F
    你是对的。" t5 F2 R7 Q3 P- Y3 _7 E* L0 f
    去掉了随机部分
    : G, ]7 t) @7 x#y = (x*27+15+random.randint(-2,3)).reshape(-1)) J8 o9 z$ W2 V/ w
    y = (x*27+15).reshape(-1)6 @8 [0 I- L0 E+ ?9 J( }
    # S+ R5 `1 `8 m) t  E2 H1 G" j9 @9 K* r
    循环次数加成10倍,就看到 b 收敛了
    4 L3 l# k0 g) ]) F8 ~w , b
    2 R. ~; l. n+ b& q27.002620697021484 14.826167106628418
    * B. J/ ^1 E& H" Z1 j% x5 E, |! f" c5 m5 ~
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-7-11 04:00 , Processed in 0.042364 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表