设为首页收藏本站

爱吱声

 找回密码
 注册
搜索
查看: 3032|回复: 4
打印 上一主题 下一主题

[信息技术] 继续请教问题:关于 Pytorch 的 Autograd

[复制链接]
  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    跳转到指定楼层
    楼主
     楼主| 发表于 2023-2-14 13:09:28 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
    本帖最后由 雷达 于 2023-2-14 13:12 编辑 . N* n, e' g$ I, ?( u; w& o
      z% t9 |- f% {" J- k7 n
    为预防老年痴呆,时不时学点新东东玩一玩。
    3 T. k2 c/ {9 T9 k, vPytorch 下面的代码做最简单的一元线性回归:
    . k0 z' o9 L6 h7 ~7 H. P9 Q----------------------------------------------
    . z9 r/ S; V$ V2 b. Qimport torch$ n9 B8 G8 [) m9 B6 M1 r, c
    import numpy as np# q6 x9 v/ `9 H
    import matplotlib.pyplot as plt
    1 k1 I, y$ r0 M6 Y5 n* N$ Bimport random, h: W4 L; V$ W% Z4 a' u

    8 m& c' W/ I  u/ o& mx = torch.tensor(np.arange(1,100,1))
    ( G- c  ~0 o' _. l8 R5 Ky = (x*27+15+random.randint(-2,3)).reshape(-1)  # y=wx+b, 真实的w0 =27, b0=15
    / s8 l- S' U/ d# q3 f' X
    + ~# `# a! T- }3 \" ^w = torch.tensor(0.,requires_grad=True)  #设置随机初始 w,b6 q3 \+ _7 k2 q) ?8 M, A- |
    b = torch.tensor(0.,requires_grad=True)
    . h1 N( K8 h8 ~$ D9 v: r$ T2 P' E# J3 W/ w
    epochs = 100
    6 X+ a8 G. t- z8 C1 R: c* d' M/ B
      q: a# U: y( q5 y" g8 Y6 [. u0 Vlosses = []! ?% r  r1 c0 A; \) Y
    for i in range(epochs):
    / C2 Z4 d0 o0 [- q* [  y_pred = (x*w+b)    # 预测
    & I& P9 L, ]; X. @% v. `  y_pred.reshape(-1)
    5 C/ w- ]- Y, v% n$ ]# J" _5 M. I
    + W" I: M( ~! V, Z% [( B$ M  loss = torch.square(y_pred - y).mean()   #计算 loss
    ( X' z8 x2 u) C4 ~+ m  losses.append(loss)
      A) ~! f. S5 L8 ^  
    # a5 r/ `4 t  X9 r$ d  loss.backward() # autograd3 o0 L2 |. h4 [: N$ {
      with torch.no_grad():% A4 t6 @9 J6 d, L& @3 x
        w  -= w.grad*0.0001   # 回归 w
    6 ~/ O6 d, z3 ~# P$ V; N    b  -= b.grad*0.0001    # 回归 b 6 b3 ]8 Z% i- ^/ k
      w.grad.zero_()  
    $ q8 G! O0 j6 ]' W  b.grad.zero_()
    9 B; j4 l: Z: z- {7 w
    . \) ^8 ^& \: e8 Tprint(w.item(),b.item()) #结果
    / ?# s3 r9 G0 _( D% y/ P( K
    % S3 R( Z! n6 G! o3 DOutput: 27.26387596130371  0.4974517822265625  f& M2 C$ {) k+ e( O2 L
    ----------------------------------------------
    # P- S5 K' G8 r1 H$ `# _最后的结果,w可以回到 w0 = 27 附近,b却回不去 b0=15。两处红字,损失函数是矢量计算后的均值,感觉 b 的回归表达有问题。+ m% H+ V; @5 b$ T0 O
    高手们帮看看是神马原因?, l6 `. ~- R0 ]: T

    评分

    参与人数 1爱元 +10 收起 理由
    老票 + 10 不明觉厉

    查看全部评分

    该用户从未签到

    沙发
    发表于 2023-2-14 19:23:02 | 只看该作者
    本帖最后由 老福 于 2023-2-14 21:58 编辑
    : C% a# T, J- j0 ~& G8 w7 {! ^
    ) X& ?/ N$ ?; A, A没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?
    0 Y0 M- {  ?, z5 v* ^: Z* b-------
      S5 ~' D0 ^& K9 t8 N3 `不好意思,再看一遍,好像你在自算回归而不是用现成的工具直接出结果,上面的评论只有一点用,就是确认是不是算法有问题。
    / @3 A3 v0 F6 K-------
    % U8 a. m! F  j0 T( i算法诊断部分,建议把循环次数改为1000, 再看看loss是不是收敛。有点怀疑你循环次数不够,因为你起点是0, 步长很小。只是直观建议。

    评分

    参与人数 1爱元 +10 收起 理由
    雷达 + 10 谢谢建议

    查看全部评分

    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    板凳
     楼主| 发表于 2023-2-14 21:52:57 | 只看该作者
    老福 发表于 2023-2-14 19:23! x9 \( T7 X' y; z# H. N
    没有用过pytorch,但你把随机噪音部分改成均值为0的正态分布再试试看是不是符合预期?  r1 B: E6 O+ d+ a% a
    -------5 ?6 U. Q4 M2 x' h
    不好意思, ...

    , K* W+ E$ j, v谢谢,算法应该没问题,就是最简单的线性回归。
    ; W  O2 i. E3 r* d我特意没有用现成的工具,就是想从最基本的地方深入理解一下。
    回复 支持 反对

    使用道具 举报

    该用户从未签到

    地板
    发表于 2023-2-14 22:00:48 | 只看该作者
    本帖最后由 老福 于 2023-2-14 22:02 编辑
    : x; U7 A' T  ^- d
    雷达 发表于 2023-2-14 21:52
    ) S- D: H- ~3 y% k+ |0 x9 \5 R# Z谢谢,算法应该没问题,就是最简单的线性回归。
    . t' e5 m$ l$ ^' T4 z我特意没有用现成的工具,就是想从最基本的地方深入理解 ...
    . u  A3 G5 ^! t% {- O8 y" L& @9 k4 ~

    3 j! \* g! h* X' `刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。8 l4 _2 p) j0 L5 Q

    ' X/ u0 @; L$ {2 y: q% _或者把b但的起点改为1试试。
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2025-9-22 22:19
  • 签到天数: 1183 天

    [LV.10]大乘

    5#
     楼主| 发表于 2023-2-15 00:25:26 | 只看该作者
    本帖最后由 雷达 于 2023-2-15 00:31 编辑 ( w: B9 Z0 t, l( c
    老福 发表于 2023-2-14 22:004 }5 s6 |* m% V! }3 Y, P; Q! `; ~; M! t
    刚才更新了一下,建议增加循环次数或调一下步长,查一下loss曲线。! W$ F* l# g( y, ~5 e" }- V

    " w- y6 ~) s) n( H) y或者把b但的起点改为1试试。 ...

    4 ^1 |) V, z/ S* Q6 b! [0 L' A0 w8 W1 r  q5 Q9 {
    你是对的。0 j$ U, o4 ^& F( t4 [
    去掉了随机部分
    ; v1 ?0 k; Y/ h: e; W#y = (x*27+15+random.randint(-2,3)).reshape(-1)
    0 O# ~& v. W8 @2 t; w: Py = (x*27+15).reshape(-1)
    ( U( x" \/ j' R8 z  V% }+ ?+ r5 P; x  z. j
    循环次数加成10倍,就看到 b 收敛了
    : m0 R6 A. i. E. ]8 A; g! sw , b( E, p5 C9 A: U) G) I0 y
    27.002620697021484 14.826167106628418
    # N5 a# h6 t! Z# f8 H. N6 i- q; O- B
    # g; }! X; r, ]和 b 的起始位置无关,但 labeled data 用 y = (x*27+15+random.randint(-2,3)).reshape(-1) ,收敛就很慢。
    回复 支持 反对

    使用道具 举报

    手机版|小黑屋|Archiver|网站错误报告|爱吱声   

    GMT+8, 2026-6-4 10:50 , Processed in 0.057688 second(s), 18 queries , Gzip On.

    Powered by Discuz! X3.2

    © 2001-2013 Comsenz Inc.

    快速回复 返回顶部 返回列表