理智明 2019-08-13 09:14 采纳率: 0%
浏览 1209

深度学习模型喂参数的时候报错In[0] ndims must be >= 2: 1

图片说明

import json
import tensorflow as tf
import numpy as np
class TrainingConfig(object):
epoches = 10 #迭代次数参数
evaluateEvery = 100 # 每个评估
checkpointEvery = 100 # 每个检查站
learningRate = 0.001 # 学习率

class ModelConfig(object):
embeddingSize = 100
#词嵌入维度参数:词嵌入维度刻画了模型表达词汇的能力,
hiddenSizes = [128] # LSTM结构的神经元个数
dropoutKeepProb = 0.5 # dropout保持率,防止过拟合
l2RegLambda = 0.0 #L2 正则化项的权重系数,越大模型越保守。
class Config(object):
sequenceLength = 100 #序列长度应当与给出数据的结构相匹配,大约是你希望生成句子的长度,通常设置大一点的值会使模型的能力也跟着强一点。
batchSize = 128 #批处理大小

dataSource = "data/data1.xls"

stopWordSource = "data/stopWordsHIT.dic"

numClasses = 3

rate = 0.8  # 训练集的比例

training = TrainingConfig()

model = ModelConfig()

实例化配置参数对象

config = Config()

x = "这 主要 是 隔膜 的 工艺 所 决定 的 隔膜 是 要 除 槽 的 负压 可以 防止 氯气 外逸 氯气 是 含氧 是 可以" \
" 接受 的 氢气 中 含氧 是 不 可以 接受 的 另外 在 电解 过程 中 氢气 出现 大 正压 会 便 得 氯中 含氢 高 危险 并且 会 冲击 电槽 安全 有 风险 氢气 大 负压 会 吸入 空气 氢中 含氧 高 氯气 大 正压 会 产生 泄漏 事故 氯 启动"

注:下面两个词典要保证和当前加载的模型对应的词典是一致的

with open("data/wordJson/wordToIndex.json", "r", encoding="utf-8") as f:
word2idx = json.load(f)

with open("data/wordJson/indexToWord.json", "r", encoding="utf-8") as f:
label2idx = json.load(f)
idx2label = {value: key for key, value in label2idx.items()}

xIds = [word2idx.get(item, word2idx["UNK"]) for item in x.split(" ")]
if len(xIds) >= config.sequenceLength:
xIds = xIds[:config.sequenceLength]
else:
xIds = xIds + [word2idx["PAD"]] * (config.sequenceLength - len(xIds))

graph = tf.Graph()
with graph.as_default():
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
#session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options)
checkpoint_file = tf.train.latest_checkpoint("model/Bi-LSTM/model/")
sess = tf.Session()
with sess.as_default():

    saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
    print(saver)
    saver.restore(sess, checkpoint_file)

    # 获得需要喂给模型的参数,输出的结果依赖的输入值
    inputX = graph.get_operation_by_name("inputX").outputs[0]
    dropoutKeepProb = graph.get_operation_by_name("dropoutKeepProb").outputs[0]
    print(inputX)
    xids=np.array(xIds).reshape(-1,100)
    print(xids.shape)
    feed_dict={inputX: xids, dropoutKeepProb: 0.5}
    pred = sess.run("output/predictions:0", feed_dict)
    print(pred)

pred = [idx2label[item] for item in pred]
print(pred)

借鉴了文章https://www.cnblogs.com/jiangxinyang/p/10208227.html的内容,但是测试运行的时候报错
图片说明

  • 写回答

1条回答 默认 最新

  • dabocaiqq 2019-08-13 09:26
    关注

    维度不同啊,参数不对,先要调整

    评论

报告相同问题?

悬赏问题

  • ¥100 Jenkins自动化部署—悬赏100元
  • ¥15 关于#python#的问题:求帮写python代码
  • ¥20 MATLAB画图图形出现上下震荡的线条
  • ¥15 关于#windows#的问题:怎么用WIN 11系统的电脑 克隆WIN NT3.51-4.0系统的硬盘
  • ¥15 perl MISA分析p3_in脚本出错
  • ¥15 k8s部署jupyterlab,jupyterlab保存不了文件
  • ¥15 ubuntu虚拟机打包apk错误
  • ¥199 rust编程架构设计的方案 有偿
  • ¥15 回答4f系统的像差计算
  • ¥15 java如何提取出pdf里的文字?