Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

essiloan/BERT_text_classification

Open more actions menu

Repository files navigation

BERT中文文本分类项目

基于Chinese-RoBERTa-wwm-ext的THUCNews新闻分类系统


📋 项目概述

本项目实现了一个完整的中文新闻文本分类系统,使用预训练的Chinese-RoBERTa-wwm-ext模型对THUCNews数据集进行14类新闻分类。项目包含完整的数据预处理、模型训练、评估和推理流程。

核心特性

  • 预训练模型: Chinese-RoBERTa-wwm-ext (102M参数,Whole Word Masking)
  • 大规模数据集: THUCNews 836K样本,14个新闻类别
  • 类别不平衡处理: 使用sqrt类别权重策略
  • BERT微调: 分层学习率 + Warmup + 早停
  • 完整评估: 准确率、精确率、召回率、F1分数、混淆矩阵
  • 多GPU支持: DataParallel、DDP、Accelerate三种方案
  • 生产就绪: 包含推理接口和完整文档

📊 数据集信息

THUCNews数据集

来源: 清华大学自然语言处理实验室 规模: 836,075个新闻样本 类别: 14个新闻类别

类别 样本数 占比 类别权重(sqrt)
科技 162,929 19.5% 0.605
股票 154,398 18.5% 0.622
体育 131,604 15.7% 0.674
娱乐 92,632 11.1% 0.803
时政 63,086 7.5% 0.973
社会 50,849 6.1% 1.084
教育 41,936 5.0% 1.193
财经 37,098 4.4% 1.269
家居 32,586 3.9% 1.354
游戏 24,373 2.9% 1.565
房产 20,050 2.4% 1.726
时尚 13,368 1.6% 2.114
彩票 7,588 0.9% 2.805
星座 3,578 0.4% 4.085

数据特点:

  • 平均文本长度: 1,209字符
  • 平均token长度: 1,088 tokens
  • 超过512 tokens的样本: 65%
  • 类别不平衡比: 45.5:1 (科技:星座)

🏗️ 项目结构

BERT_text_classification/
├── config.py                          # 配置文件(超参数、路径)
├── model.py                           # BERT分类器模型定义
├── dataset.py                         # PyTorch Dataset类
├── train.py                           # 训练脚本
├── evaluate.py                        # 评估脚本
├── inference.py                       # 推理接口
│
├── data_preprocessing.py              # 数据预处理
├── dataset_split.py                   # 数据集划分
├── data_exploration.py                # 数据探索分析
│
├── requirements.txt                   # Python依赖
├── pyproject.toml                     # uv项目配置
│
├── THUCNews/                          # 原始数据集
│   ├── 体育/
│   ├── 娱乐/
│   └── ...
│
├── pretrained_model/                  # 预训练模型
│   └── chinese-roberta-wwm-ext/
│       ├── config.json
│       ├── pytorch_model.bin
│       └── vocab.txt
│
├── processed_data_balanced.json       # 预处理后的数据(131K样本)
├── train.json                         # 训练集(104,800样本)
├── val.json                           # 验证集(13,100样本)
├── test.json                          # 测试集(13,100样本)
│
├── checkpoints/                       # 模型检查点
│   ├── best_model.pt
│   └── checkpoint_epoch_*.pt
│
├── results/                           # 评估结果
│   ├── test_results.json
│   ├── confusion_matrix.png
│   └── per_class_metrics.png
│
├── logs/                              # 训练日志
│   └── training_history.json
│
└── docs/                              # 文档
    ├── ROBERTA_MODEL_ANALYSIS.md      # RoBERTa模型分析
    ├── MULTI_GPU_TRAINING.md          # 多GPU训练指南
    ├── MODEL_OUTPUT_ANALYSIS.md       # 模型输出层分析
    ├── PYTORCH_2.6_FIX.md             # PyTorch 2.6兼容性修复
    └── ...

🚀 快速开始

1. 环境要求

  • Python 3.12+
  • PyTorch 2.6+
  • CUDA 12.4+ (GPU训练)

2. 安装依赖

使用pip

pip install -r requirements.txt

使用uv (推荐)

# 安装uv
pip install uv

# 创建虚拟环境并安装依赖
uv sync

手动安装

  • python = 3.12
  • numpy
  • pandas
  • scipy
  • matplotlib
  • scikit-learn
  • tqdm
  • seaborn
  • torch+cu124 >= 2.6
  • transformers

3. 准备数据

下载THUCNews数据集

# 将THUCNews.zip解压到项目根目录
unzip THUCNews.zip

数据预处理

# 1. 数据预处理(清洗、标签编码)
python data_preprocessing.py

# 2. 数据集划分(train/val/test)
python dataset_split.py

4. 下载预训练模型

方式1: 自动下载(推荐)

python download_roberta.py

方式2: 手动下载

# 从HuggingFace下载
mkdir -p pretrained_model
cd pretrained_model
git lfs install
git clone https://huggingface.co/hfl/chinese-roberta-wwm-ext

5. 训练模型

单卡训练

python train.py

多卡训练(DDP)

# 使用4个GPU
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py

# 或使用torchrun (PyTorch 1.10+)
torchrun --nproc_per_node=4 train_ddp.py

6. 评估模型

python evaluate.py

7. 推理预测

from inference import TextClassifier

model_path = os.path.join(CHECKPOINT_DIR, 'best_model.pt')

# 加载模型
classifier = TextClassifier(
    model_path=model_path,
    pretrained_model_path=PRETRAINED_MODEL_PATH
)

# 预测单个文本
text = "中国男篮在世界杯上取得了胜利"
result = classifier.predict(text, return_probs=True)

print(f"预测类别: {result['category']}")
print(f"置信度: {result['confidence']:.4f}")
print(f"所有类别概率: {result['probabilities']}")

⚙️ 配置说明

核心超参数 (config.py)

# 模型配置
PRETRAINED_MODEL_PATH = 'pretrained_model/chinese-roberta-wwm-ext'
MAX_LENGTH = 512              # BERT最大序列长度
DROPOUT_RATE = 0.1            # Dropout比例

# 训练配置
BATCH_SIZE = 16               # 批次大小
LEARNING_RATE = 2e-5          # BERT层学习率
CLASSIFIER_LR = 1e-4          # 分类层学习率(5倍)
NUM_EPOCHS = 5                # 训练轮数
WARMUP_RATIO = 0.1            # Warmup比例
WEIGHT_DECAY = 0.01           # 权重衰减
MAX_GRAD_NORM = 1.0           # 梯度裁剪

# 类别权重
USE_CLASS_WEIGHTS = True      # 是否使用类别权重
CLASS_WEIGHT_METHOD = 'sqrt'  # 'inverse' 或 'sqrt'

# BERT微调
FREEZE_BERT = True           # False表示微调整个BERT

# 早停
EARLY_STOPPING_PATIENCE = 3   # 早停耐心值
BEST_MODEL_METRIC = 'f1'      # 保存最佳模型的指标

🎯 模型架构

Chinese-RoBERTa-wwm-ext

基本信息:

  • 开发者: 哈工大讯飞联合实验室 (HFL)
  • 参数量: 102M
  • 层数: 12层Transformer
  • 隐藏层维度: 768
  • 注意力头数: 12
  • 词汇表大小: 21,128

核心技术:

  • ✅ Whole Word Masking (全词遮罩)
  • ✅ 动态Masking
  • ✅ 扩展预训练语料 (9GB中文文本)
  • ✅ 移除NSP任务
  • ✅ 更大Batch Size

性能提升: 相比BERT-Base-Chinese,在中文NLP任务上平均提升1-3%

分类器架构

Input Text
    ↓
Tokenizer (BertTokenizer)
    ↓
[CLS] token_1 token_2 ... token_n [SEP] [PAD] ...
    ↓
BERT Encoder (12 layers, 768 hidden)
    ↓
[CLS] Representation (pooler_output)
    ↓
Dropout (p=0.1)
    ↓
Linear Layer (768 → 14)
    ↓
Logits (raw scores)
    ↓
CrossEntropyLoss (with built-in softmax)

关键设计:

  • 使用[CLS] token的pooler_output作为句子表示
  • 输出层返回logits(未经softmax)
  • CrossEntropyLoss内置softmax,数值更稳定
  • 分层学习率:BERT层2e-5,分类层1e-4

📈 训练策略

1. 类别不平衡处理

问题: 最大类别(科技)是最小类别(星座)的45.5倍

解决方案: 使用sqrt类别权重

weight_i = sqrt(total_samples / (num_classes * count_i))

效果:

  • 星座类权重: 4.085 (最高)
  • 科技类权重: 0.605 (最低)
  • 平衡了模型对少数类的关注

2. 分层学习率

optimizer = AdamW([
    {'params': bert_params, 'lr': 2e-5},      # BERT层较小学习率
    {'params': classifier_params, 'lr': 1e-4}  # 分类层较大学习率
])

原因:

  • BERT已经预训练好,需要小心微调
  • 分类层随机初始化,需要更快学习

3. 学习率调度

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps * 0.1,  # 10% warmup
    num_training_steps=total_steps
)

Warmup策略:

  • 前10%步数线性增加学习率
  • 之后线性衰减到0
  • 避免训练初期梯度爆炸

4. 早停机制

patience = 3  # 连续3个epoch无改进则停止
delta = 0.001  # 最小改进阈值

监控指标: F1分数(宏平均)

5. 梯度裁剪

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

防止梯度爆炸,稳定训练


📊 评估指标

主要指标

  • Accuracy: 整体准确率
  • Macro F1: 宏平均F1(所有类别平等权重)
  • Weighted F1: 加权F1(按样本数加权)
  • Per-class Metrics: 每个类别的精确率、召回率、F1

混淆矩阵

评估脚本会自动生成混淆矩阵可视化:

  • results/confusion_matrix.png
  • results/per_class_metrics.png

预期性能

基于Chinese-RoBERTa-wwm-ext在THUCNews上的预期表现:

指标 预期值
Accuracy 96.0-96.5%
Macro F1 95.5-96.0%
Weighted F1 96.0-96.5%

🖥️ 多GPU训练

方式1: DataParallel (简单)

# train.py中添加
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

优点: 代码改动最少 缺点: 性能较差,GPU 0负载更重

方式2: DistributedDataParallel (推荐)

# 单机4卡
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py

# 多机多卡
# 节点0 (主节点)
python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=0 \
    --master_addr="192.168.1.100" \
    --master_port=12345 \
    train_ddp.py

# 节点1
python -m torch.distributed.launch \
    --nproc_per_node=4 \
    --nnodes=2 \
    --node_rank=1 \
    --master_addr="192.168.1.100" \
    --master_port=12345 \
    train_ddp.py

优点: 性能最优,线性加速 缺点: 代码改动较多

方式3: Accelerate (最简单)

# 配置
accelerate config

# 训练
accelerate launch train.py

优点: API最简单,自动处理分布式 缺点: 需要额外依赖

性能对比

配置 时间/Epoch 加速比
单卡 120分钟 1.0x
2卡 DDP 60分钟 2.0x
4卡 DDP 30分钟 4.0x
8卡 DDP 15分钟 8.0x

详见: MULTI_GPU_TRAINING.md


🐛 常见问题

Q1: PyTorch 2.6 checkpoint加载错误

错误信息:

_pickle.UnpicklingError: Weights only load failed

解决方案:

# 在torch.load()中添加weights_only=False
checkpoint = torch.load(model_path, map_location=device, weights_only=False)

详见: PYTORCH_2.6_FIX.md

Q2: CUDA Out of Memory

解决方案:

# 1. 减小batch size
BATCH_SIZE = 8  # 从16减到8

# 2. 减小序列长度
MAX_LENGTH = 256  # 从512减到256

# 3. 使用梯度累积
GRADIENT_ACCUMULATION_STEPS = 2

# 4. 使用混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

Q3: 训练速度慢

优化方案:

# 1. 增加num_workers
DataLoader(..., num_workers=4, pin_memory=True)

# 2. 使用多GPU
python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py

# 3. 使用混合精度
with autocast():
    logits = model(...)

Q4: 模型过拟合

解决方案:

# 1. 增加Dropout
DROPOUT_RATE = 0.3  # 从0.1增加到0.3

# 2. 增加权重衰减
WEIGHT_DECAY = 0.1  # 从0.01增加到0.1

# 3. 使用早停
EARLY_STOPPING_PATIENCE = 2

# 4. 减少训练轮数
NUM_EPOCHS = 3

Q5: 少数类别性能差

解决方案:

# 1. 使用更强的类别权重
CLASS_WEIGHT_METHOD = 'inverse'  # 从'sqrt'改为'inverse'

# 2. 数据增强(针对少数类)
USE_DATA_AUGMENTATION = True

# 3. 过采样少数类
from imblearn.over_sampling import RandomOverSampler

📚 技术文档

核心文档

其他文档


🔧 开发指南

添加新的数据集

  1. 准备数据格式(每个类别一个文件夹)
  2. 修改config.py中的CATEGORIES
  3. 运行数据预处理:
python data_preprocessing.py
python dataset_split.py

切换预训练模型

  1. 下载新模型到pretrained_model/
  2. 修改config.py:
PRETRAINED_MODEL_PATH = 'pretrained_model/your-model'
  1. 确保模型兼容BertModel接口

自定义模型架构

修改model.py中的BertClassifier类:

class BertClassifier(nn.Module):
    def __init__(self, ...):
        # 添加自定义层
        self.lstm = nn.LSTM(768, 256, bidirectional=True)
        self.classifier = nn.Linear(512, num_classes)

📄 许可证

本项目仅供学习和研究使用。

数据集许可

  • THUCNews数据集版权归清华大学所有
  • 仅供学术研究使用

模型许可


🙏 致谢

  • 预训练模型: 哈工大讯飞联合实验室 (HFL)
  • 数据集: 清华大学自然语言处理实验室
  • 框架: PyTorch, HuggingFace Transformers

📧 联系方式

如有问题或建议,请提交Issue或Pull Request。


📝 更新日志

v1.0.0 (2026-03-11)

  • ✅ 完整的数据预处理流程
  • ✅ Chinese-RoBERTa-wwm-ext模型集成
  • ✅ 类别不平衡处理(sqrt权重)
  • ✅ BERT微调训练
  • ✅ 完整评估和推理接口
  • ✅ 多GPU训练支持
  • ✅ PyTorch 2.6兼容性修复
  • ✅ 完整技术文档

🎓 引用

如果本项目对您的研究有帮助,请引用:

@misc{bert_thucnews_classification,
  title={BERT Chinese Text Classification on THUCNews},
  author={Your Name},
  year={2026},
  howpublished={\url{https://github.com/yourusername/BERT_text_classification}}
}

Happy Coding! 🚀

About

Using BERT for chinese text classification with THUCNews dataset.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

Morty Proxy This is a proxified and sanitized view of the page, visit original site.