TO-GA 2022-03-20 17:21 采纳率: 0%
浏览 335

使用OpenKE中TransH模型在自制数据集训练时遇到错误

各位大佬好,最近在用OpenKE中TransH模型在自制简单数据集上训练时遇到了麻烦,在用GPU时报“CUDA ERROR: device-side assert triggered”的错误,在用CPU时报“IndexError: index out of range in self”的错误,我猜测这应该是哪个维度出错了,但一直找不到原因,也不知道该改哪里,下面是我的训练代码和完整的运行情况,希望各位大佬多多指正:
训练代码如下:

import openke
from openke.config import Trainer, Tester
from openke.module.model import TransH
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader

# dataloader for training
train_dataloader = TrainDataLoader(
    in_path = "./benchmarks/XG02/",
    nbatches = 24,
    threads = 8, 
    sampling_mode = "normal", 
    bern_flag = 1, 
    filter_flag = 1, 
    neg_ent = 1,
    neg_rel = 0)

# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/XG02/", "link")

# define the model
transh = TransH(
    ent_tot = train_dataloader.get_ent_tot(),
    rel_tot = train_dataloader.get_rel_tot(),
    dim = 24, 
    p_norm = 1, 
    norm_flag = True)

# define the loss function
model = NegativeSampling(
    model = transh, 
    loss = MarginLoss(margin = 4.0),
    batch_size = train_dataloader.get_batch_size()
)


# train the model
trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, alpha = 0.5, use_gpu = False)
print(model)
print(transh)
print(trainer.model.parameters)
trainer.run()
transh.save_checkpoint('./checkpoint/transh.ckpt')

# test the model
transh.load_checkpoint('./checkpoint/transh.ckpt')
tester = Tester(model = transh, data_loader = test_dataloader, use_gpu = False)
tester.run_link_prediction(type_constrain = False)

运行情况如下:

img

  • 写回答

4条回答 默认 最新

  • 小喽啰罗 2022-03-22 16:11
    关注

    会不会是你的数据格式有问题,openke要求的(e1,e2,r),而很多数据集确是(e1,r,e2)

    评论

报告相同问题?

问题事件

  • 创建了问题 3月20日

悬赏问题

  • ¥15 CSS通配符清除内外边距为什么可以覆盖默认样式?
  • ¥15 SPSS分类模型实训题步骤
  • ¥15 求解决扩散模型代码问题
  • ¥15 工创大赛太阳能电动车项目零基础要学什么
  • ¥20 limma多组间分析最终p值只有一个
  • ¥15 nopCommerce开发问题
  • ¥15 torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with signal SIGKILL
  • ¥15 QuartusⅡ15.0编译项目后,output_files中的.jdi、.sld、.sof不更新怎么解决
  • ¥15 pycharm输出和导师的一样,但是标红
  • ¥15 想问问富文本拿到的html怎么转成docx的