首先利用 libtorch 库封装了一个libgotorch库,已支持最新的 libtorch2.0.1
问题一:cgo 中返回的 tensor 对象在栈上,直接使用可能会有内存安全问题
我做了一层简单的封装来使其创建到堆上,但其引发的问题是需要手动管理内存,因此我编写了 mmgr 包在每一个 tensor 对象创建的时候自动加入 mmgr 的 storage 当中,最后在每一轮训练完毕后通过 GC 方法释放堆上的 tensor 对象
问题二:windows 下的 libtorch 库通过 msvc 编译,提供的是 C++接口,无法在 mingw 中无法正常链接
解决方案是通过在封装一个动态链接库并暴露 C 语言接口,在 mingw 中即可正常链接
通过解决以上两个问题,已可以在 go 语言中使用 libtorch 库并实现自己的模型了
下面进入正题,我在 tnn 库中实现了一个小型的 GPT 模型来实现对对联:couplet,下面让我们来看一下最终效果
$ go run main.go evaluate --model model7M 晚风摇树树还挺 load embedding... model loaded inputs: [472 3 462 148 148 342 1516] map[4.278747:[醉] 5.084207:[润] 8.868446:[晨]] map[3.8447263:[花] 4.750472:[润] 8.635651:[露]] map[5.46043:[花] 6.7003703:[露] 10.768249:[润]] map[4.3850584:[露] 4.875666:[润] 9.896332:[花]] map[3.6241615:[红] 5.611262:[润] 10.782802:[花]] map[4.3855276:[花] 5.48069:[红] 9.480111:[更]] map[3.7904112:[心] 4.269902:[花] 10.3220415:[红]] 晨露润花花更红 $ go run main.go evaluate --model model7M 投石向天跟命斗 load embedding... model loaded inputs: [1233 190 383 11 2623 620 490] map[5.7068815:[门] 5.7826476:[问] 9.79136:[闭]] map[3.0136497:[问] 3.1092193:[人] 8.903796:[门]] map[3.021591:[还] 3.448888:[歌] 8.96453:[问]] map[4.9368696:[地] 5.7390223:[时] 9.438878:[卷]] map[3.5542138:[话] 3.858942:[时] 8.253393:[与]] map[3.025545:[与] 3.2461479:[卷] 9.06726:[时]] map[4.250452:[时] 4.712057:[舟] 10.401218:[争]] 闭门问卷与时争
注意:该模型仅训练了开源数据集couplet-dataset中的前 1 万个样本
模型的参数结构如下:
+------------------------+---------+ | NAME | COUNT | +------------------------+---------+ | transformer0_attention | 1872 | | transformer0_dense | 1256640 | | transformer0_output | 1254960 | | transformer1_attention | 1872 | | transformer1_dense | 1256640 | | transformer1_output | 1254960 | | output | 2488596 | | total | 7515540 | +------------------------+---------+ train 200, cost=2h15m7.877395694s, loss=3.665343e-02
整个模型共有 751 万个参数,模型包含 2 个 transformer 模块,由于在训练时只使用了 8 个 float32 来对每一个字进行表征,因此 attention 层的参数量较少,其他参数配置如下:
const embeddingDim = 8 // 8 个 float32 表示一个字向量 const paddingSize = 70 // 最长为 34*2 ,因此 padding 长度必须大于 68 const heads = 4 const batchSize = 128 const epoch = 200 const lr = 0.001 const transformerSize = 2
最后让我们来看看模型的泛化能力如何
$ go run main.go evaluate --model model7M 我是谁 load embedding... model loaded inputs: [85 62 191] map[4.3809786:[雨] 4.9436274:[染] 7.105626:[绿]] map[3.8163047:[水] 4.013789:[东] 4.088595:[得]] map[4.872726:[唱] 5.4107614:[兰] 6.3983927:[发]] 绿得发 $ go run main.go evaluate --model ./model7M 我在哪 load embedding... model loaded inputs: [85 99 1151] map[1.480957:[思] 2.002811:[得] 4.0260763:[寻]] map[3.4100764:[女] 3.868993:[对] 4.448501:[得]] map[2.2672489:[年] 2.3772364:[历] 4.946753:[谁]] 寻得谁
效果不是很理想,可能还是跟训练的样本数量太少有关
另外还有一些示例可在 example 目录下找到,如使用 RNN 来学习如何画 sin 曲线等
最后是项目地址:
修复了FFN层的实现方式问题,现在参数数量看起来比较正常
+------------------------+---------+ | NAME | COUNT | +------------------------+---------+ | transformer0_attention | 62208 | | transformer0_dense | 66048 | | transformer0_output | 65664 | | transformer1_attention | 62208 | | transformer1_dense | 66048 | | transformer1_output | 65664 | | transformer2_attention | 62208 | | transformer2_dense | 66048 | | transformer2_output | 65664 | | transformer3_attention | 62208 | | transformer3_dense | 66048 | | transformer3_output | 65664 | | output | 572244 | | total | 1347924 | +------------------------+---------+
另外增加了各种mask的支持,修复了loss函数使用问题
![]() | 1 coosir 2023-06-15 11:42:08 +08:00 强哦,对联对得倒是挺好的 |
![]() | 2 vus520 2023-06-16 10:17:31 +08:00 关注下,如果能把 huggingface 的包实现一遍真是造福我众 |
![]() | 3 allegory 2024-04-10 18:16:03 +08:00 libtorch 都还是 beta 版,你再来一个 go 的封装,稳定性/正确性如何保证?不过还是很强 |