shiter
2022-02-03 10:23
采纳率: 100%
浏览 136

问下使用albert 进行多标签的文本分类[12个类别],为啥 acc 才这么点,是哪里没写对么?

import tensorflow as tf
 
tf.__version__
'2.2.2'
import keras
print(keras.__version__)
2.3.1


Using TensorFlow backend.

from tensorflow import keras
print(keras.__version__)
2.3.0-tf
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

MAX_LENGTH = 1024
#载入训练集和测试集数据
# train = pd.read_csv(r'D:\code\python\csdn_nlp\data\文本分类数据\Texts_Classification\train2.csv')
# test = pd.read_csv(r'D:\code\python\csdn_nlp\data\文本分类数据\Texts_Classification\test2.csv')

train = pd.read_csv(r'/mnt/d/code/python/csdn_nlp/data/文本分类数据/Texts_Classification/train2.csv')
test = pd.read_csv(r'/mnt/d/code/python/csdn_nlp/data/文本分类数据/Texts_Classification/test2.csv')
train.head(10)

Unnamed: 0分类文章字数分词label
010000体育欧洲杯观察:足坛格局地震 东南欧已能占半边天 搜狐体育讯 土耳其与克罗地亚在本届欧洲杯1/4...1818['欧洲杯', '观察', '足坛', '格局', '地震', '东南欧', '已能', '...8
110001体育来源:搜狐体育 作者:雷欧 主队 比分 客队 进球视频青岛 1-1 武汉浙江 1-1 北京搜...2190['来源', '搜狐', '体育', '作者', '雷欧', '主队', '比分', '客队...8
210002体育欧洲杯-帕夫柳琴科范尼破门 荷兰1-1俄罗斯加时 搜狐体育讯 北京时间6月22日,2008欧...2078['欧洲杯', '帕夫', '柳琴', '科范尼', '破门', '荷兰', '俄罗斯', ...8
310003体育郅联璧合配合日臻完善 斯杯之前尤纳斯再谈裁员 7年后重登上海滩,中国男篮来去匆匆。队长刘炜回...1810['郅', '联璧', '合', '配合', '日臻完善', '斯杯', '之前', '尤纳...8
410004体育盘踞世界第一逼近五百周 伍兹创造高尔夫里程碑 搜狐体育讯 下周一,老虎-伍兹将第499周稳坐...1893['盘踞', '世界', '第一', '逼近', '五百', '周伍', '兹', '创造'...8
510005体育第五日看点:郑洁挑战NO.1 费天王冲击男单16强 搜狐体育讯 北京时间6月27日晚,200...2082['第五日', '看点', '郑洁', '挑战', '费', '天王', '冲击', '男单...8
610007体育王德显改变双姝命运 邢慧娜孙英杰成败均因一人 2007年年底,邢慧娜来到了北京。田管中心已经...1855['王德显', '改变', '双姝', '命运', '邢慧娜', '孙', '英杰', '成...8
710008体育半场实录:皮尔斯异军突起 湖人队领先凯尔特人 搜狐直播员:搜狐网友:赛式是232制的,还是主...8550['半场', '实录', '皮尔斯', '异军突起', '湖人队', '领先', '凯尔特人...8
810009体育温格:德国可成英格兰榜样 三狮军团差在意志力 晨报特派记者 甘慧(奥地利维也纳6月29日电)...2162['温格', '德国', '可成', '英格兰', '榜样', '三狮', '军团', '差...8
910010体育欧洲杯-范尼斯内德建功 荷兰3-0狂屠意大利 搜狐体育讯 北京时间6月10日凌晨2:45,2...2643['欧洲杯', '范', '尼斯', '德', '建功', '荷兰', '狂屠', '意大利...8
temp_train_data_list = train[["文章","label"]].values.tolist()
train_data_list = [tuple(x) for x in temp_train_data_list]

temp_test_data_list = test[["文章","label"]].values.tolist()
test_data_list = [tuple(x) for x in temp_test_data_list]
train_data_list[0]
print(type(train_data_list))
<class 'list'>
# from sklearn.model_selection import train_test_split
# '''不采取分层抽样时的数据集分割'''
# valid_data_list = train_test_split(train_data_list,test_size=0.3)
# len(valid_data_list)
# valid_data_list[0]
import random
baifenbi = 80
length = len(train_data_list)
print(length)
len(train_data_list[int((length/100)*baifenbi):])
23549





4710
# 加载数据集 这么取 值 标签训练会有问题,所以要随机打乱,然后选取
import random
random.shuffle(train_data_list)
train_data = train_data_list[0:int((length/100)*baifenbi)]
valid_data = train_data_list[int((length/100)*baifenbi):]


# random.seed(10)
# valid_data = random.sample(train_data_list, 5) 

test_data = test_data_list
import numpy as np
from sklearn import metrics
from bert4keras.tokenizers import Tokenizer
from bert4keras.backend import keras, set_gelu
from bert4keras.models import build_transformer_model
from bert4keras.snippets import DataGenerator, sequence_padding
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from keras.layers import Lambda, Dense
from contextlib import redirect_stdout
set_gelu('tanh')  # 切换gelu版本

# 文本共计 12 个类别
num_classes = 12
# 长文本,长一些,好像bert 类,最长 512? 所以编码的文字 应该去掉停用词
maxlen = 64

batch_size = 8

# config_path = r"D:/code/python/csdn_nlp/model/albert_xlarge_zh/albert_xlarge/albert_config.json"
# checkpoint_path =tf.train.latest_checkpoint(r"D:/code/python/csdn_nlp/model/albert_xlarge_zh/albert_xlarge/")
# #checkpoint_path = r"D:/code/python/csdn_nlp/model/albert_xlarge_zh/albert_xlarge/model"
# dict_path = r"D:/code/python/csdn_nlp/model/albert_xlarge_zh/albert_xlarge/vocab_chinese.txt"

# linux 路径
config_path = r"/mnt/d/code/python/csdn_nlp/model/albert_base/albert_config.json"
# albert-base-v2 的压缩包里面没有ckpt,windows 下面加载报错,所以应该在linux 下面使用
checkpoint_path =tf.train.latest_checkpoint(r"/mnt/d/code/python/csdn_nlp/model/albert_base/model.ckpt-best.data-00000-of-00001")
#checkpoint_path = r"/mnt/d/code/python/csdn_nlp/model/albert_base/"
dict_path = r"/mnt/d/code/python/csdn_nlp/model/albert_base/vocab_chinese.txt"

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)


class data_generator(DataGenerator):
    """数据生成器
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, (text, label) in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([label])
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []


# 加载预训练模型
bert = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='albert',
    return_keras_model=False,
)

output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
output = Dense(
    units=num_classes,
    activation='softmax',
    kernel_initializer=bert.initializer
)(output)

model = keras.models.Model(bert.model.input, output)
model.summary()

# 派生为带分段线性学习率的优化器。
# 其中name参数可选,但最好填入,以区分不同的派生优化器。
AdamLR = extend_with_piecewise_linear_lr(Adam, name='AdamLR')

model.compile(
    loss='sparse_categorical_crossentropy',
    # optimizer=Adam(1e-5),  # 用足够小的学习率
    optimizer=AdamLR(learning_rate=1e-4, lr_schedule={
        1000: 1,
        2000: 0.1
    }),
    metrics=['accuracy'],
)

# 转换数据集
train_generator = data_generator(train_data, batch_size)
valid_generator = data_generator(valid_data, batch_size)
test_generator = data_generator(test_data, batch_size)


def evaluate(data):
    total, right = 0., 0.
    for x_true, y_true in data:
        y_pred = model.predict(x_true).argmax(axis=1)
        y_true = y_true[:, 0]
        total += len(y_true)
        right += (y_true == y_pred).sum()
    return right / total


class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_acc = 0.

    def on_epoch_end(self, epoch, logs=None):
        val_acc = evaluate(valid_generator)
        if val_acc > self.best_val_acc:
            self.best_val_acc = val_acc
            model.save_weights('best_model.weights')
        test_acc = evaluate(test_generator)
        print(
            u'val_acc: %.5f, best_val_acc: %.5f, test_acc: %.5f\n' %
            (val_acc, self.best_val_acc, test_acc)
        )
Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input-Token (InputLayer)        (None, None)         0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, None)         0                                            
__________________________________________________________________________________________________
Embedding-Token (Embedding)     (None, None, 128)    2704384     Input-Token[0][0]                
__________________________________________________________________________________________________
Embedding-Segment (Embedding)   (None, None, 128)    256         Input-Segment[0][0]              
__________________________________________________________________________________________________
Embedding-Token-Segment (Add)   (None, None, 128)    0           Embedding-Token[0][0]            
                                                                 Embedding-Segment[0][0]          
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, None, 128)    65536       Embedding-Token-Segment[0][0]    
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, None, 128)    256         Embedding-Position[0][0]         
__________________________________________________________________________________________________
Embedding-Mapping (Dense)       (None, None, 768)    99072       Embedding-Norm[0][0]             
__________________________________________________________________________________________________
Transformer-MultiHeadSelfAttent (None, None, 768)    2362368     Embedding-Mapping[0][0]          
                                                                 Embedding-Mapping[0][0]          
                                                                 Embedding-Mapping[0][0]          
                                                                 Transformer-FeedForward-Norm[0][0
                                                                 Transformer-FeedForward-Norm[0][0
                                                                 Transformer-FeedForward-Norm[0][0
                                                                 Transformer-FeedForward-Norm[1][0
                                                                 Transformer-FeedForward-Norm[1][0
                                                                 Transformer-FeedForward-Norm[1][0
                                                                 Transformer-FeedForward-Norm[2][0
                                                                 Transformer-FeedForward-Norm[2][0
                                                                 Transformer-FeedForward-Norm[2][0
                                                                 Transformer-FeedForward-Norm[3][0
                                                                 Transformer-FeedForward-Norm[3][0
                                                                 Transformer-FeedForward-Norm[3][0
                                                                 Transformer-FeedForward-Norm[4][0
                                                                 Transformer-FeedForward-Norm[4][0
                                                                 Transformer-FeedForward-Norm[4][0
                                                                 Transformer-FeedForward-Norm[5][0
                                                                 Transformer-FeedForward-Norm[5][0
                                                                 Transformer-FeedForward-Norm[5][0
                                                                 Transformer-FeedForward-Norm[6][0
                                                                 Transformer-FeedForward-Norm[6][0
                                                                 Transformer-FeedForward-Norm[6][0
                                                                 Transformer-FeedForward-Norm[7][0
                                                                 Transformer-FeedForward-Norm[7][0
                                                                 Transformer-FeedForward-Norm[7][0
                                                                 Transformer-FeedForward-Norm[8][0
                                                                 Transformer-FeedForward-Norm[8][0
                                                                 Transformer-FeedForward-Norm[8][0
                                                                 Transformer-FeedForward-Norm[9][0
                                                                 Transformer-FeedForward-Norm[9][0
                                                                 Transformer-FeedForward-Norm[9][0
                                                                 Transformer-FeedForward-Norm[10][
                                                                 Transformer-FeedForward-Norm[10][
                                                                 Transformer-FeedForward-Norm[10][
__________________________________________________________________________________________________
Transformer-MultiHeadSelfAttent (None, None, 768)    0           Embedding-Mapping[0][0]          
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[0][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[1][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[2][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[3][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[4][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[5][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[6][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[7][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[8][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[9][0
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward-Norm[10][
                                                                 Transformer-MultiHeadSelfAttentio
__________________________________________________________________________________________________
Transformer-MultiHeadSelfAttent (None, None, 768)    1536        Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
__________________________________________________________________________________________________
Transformer-FeedForward (FeedFo (None, None, 768)    4722432     Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-MultiHeadSelfAttentio
__________________________________________________________________________________________________
Transformer-FeedForward-Add (Ad (None, None, 768)    0           Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[0][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[1][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[2][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[3][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[4][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[5][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[6][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[7][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[8][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[9][0]    
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[10][0]   
                                                                 Transformer-MultiHeadSelfAttentio
                                                                 Transformer-FeedForward[11][0]   
__________________________________________________________________________________________________
Transformer-FeedForward-Norm (L (None, None, 768)    1536        Transformer-FeedForward-Add[0][0]
                                                                 Transformer-FeedForward-Add[1][0]
                                                                 Transformer-FeedForward-Add[2][0]
                                                                 Transformer-FeedForward-Add[3][0]
                                                                 Transformer-FeedForward-Add[4][0]
                                                                 Transformer-FeedForward-Add[5][0]
                                                                 Transformer-FeedForward-Add[6][0]
                                                                 Transformer-FeedForward-Add[7][0]
                                                                 Transformer-FeedForward-Add[8][0]
                                                                 Transformer-FeedForward-Add[9][0]
                                                                 Transformer-FeedForward-Add[10][0
                                                                 Transformer-FeedForward-Add[11][0
__________________________________________________________________________________________________
CLS-token (Lambda)              (None, 768)          0           Transformer-FeedForward-Norm[11][
__________________________________________________________________________________________________
dense_21 (Dense)                (None, 12)           9228        CLS-token[0][0]                  
==================================================================================================
Total params: 9,966,604
Trainable params: 9,966,604
Non-trainable params: 0
__________________________________________________________________________________________________
evaluator = Evaluator()
model.fit(
        train_generator.forfit(),
        steps_per_epoch=len(train_generator),
        epochs=1,
        callbacks=[evaluator]
    )

model.load_weights('best_model.weights')
print(u'final test acc: %05f\n' % (evaluate(test_generator)))
Epoch 1/1
2355/2355 [==============================] - 1663s 706ms/step - loss: 2.5211 - accuracy: 0.0876
val_acc: 0.08493, best_val_acc: 0.08493, test_acc: 0.08333

final test acc: 0.083333

  • 写回答
  • 好问题 提建议
  • 追加酬金
  • 关注问题
  • 邀请回答

2条回答 默认 最新

相关推荐 更多相似问题