设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 2333|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 , @) b$ f- V- R' N8 d2 E! T/ [2 _
      q3 S3 G" S( o, t( L
    为预防老年痴呆,时不时学点新东东玩一玩。. S! w' n+ A( X5 f, z2 z
    Pytorch 下面的代码做最简单的一元线性回归:
    0 t5 \" W5 `) w) Y: u----------------------------------------------
    & {; h% n- z5 `- B( s# t- ^% {import torch
    $ V9 O9 o. n$ U& G2 jimport numpy as np
      s9 E: U1 C0 T% z0 ximport matplotlib.pyplot as plt' J0 D" H3 t# [7 K% ]5 P" |/ t: g
    import random2 n" j! U: r* v1 T0 ^

    4 v4 t  m1 |3 h' @! @$ |4 l$ N" Ax = torch.tensor(np.arange(1,100,1))1 j- m- {! R! s  h
    y = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=156 V- p/ N- k9 m8 y8 E0 i- r8 |
    ! S$ U7 G/ d) y9 [
    w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b
    ; M0 Z- q* a% c! ?b = torch.tensor(0.,requires_grad=True). c5 s$ x/ S, w

    2 @1 e" ?2 a" Jepochs = 1005 C& N; v4 k5 a! G% B" y

    8 A- W. i+ E. p. m& jlosses = []
    1 }. \- Z% W6 j# [/ E, D8 Sfor i in range(epochs):8 e7 v3 \/ X( d/ z
      y_pred = (x*w+b)    # 预测6 t2 Z0 d) v9 V% X
      y_pred.reshape(-1)
    0 W. Y3 y) S0 {' X $ G3 K  |& d$ |! S2 x% c
      loss = torch.square(y_pred - y).mean()   #计算 loss0 u4 E+ W$ X2 @
      losses.append(loss); Z  L/ M6 {8 ~/ ^2 Z/ n2 ]
      & y3 t  f0 S; ?/ u! K- ^/ z
      loss.backward() # autograd
    # l) n$ z! g/ u  with torch.no_grad():1 k% N: w4 D4 _" D3 _% |
        w  -= w.grad*0.0001   # 回归 w
    - V0 |. G6 ?/ B4 _5 {* N: c/ H: {    b  -= b.grad*0.0001    # 回归 b & [! s) x# T$ L3 M, `+ m: P
      w.grad.zero_()  
      u' U( c7 F! Y" F6 Y/ g  b.grad.zero_()0 q# P) t& h/ \
    , @3 P& u2 v+ }1 o) c' N' p
    print(w.item(),b.item()) #结果
    ; t- c# C/ L8 e4 d, F1 R% P; ~9 p  l- n+ Y" c- z
    Output: 27.26387596130371  0.4974517822265625
    ' M, Z6 D) @4 E* E----------------------------------------------
    ! V( k0 Z# M' P; M* R' U最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。
    6 w; Z4 d& A4 p6 x3 e6 v3 x高手们帮看看是神马原因?
    9 g2 E4 p5 F6 w( l5 I

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑 : M- k' y  {* \! ^
    : R+ q" Q) r, w" O5 K' ^- C
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
      l) J8 y- |8 X- n! V-------
      S8 `; e; F1 i, T# T( [不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    7 G9 G2 ?0 y; l5 [$ V, n0 l2 n-------
    + ?* W7 B) ~, Z6 i3 U4 M- @. g算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23, j5 C7 Q) S' e: i
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    # O9 X5 j" J) r; w" r-------
    # k6 Y) u8 l# k不好意思, ...

    2 C9 x! Y, d5 V, i; W谢谢,算法应该没问题,就是最简单的线性回归。
    ; _' u% q# U* M5 \+ L1 f我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    / f5 Z  R  Y: t$ {
    雷达 发表于 2023-2-14 21:523 Y  Y) T' R- i' q' ]' ?
    谢谢,算法应该没问题,就是最简单的线性回归。3 M! j7 O' n8 _1 o
    我特意没有用现成的工具,就是想从最基本的地方深入理解 ...

    & _9 H* _1 w) G/ Z
    . L! ]5 R" U4 v8 ^  A. L刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
    % o8 _: \0 P' P* Z% p4 g8 x% R) P5 y
    - s5 L5 s0 L2 J或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑
    ; ^2 F5 {% h4 q+ _% R6 T
    老福 发表于 2023-2-14 22:000 w& E  v$ }! c; o# z
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。
      ]& L4 M% v+ K  V( V
      E' u( W9 a5 d4 g1 I或者把b但的起点改为1试试。 ...
    1 e, l. p- q% P! P( U* P; b
    1 t. C5 G/ s7 y3 \6 r8 s/ i
    你是对的。
    , M5 r3 j! B' d. K5 x% Q. G去掉了随机部分! n4 w1 s; I* M' K- b6 g
    #y = (x*27+15+random.randint(-2,3)).reshape(-1)
    " B. r; X: T) o% C' o. J% ~y = (x*27+15).reshape(-1); S, e+ R/ @1 d) y- Q
    7 F+ @+ ^. U$ b- G
    循环次数加成10倍,就看到 b 收敛了# N: C# o3 g0 o
    w , b6 e: u* k. B7 }( ~0 o
    27.002620697021484 14.826167106628418; W: v3 m- I3 L7 B7 S
    % ^& u3 C* h9 i1 U/ o
    和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-1-15 03:24 , Processed in 0.031421 second(s), 22 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表