请教个神经网络的问题 - V2EX
V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
zxCoder
V2EX    问与答

请教个神经网络的问题

  •  
  •   zxCoder 2021-12-30 09:47:05 +08:00 1944 次点击
    这是一个创建于 1463 天前的主题,其中的信息可能已经有所发展或是发生改变。

    在改一份代码,网络结构是一个 bert ,一个 dropout 层和最后一个 linear 层

    我在把多分类改成多标签分类

    就改改数据的输入输出格式,加个 sigmoid 激活函数

    但是结果出现了问题,发现 loss 一直在降,但是 acc 缺掉到 0 ,一看才发现,练着练着,最后 linear 层输出是一个特别小的负数,而且越来越小,然后经过 sigmoid 后全部预测为 0 ,所以 loss 算起来还是很低

    这种情况可能是出现什么问题呢

    23 条回复    2022-01-06 20:41:44 +08:00
    zhucegeqiu
        1
    zhucegeqiu  
       2021-12-30 09:56:39 +08:00
    预测全为 0 为啥 loss 就低,样本不平衡?
    zxCoder
        2
    zxCoder  
    OP
       2021-12-30 09:58:54 +08:00
    @zhucegeqiu 比如一共有 1000 个 label ,但是一个样本可能只有一两个 label 是 1 ,其他都是 0 ,直接预测全 0 ,用 BCE 计算 loss 好像就不高
    sss3barry
        3
    sss3barry  
       2021-12-30 10:27:15 +08:00
    样本也太不平衡了吧,感觉换 loss 也不一定有用,下采样负例稍微平衡一下?
    ipwx
        4
    ipwx  
       2021-12-30 10:30:33 +08:00
    @zxCoder 你已经说出了原因了。神经网络做的是

    1. 最大化 l(x,y) 在训练样本 p(x,y) = p(x)p(y|x) 上的期望。
    2. 几乎不可能找到全局最优解,只能找到局部最优解。

    在你的例子中,998/1000 以上的训练样本都是 0 。也就是说 p(y=0|x) = 0.998 。如果你是神经网络要学习 y=f(x),那么是不是很容易找到 y=f(x) 恒等于 0 这个局部最优解?
    ipwx
        5
    ipwx  
       2021-12-30 10:31:17 +08:00
    p(y=0|x) = 0.998 ,约等于 p(y=0|x) = 1 。如果是后者,显然全局最优解就是 y=f(x) 恒等于 0 。
    jdhao
        6
    jdhao  
       2021-12-30 10:36:58 +08:00 via Android
    使用 weighted cross entropy ,增大类别少的样本权重,或者想办法把各类数据均衡
    ipwx
        7
    ipwx  
       2021-12-30 10:39:37 +08:00
    @jdhao 这个例子里面,imbalance 太严重了。说句难听的,只有 2 个正例,分分钟过拟合。

    这种问题建议先 unsupervised training ,用得到的 hidden representation 做更稳的分类器。
    zxCoder
        8
    zxCoder  
    OP
       2021-12-30 20:25:51 +08:00
    @ipwx 检查了一下,一共 120 个标签,5 万多个训练样本
    标签分布大概是
    [656, 122, 380, 217, 2276, 794, 149, 458, 248, 626, 102, 135, 396, 127, 141, 3431, 161, 687, 352, 157, 1322, 1887, 423, 313, 994, 880, 125, 138, 295, 473, 263, 106, 1153, 221, 297, 1643, 150, 3083, 218, 403, 277, 258, 524, 5553, 246, 128, 205, 735, 115, 1089, 159, 201, 438, 1984, 536, 148, 666, 251, 103, 499, 263, 124, 217, 823, 136, 112, 157, 128, 747, 202, 189, 147, 115, 224, 122, 920, 2176, 235, 247, 1861, 110, 124, 1178, 199, 166, 106, 457, 3718, 154, 12597, 641, 182, 354, 164, 277, 230, 232, 106, 3445, 979, 1589, 132, 105, 150, 219, 188, 168, 535, 284, 2137, 1351, 101, 542, 115, 804, 114, 384, 125, 102, 1334]

    感觉也还行吧,这算不均衡吗
    zxCoder
        9
    zxCoder  
    OP
       2021-12-30 20:32:15 +08:00
    而且同样的数据,之前在其他基于 bert 的模型上跑也没出现这问题啊
    c0xt30a
        10
    c0xt30a  
       2021-12-31 03:39:28 +08:00
    这个问题的根源在于, 激活函数是 softmax 的时候,最后一层的输出强制非负并且和为 1;
    更换到 sigmoid 之后,输出之和为 1 的限制被解去了。
    当有 1000+ 个 label ,而一个样本最多有两个 label 是 1 的时候,网络输出全为 0 明显是一个局部最小,而且基本上不能跳出来的那种。
    c0xt30a
        11
    c0xt30a  
       2021-12-31 03:47:48 +08:00
    我出个主意:
    鉴于最多有两个标签是 1 ,那么可以依旧用 softmax 做激活,用 soft label 处理下数据。

    譬如,当某个样本只有 1 个标签的时候,把这个标签设置为 0.5 ,其他的 999 个平分剩下的 0.5 ;当某个样本有两个标签的时候,把这两个标签都设置为 0.5 ,其余的标签设置为 0 。

    刚拍脑袋想出来的,不保证靠谱。
    ipwx
        12
    ipwx  
       2021-12-31 10:23:47 +08:00
    @zxCoder 噗你不说我看 sigmoid 还以为你这是二分类。

    多分类你用 sigmoid 做各维独立当然不太行。。。
    zxCoder
        13
    zxCoder  
    OP
       2021-12-31 10:32:29 +08:00
    @ipwx 啊 为啥啊,做的是多标签分类,之前一直用 sigmoid
    ipwx
        14
    ipwx  
       2021-12-31 10:39:31 +08:00
    哦哦哦 等等,是多标签分类。

    那你说不定可以试试在输出层的 loss 上,各维独立添加惩罚系数,就正比于根据你上面那个类别数量的倒数。

    不过我觉得你说不定前面的网络哪里有问题。。。
    ipwx
        15
    ipwx  
       2021-12-31 10:53:14 +08:00
    哦又重新看了看你的问题描述。

    我懂了,你的问题果然还是各维独立做多标签这里出了问题。
    记 C = sum(你上面那个频数列表)

    np.percentile(C / C.sum(), [1,10,20,50,80,90,99])
    => array([0.00119775, 0.0013504 , 0.00157821, 0.00292978, 0.00979803,
    0.01954908, 0.06111261])

    这个 imbalance 太严重了。各个维度都输出 0 就是一个非常巨大的局部最小值点。

    解决方法 1:在每一维输出的时候,如果是 1 就 loss 添加惩罚系数 C / 类别频率,0 就添加 C / (1 - 类别频率)。这样能把各维二分类各自的 imbalance 问题降下来。

    解决方法 2:还是 softmax 做 top-k 分类推荐。额外添加一个分类器预测到底输入有多少个输出类。
    ipwx
        16
    ipwx  
       2021-12-31 10:55:57 +08:00
    @zxCoder 从逻辑回归角度看 0-1 分类会有这样的结论:

    对于一个二分类器而言,如果你的类别 A 出现概率是 p ,类别 B 出现概率是 1-p ,并且两类在 loss 上的贡献是等权的,那么最终你的分类器的决策边界大概在 p 附近。

    你的所有类别都严重 imbalance ,p 的 99% 分位数都只有 0.06 ,那么你每一类的分类边界,小的大概在 0.001 ,大的也只有 0.06 。但是你在做预测的时候决策边界是 0.5 ,那么当然全部认为是 0 。

    我说的那个惩罚系数就是强行把这个决策边界拉回 0.5 附近。
    ipwx
        17
    ipwx  
       2021-12-31 10:56:32 +08:00
    想了想惩罚系数也可以这样

    1 => (1 - 类别频率),0 => 类别频率
    zxCoder
        18
    zxCoder  
    OP
       2021-12-31 11:02:37 +08:00
    @ipwx 没辙了。。。加了 weight 也不行,把多标签分类改成多分类也不行,总是预测到同一个类别。。。这破论文怎么就还能发呢。。。
    zxCoder
        19
    zxCoder  
    OP
       2021-12-31 11:14:32 +08:00
    @ipwx
    ”1 => (1 - 类别频率),0 => 类别频率“

    老哥能不能说下代码怎么实现这个啊。这个是要自己循环判断每个 label 的预测结果再计算 loss 吗?直接用 weight 好像只能是给每个 label 分配不同的权重
    ipwx
        20
    ipwx  
       2021-12-31 13:14:52 +08:00   1
    @zxCoder 你们用 pytorch 的都没有自己写 train loop 的经历了么啊,只能 model.fit ??

    class_prob = C / C.sum()

    自己组 loss 不就随便搞了。伪代码:

    binary_cross_entropy(logits, label) * torch.where(label == 1, 1-C[label], C[label])

    肯定有辙啊
    zxCoder
        21
    zxCoder  
    OP
       2022-01-01 17:29:19 +08:00
    @ipwx 还是没啥效果,真是奇怪,同样的数据,在别的模型上就挺正常的,在这个模型上就不行,网络结构都是一个 bert ,一个 dropout 层和最后一个 linear 层
    ipwx
        22
    ipwx  
       2022-01-01 18:12:24 +08:00
    摸不到你的数据所以没法远程给建议了(趴
    zxCoder
        23
    zxCoder  
    OP
       2022-01-06 20:41:44 +08:00
    @ipwx 最后解决方案:去 tm 的开源代码,自己重新复现了一份,精度正常了。。。。
    关于     帮助文档     自助推广系统     博客     API     FAQ     Solana     2359 人在线   最高记录 6679       Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 27ms UTC 03:13 PVG 11:13 LAX 19:13 JFK 22:13
    Do have faith in what you're doing.
    ubao msn snddm index pchome yahoo rakuten mypaper meadowduck bidyahoo youbao zxmzxm asda bnvcg cvbfg dfscv mmhjk xxddc yybgb zznbn ccubao uaitu acv GXCV ET GDG YH FG BCVB FJFH CBRE CBC GDG ET54 WRWR RWER WREW WRWER RWER SDG EW SF DSFSF fbbs ubao fhd dfg ewr dg df ewwr ewwr et ruyut utut dfg fgd gdfgt etg dfgt dfgd ert4 gd fgg wr 235 wer3 we vsdf sdf gdf ert xcv sdf rwer hfd dfg cvb rwf afb dfh jgh bmn lgh rty gfds cxv xcv xcs vdas fdf fgd cv sdf tert sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf sdf shasha9178 shasha9178 shasha9178 shasha9178 shasha9178 liflif2 liflif2 liflif2 liflif2 liflif2 liblib3 liblib3 liblib3 liblib3 liblib3 zhazha444 zhazha444 zhazha444 zhazha444 zhazha444 dende5 dende denden denden2 denden21 fenfen9 fenf619 fen619 fenfe9 fe619 sdf sdf sdf sdf sdf zhazh90 zhazh0 zhaa50 zha90 zh590 zho zhoz zhozh zhozho zhozho2 lislis lls95 lili95 lils5 liss9 sdf0ty987 sdft876 sdft9876 sdf09876 sd0t9876 sdf0ty98 sdf0976 sdf0ty986 sdf0ty96 sdf0t76 sdf0876 df0ty98 sf0t876 sd0ty76 sdy76 sdf76 sdf0t76 sdf0ty9 sdf0ty98 sdf0ty987 sdf0ty98 sdf6676 sdf876 sd876 sd876 sdf6 sdf6 sdf9876 sdf0t sdf06 sdf0ty9776 sdf0ty9776 sdf0ty76 sdf8876 sdf0t sd6 sdf06 s688876 sd688 sdf86