基于Chinese-RoBERTa-wwm-ext的THUCNews新闻分类系统
本项目实现了一个完整的中文新闻文本分类系统,使用预训练的Chinese-RoBERTa-wwm-ext模型对THUCNews数据集进行14类新闻分类。项目包含完整的数据预处理、模型训练、评估和推理流程。
- ✅ 预训练模型: Chinese-RoBERTa-wwm-ext (102M参数,Whole Word Masking)
- ✅ 大规模数据集: THUCNews 836K样本,14个新闻类别
- ✅ 类别不平衡处理: 使用sqrt类别权重策略
- ✅ BERT微调: 分层学习率 + Warmup + 早停
- ✅ 完整评估: 准确率、精确率、召回率、F1分数、混淆矩阵
- ✅ 多GPU支持: DataParallel、DDP、Accelerate三种方案
- ✅ 生产就绪: 包含推理接口和完整文档
来源: 清华大学自然语言处理实验室 规模: 836,075个新闻样本 类别: 14个新闻类别
| 类别 | 样本数 | 占比 | 类别权重(sqrt) |
|---|---|---|---|
| 科技 | 162,929 | 19.5% | 0.605 |
| 股票 | 154,398 | 18.5% | 0.622 |
| 体育 | 131,604 | 15.7% | 0.674 |
| 娱乐 | 92,632 | 11.1% | 0.803 |
| 时政 | 63,086 | 7.5% | 0.973 |
| 社会 | 50,849 | 6.1% | 1.084 |
| 教育 | 41,936 | 5.0% | 1.193 |
| 财经 | 37,098 | 4.4% | 1.269 |
| 家居 | 32,586 | 3.9% | 1.354 |
| 游戏 | 24,373 | 2.9% | 1.565 |
| 房产 | 20,050 | 2.4% | 1.726 |
| 时尚 | 13,368 | 1.6% | 2.114 |
| 彩票 | 7,588 | 0.9% | 2.805 |
| 星座 | 3,578 | 0.4% | 4.085 |
数据特点:
- 平均文本长度: 1,209字符
- 平均token长度: 1,088 tokens
- 超过512 tokens的样本: 65%
- 类别不平衡比: 45.5:1 (科技:星座)
BERT_text_classification/
├── config.py # 配置文件(超参数、路径)
├── model.py # BERT分类器模型定义
├── dataset.py # PyTorch Dataset类
├── train.py # 训练脚本
├── evaluate.py # 评估脚本
├── inference.py # 推理接口
│
├── data_preprocessing.py # 数据预处理
├── dataset_split.py # 数据集划分
├── data_exploration.py # 数据探索分析
│
├── requirements.txt # Python依赖
├── pyproject.toml # uv项目配置
│
├── THUCNews/ # 原始数据集
│ ├── 体育/
│ ├── 娱乐/
│ └── ...
│
├── pretrained_model/ # 预训练模型
│ └── chinese-roberta-wwm-ext/
│ ├── config.json
│ ├── pytorch_model.bin
│ └── vocab.txt
│
├── processed_data_balanced.json # 预处理后的数据(131K样本)
├── train.json # 训练集(104,800样本)
├── val.json # 验证集(13,100样本)
├── test.json # 测试集(13,100样本)
│
├── checkpoints/ # 模型检查点
│ ├── best_model.pt
│ └── checkpoint_epoch_*.pt
│
├── results/ # 评估结果
│ ├── test_results.json
│ ├── confusion_matrix.png
│ └── per_class_metrics.png
│
├── logs/ # 训练日志
│ └── training_history.json
│
└── docs/ # 文档
├── ROBERTA_MODEL_ANALYSIS.md # RoBERTa模型分析
├── MULTI_GPU_TRAINING.md # 多GPU训练指南
├── MODEL_OUTPUT_ANALYSIS.md # 模型输出层分析
├── PYTORCH_2.6_FIX.md # PyTorch 2.6兼容性修复
└── ...
- Python 3.12+
- PyTorch 2.6+
- CUDA 12.4+ (GPU训练)
pip install -r requirements.txt# 安装uv
pip install uv
# 创建虚拟环境并安装依赖
uv sync- python = 3.12
- numpy
- pandas
- scipy
- matplotlib
- scikit-learn
- tqdm
- seaborn
- torch+cu124 >= 2.6
- transformers
# 将THUCNews.zip解压到项目根目录
unzip THUCNews.zip# 1. 数据预处理(清洗、标签编码)
python data_preprocessing.py
# 2. 数据集划分(train/val/test)
python dataset_split.pypython download_roberta.py# 从HuggingFace下载
mkdir -p pretrained_model
cd pretrained_model
git lfs install
git clone https://huggingface.co/hfl/chinese-roberta-wwm-extpython train.py# 使用4个GPU
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py
# 或使用torchrun (PyTorch 1.10+)
torchrun --nproc_per_node=4 train_ddp.pypython evaluate.pyfrom inference import TextClassifier
model_path = os.path.join(CHECKPOINT_DIR, 'best_model.pt')
# 加载模型
classifier = TextClassifier(
model_path=model_path,
pretrained_model_path=PRETRAINED_MODEL_PATH
)
# 预测单个文本
text = "中国男篮在世界杯上取得了胜利"
result = classifier.predict(text, return_probs=True)
print(f"预测类别: {result['category']}")
print(f"置信度: {result['confidence']:.4f}")
print(f"所有类别概率: {result['probabilities']}")# 模型配置
PRETRAINED_MODEL_PATH = 'pretrained_model/chinese-roberta-wwm-ext'
MAX_LENGTH = 512 # BERT最大序列长度
DROPOUT_RATE = 0.1 # Dropout比例
# 训练配置
BATCH_SIZE = 16 # 批次大小
LEARNING_RATE = 2e-5 # BERT层学习率
CLASSIFIER_LR = 1e-4 # 分类层学习率(5倍)
NUM_EPOCHS = 5 # 训练轮数
WARMUP_RATIO = 0.1 # Warmup比例
WEIGHT_DECAY = 0.01 # 权重衰减
MAX_GRAD_NORM = 1.0 # 梯度裁剪
# 类别权重
USE_CLASS_WEIGHTS = True # 是否使用类别权重
CLASS_WEIGHT_METHOD = 'sqrt' # 'inverse' 或 'sqrt'
# BERT微调
FREEZE_BERT = True # False表示微调整个BERT
# 早停
EARLY_STOPPING_PATIENCE = 3 # 早停耐心值
BEST_MODEL_METRIC = 'f1' # 保存最佳模型的指标基本信息:
- 开发者: 哈工大讯飞联合实验室 (HFL)
- 参数量: 102M
- 层数: 12层Transformer
- 隐藏层维度: 768
- 注意力头数: 12
- 词汇表大小: 21,128
核心技术:
- ✅ Whole Word Masking (全词遮罩)
- ✅ 动态Masking
- ✅ 扩展预训练语料 (9GB中文文本)
- ✅ 移除NSP任务
- ✅ 更大Batch Size
性能提升: 相比BERT-Base-Chinese,在中文NLP任务上平均提升1-3%
Input Text
↓
Tokenizer (BertTokenizer)
↓
[CLS] token_1 token_2 ... token_n [SEP] [PAD] ...
↓
BERT Encoder (12 layers, 768 hidden)
↓
[CLS] Representation (pooler_output)
↓
Dropout (p=0.1)
↓
Linear Layer (768 → 14)
↓
Logits (raw scores)
↓
CrossEntropyLoss (with built-in softmax)
关键设计:
- 使用[CLS] token的pooler_output作为句子表示
- 输出层返回logits(未经softmax)
- CrossEntropyLoss内置softmax,数值更稳定
- 分层学习率:BERT层2e-5,分类层1e-4
问题: 最大类别(科技)是最小类别(星座)的45.5倍
解决方案: 使用sqrt类别权重
weight_i = sqrt(total_samples / (num_classes * count_i))效果:
- 星座类权重: 4.085 (最高)
- 科技类权重: 0.605 (最低)
- 平衡了模型对少数类的关注
optimizer = AdamW([
{'params': bert_params, 'lr': 2e-5}, # BERT层较小学习率
{'params': classifier_params, 'lr': 1e-4} # 分类层较大学习率
])原因:
- BERT已经预训练好,需要小心微调
- 分类层随机初始化,需要更快学习
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=total_steps * 0.1, # 10% warmup
num_training_steps=total_steps
)Warmup策略:
- 前10%步数线性增加学习率
- 之后线性衰减到0
- 避免训练初期梯度爆炸
patience = 3 # 连续3个epoch无改进则停止
delta = 0.001 # 最小改进阈值监控指标: F1分数(宏平均)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)防止梯度爆炸,稳定训练
- Accuracy: 整体准确率
- Macro F1: 宏平均F1(所有类别平等权重)
- Weighted F1: 加权F1(按样本数加权)
- Per-class Metrics: 每个类别的精确率、召回率、F1
评估脚本会自动生成混淆矩阵可视化:
results/confusion_matrix.pngresults/per_class_metrics.png
基于Chinese-RoBERTa-wwm-ext在THUCNews上的预期表现:
| 指标 | 预期值 |
|---|---|
| Accuracy | 96.0-96.5% |
| Macro F1 | 95.5-96.0% |
| Weighted F1 | 96.0-96.5% |
# train.py中添加
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)优点: 代码改动最少 缺点: 性能较差,GPU 0负载更重
# 单机4卡
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py
# 多机多卡
# 节点0 (主节点)
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=0 \
--master_addr="192.168.1.100" \
--master_port=12345 \
train_ddp.py
# 节点1
python -m torch.distributed.launch \
--nproc_per_node=4 \
--nnodes=2 \
--node_rank=1 \
--master_addr="192.168.1.100" \
--master_port=12345 \
train_ddp.py优点: 性能最优,线性加速 缺点: 代码改动较多
# 配置
accelerate config
# 训练
accelerate launch train.py优点: API最简单,自动处理分布式 缺点: 需要额外依赖
| 配置 | 时间/Epoch | 加速比 |
|---|---|---|
| 单卡 | 120分钟 | 1.0x |
| 2卡 DDP | 60分钟 | 2.0x |
| 4卡 DDP | 30分钟 | 4.0x |
| 8卡 DDP | 15分钟 | 8.0x |
错误信息:
_pickle.UnpicklingError: Weights only load failed
解决方案:
# 在torch.load()中添加weights_only=False
checkpoint = torch.load(model_path, map_location=device, weights_only=False)解决方案:
# 1. 减小batch size
BATCH_SIZE = 8 # 从16减到8
# 2. 减小序列长度
MAX_LENGTH = 256 # 从512减到256
# 3. 使用梯度累积
GRADIENT_ACCUMULATION_STEPS = 2
# 4. 使用混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()优化方案:
# 1. 增加num_workers
DataLoader(..., num_workers=4, pin_memory=True)
# 2. 使用多GPU
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py
# 3. 使用混合精度
with autocast():
logits = model(...)解决方案:
# 1. 增加Dropout
DROPOUT_RATE = 0.3 # 从0.1增加到0.3
# 2. 增加权重衰减
WEIGHT_DECAY = 0.1 # 从0.01增加到0.1
# 3. 使用早停
EARLY_STOPPING_PATIENCE = 2
# 4. 减少训练轮数
NUM_EPOCHS = 3解决方案:
# 1. 使用更强的类别权重
CLASS_WEIGHT_METHOD = 'inverse' # 从'sqrt'改为'inverse'
# 2. 数据增强(针对少数类)
USE_DATA_AUGMENTATION = True
# 3. 过采样少数类
from imblearn.over_sampling import RandomOverSampler- ROBERTA_MODEL_ANALYSIS.md - RoBERTa模型详细分析
- MODEL_OUTPUT_ANALYSIS.md - 模型输出层分析(Logits vs Softmax)
- MULTI_GPU_TRAINING.md - 多GPU训练完整指南
- PYTORCH_2.6_FIX.md - PyTorch 2.6兼容性修复
- BERT_FINETUNING_ANALYSIS.md - BERT微调分析
- MODEL_ARCHITECTURE_ANALYSIS.md - 模型架构分析
- MODEL_IMPROVEMENT_RECOMMENDATION.md - 模型改进建议
- 准备数据格式(每个类别一个文件夹)
- 修改
config.py中的CATEGORIES - 运行数据预处理:
python data_preprocessing.py
python dataset_split.py- 下载新模型到
pretrained_model/ - 修改
config.py:
PRETRAINED_MODEL_PATH = 'pretrained_model/your-model'- 确保模型兼容BertModel接口
修改model.py中的BertClassifier类:
class BertClassifier(nn.Module):
def __init__(self, ...):
# 添加自定义层
self.lstm = nn.LSTM(768, 256, bidirectional=True)
self.classifier = nn.Linear(512, num_classes)本项目仅供学习和研究使用。
- THUCNews数据集版权归清华大学所有
- 仅供学术研究使用
- Chinese-RoBERTa-wwm-ext: Apache License 2.0
- 详见: https://github.com/ymcui/Chinese-BERT-wwm
- 预训练模型: 哈工大讯飞联合实验室 (HFL)
- 数据集: 清华大学自然语言处理实验室
- 框架: PyTorch, HuggingFace Transformers
如有问题或建议,请提交Issue或Pull Request。
- ✅ 完整的数据预处理流程
- ✅ Chinese-RoBERTa-wwm-ext模型集成
- ✅ 类别不平衡处理(sqrt权重)
- ✅ BERT微调训练
- ✅ 完整评估和推理接口
- ✅ 多GPU训练支持
- ✅ PyTorch 2.6兼容性修复
- ✅ 完整技术文档
如果本项目对您的研究有帮助,请引用:
@misc{bert_thucnews_classification,
title={BERT Chinese Text Classification on THUCNews},
author={Your Name},
year={2026},
howpublished={\url{https://github.com/yourusername/BERT_text_classification}}
}Happy Coding! 🚀