class data_generator:
def __init__(self, data, batch_size=32, shuffle=True):
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.steps = len(self.data) // self.batch_size
if len(self.data) % self.batch_size != 0:
self.steps += 1
def __len__(self):
return self.steps
def __iter__(self):
while True:
idxs = list(range(len(self.data)))
if self.shuffle:
np.random.shuffle(idxs)
X1, X2, Y = [], [], []
for i in idxs:
d = self.data[i]
text = d[0][:maxlen]
x1, x2 = tokenizer.encode(first=text)
y = d[1]
X1.append(x1)
X2.append(x2)
Y.append([y])
if len(X1) == self.batch_size or i == idxs[-1]:
X1 = seq_padding(X1)
X2 = seq_padding(X2)
Y = seq_padding(Y)
yield [X1, X2], Y
[X1, X2, Y] = [], [], []
# bert模型设置
# from tensorflow import ConfigProto
# from tensorflow import InteractiveSession
# config = ConfigProto()
# config.gpu_options.allow_growth = True
# session = InteractiveSession(config=config)
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) # 加载预训练模型**(错误显示出现在这一行)**
for l in bert_model.layers:
l.trainable = True
运行结果
Traceback (most recent call last):
File "F:/pycharm最新/3 第四次数据修改代码运行/bert/run_tnews_classifier.py", line 137, in <module>
bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None) # 加载预训练模型
File "D:\python3.7\lib\site-packages\keras_bert\loader.py", line 169, in load_trained_model_from_checkpoint
**kwargs)
File "D:\python3.7\lib\site-packages\keras_bert\loader.py", line 58, in build_model_from_config
**kwargs)
File "D:\python3.7\lib\site-packages\keras_bert\bert.py", line 84, in get_model
dropout_rate=dropout_rate,
File "D:\python3.7\lib\site-packages\keras_bert\layers\embedding.py", line 37, in get_embedding
)(inputs[0]),
File "D:\python3.7\lib\site-packages\keras\backend\tensorflow_backend.py", line 75, in symbolic_fn_wrapper
return func(*args, **kwargs)
File "D:\python3.7\lib\site-packages\keras\engine\base_layer.py", line 529, in __call__
arguments=user_kwargs)
File "D:\python3.7\lib\site-packages\keras\engine\base_layer.py", line 597, in _add_inbound_node
output_tensors[i]._keras_shape = output_shapes[i]
IndexError: list index out of range
给位 挚友,能不能帮忙解决一下啊