设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 1505|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑
    / B, p& T$ m) B' D* [; h! [% D5 e1 x) \) i; `% _0 c- l( U& ]. k
    为预防老年痴呆,时不时学点新东东玩一玩。, {% S1 S# |  u6 ?$ S
    Pytorch 下面的代码做最简单的一元线性回归:
    0 z2 N$ B2 y7 \1 L% u----------------------------------------------& W! s& p. [) H6 ~, P5 G$ I
    import torch
    ' w, w  w1 R( G- y) wimport numpy as np( J  P- p" q1 n3 i8 ^/ X
    import matplotlib.pyplot as plt
    4 S8 m" Z2 P  ^0 ^/ S* vimport random* n7 |  O% `) i5 m- Z, K

    ' i) U5 O+ @7 `+ f0 N( Q0 Wx = torch.tensor(np.arange(1,100,1))
    8 x+ u3 ]  T" m1 Z+ cy = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    : I7 Q6 G1 L" {2 \7 C+ ~$ }/ e$ d- k/ z& u" ?) V
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    2 h3 t- u6 f  {+ F( E0 \4 M5 Bb = torch.tensor(0.,requires_grad=True)  ~, t  Z" ]* z

    / {9 G. U  E) }8 E3 W' r5 ]epochs = 1008 B* {; U/ ]' n: k

    # ]1 P) t8 X8 K! Plosses = []3 y( m5 v8 y# O
    for i in range(epochs):
    " w5 A( h3 m3 g) O' J' E  y_pred = (x*w+b)    # 预测
    ) U9 x% v) t/ j' F; ]. t! ]  y_pred.reshape(-1)
    / `2 e: _6 W' V3 Z6 A* r
    ) m) ?$ [5 b1 ^# [  loss = torch.square(y_pred - y).mean()   #计算 loss0 q) P* K6 C; M# y
      losses.append(loss)
    - U- o  O' V7 }: p+ ^* B  9 \% k* n, n9 `
      loss.backward() # autograd9 \! v3 }# |5 H" ?2 k( K7 R2 w# e% i
      with torch.no_grad():: x' I2 N) P  P: j0 w: Z
        w  -= w.grad*0.0001   # 回归 w
    8 y2 f# \' P/ ^    b  -= b.grad*0.0001    # 回归 b
    / k+ o0 L! o9 z+ ?; u  w.grad.zero_()  
    # }+ N; O2 u2 S3 P/ z9 g; |  b.grad.zero_()/ B* {; J: `" {
    3 ~" A+ v4 w9 X& W6 p) Z& A
    print(w.item(),b.item()) #结果
    , ^& R5 r6 `% c5 K% d% h+ I! ^( o0 B7 D: ~  o5 t6 p, F
    Output: 27.26387596130371  0.4974517822265625, x$ L2 q# |; L! I
    ----------------------------------------------; j1 s9 S+ v( c( B% ~! G' R1 r1 ]
    最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
      B& R' A, ?: B9 O高手们帮看看是神马原因?$ k  r. O: P3 M. T9 F/ U

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    7 O( v" u* w) u7 k6 X6 K/ i, X! p$ v
    8 V6 V( M4 u, }2 x# [没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?# M) i, \& i# N8 u5 t( I- w
    -------
    ( b/ B( z( F9 {% l6 J9 N' t5 B不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    4 C8 ~5 T$ ?7 P# }* ?: m# a-------
    3 |0 u# \& _1 T* o+ p8 i算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23
    3 @) ?9 Y+ b6 Z% _' O% B: j/ K没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?! h3 y+ R( O, V6 p
    -------
    3 A) \2 W- c: S不好意思, ...

    4 y1 q2 N3 T% n谢谢,算法应该没问题,就是最简单的线性回归。
    9 x% B' A; W, d) h  c* m我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    ! M' y- M& U. r$ F5 N. i
    雷达 发表于 2023-2-14 21:52
    0 K4 a& t8 t* V) y- y谢谢,算法应该没问题,就是最简单的线性回归。: v1 l& ~  {0 k6 X  z
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    - D; G5 g9 e9 k) l& H; y

    + w1 b" j+ x7 l7 ]8 n9 d刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    4 Y% c. T/ D  E# N! N8 Y. Q5 y/ Y2 R7 L% v7 Q5 M$ k
    或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    擦汗
    2024-12-25 23:22
  • 签到天数: 1182 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 3 S  \' v  L2 \( E
    老福 发表于 2023-2-14 22:007 o- [! o3 A& q+ C
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    1 C; Y3 j9 E- l5 J! Z! Y/ E; c2 u( f+ L0 X
    或者把b但的起点改为1试试。 ...
    * O- d5 |' Z# C6 Q4 K

    * o0 x1 b; q: ?1 Y! [# m, g你是对的。
    + e$ D' N4 @( R5 P9 ~8 ]5 H去掉了随机部分
    5 h, C' d9 B5 }( Q  A#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    + R2 k7 H3 c, Y; Y  M0 _y = (x*27+15).reshape(-1)
    ; a# y8 i# s. M& X: e+ h0 B: t  {3 w# @; V- u
    循环次数加成10倍,就看到 b 收敛了2 b5 D: X' J1 w, l2 z
    w , b
    & ]/ \6 h! o- q# ?: }27.002620697021484 14.826167106628418
    ( W$ _- q) m' R7 C
    ; E. ^7 [5 v. D3 `和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2025-4-3 15:05 , Processed in 0.032282 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表