使用 BERT 和 Hugging Face 打造中文情感分類模型微調(fine-tuning)教學
日期:2024-12-03
使用 Hugging Face 和 BERT 模型進行中文情感分析,涵蓋從環境設置、數據集加載、分詞處理到模型訓練與測試的完整流程。
崴寶之前的教學文章中有使用過 bert-base-uncased
來訓練。
因此這次我們將使用 bert-base-chinese
來進行訓練😝
相關資料
- HuggingFace 資料集: ChnSentiCorp
- GitHub Repository: weitsung50110/Bert_HugginFace_Train_Predict
目錄結構
此教學文件路徑: Bert_HugginFace_Train_Predict/ChnSentiCorp_bert/
ChnSentiCorp_bert.py
→ 訓練模型ChnSentiCorp_bert_predict.py
→ 預測ChnSentiCorp_bert_test_inference.py
→ 測試集推論
ChnSentiCorp 資料集介紹
ChnSentiCorp 是一個中文情感分析資料集,包含來自酒店、筆記型電腦和書籍等商品的網購評論。該資料集共包含:
- 9600 條訓練數據
- 1200 條驗證數據
- 1200 條測試數據
每條評論都被標註為正面(1)或負面(0),用於訓練和評估情感分類模型。
崴寶使用的是 lansinuote/ChnSentiCorp
,它是 Hugging Face 上的用戶 lansinuote
提供的 ChnSentiCorp 資料集版本。該版本以 Parquet 格式存儲。
Docker
- weitsung50110/bert_huggingface: 此為我安裝好的 Docker image 環境。
docker run
等使用說明請進入 Docker Hub 裡面的說明欄查看。
可以使用以下命令來拉取 Docker image:
docker pull weitsung50110/bert_huggingface:1.0
1. 選擇訓練設備
程式碼首先會檢查是否有可用的 GPU(CUDA),如果有的話,會優先使用 GPU,否則退回到 CPU。
import torch
# 自動檢查是否有可用的 GPU,並選擇適當的設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.cuda.is_available()
用來檢查當前系統是否支援 CUDA(NVIDIA 的 GPU 技術)。- 如果有 GPU,設備會設為
"cuda"
;如果沒有,則設為"cpu"
。
2. 加載本地數據集
第一種方式:手動指定文件路徑
這段程式碼會從本地的 .parquet
文件加載數據集,並使用 Hugging Face 的 datasets
庫進行處理。
data_files = {
"train": "./ChnSentiCorp/data/train-00000-of-00001-02f200ca5f2a7868.parquet",
"validation": "./ChnSentiCorp/data/validation-00000-of-00001-405befbaa3bcf1a2.parquet",
"test": "./ChnSentiCorp/data/test-00000-of-00001-5372924f059fe767.parquet",
}
dataset = load_dataset("parquet", data_files=data_files)
print(dataset)
data_files
是一個字典,指定了訓練集(train)、驗證集(validation)和測試集(test)的檔案路徑。load_dataset("parquet", data_files=data_files)
會根據檔案類型(parquet
)加載數據集,並自動分配為訓練集、驗證集和測試集。print(dataset)
用來查看數據集的結構,例如訓練集和測試集的樣本數量。
第二種方式:直接從 Hugging Face Hub 加載
也可以使用 Hugging Face 的 load_dataset
方法直接加載 lansinuote/ChnSentiCorp
數據集。
dataset = load_dataset("lansinuote/ChnSentiCorp")
兩種方式比較
1. 手動指定文件路徑
本地文件:
- 需要手動指定數據集的路徑,通常用於本地保存的數據集。
- 支援多種文件格式(如
parquet
、csv
、json
)。
靈活性:
- 可以任意指定數據集的分割,例如使用不同的文件作為
train
、validation
和test
。 - 適合用於自定義數據集或本地數據處理。
無需網絡連接:
- 本地文件不需要網絡連接,適合離線環境。
2. 直接從 Hugging Face Hub 加載
線上數據集:
- 直接從 Hugging Face 的數據集庫下載,無需本地文件。
- 只需要提供數據集名稱(如
lansinuote/ChnSentiCorp
)。
內置分割:
- 數據集已經按照標準分割(
train
、validation
、test
),用戶無需手動指定。 - 例如,
ChnSentiCorp
的train
、validation
、test
自動匹配到對應的.parquet
文件。
需要網絡連接:
- 需要從 Hugging Face Hub 獲取數據,因此需要網絡支持。
3. 加載分詞器和模型
這段程式碼使用 Hugging Face 的 BertTokenizer
和 BertForSequenceClassification
來加載模型和分詞器。
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=2)
# 將模型移動到指定的設備
model.to(device)
bert-base-chinese
是預訓練的 BERT 中文模型,適合處理中文文本。num_labels=2
表示這是一個二分類問題,例如「正向情感」和「負向情感」。model.to(device)
將模型移動到 GPU 或 CPU,以確保運算與設備一致。
4. 數據預處理
在這部分,程式碼會對數據集進行分詞處理,將文本轉換為 BERT 所需的輸入格式。
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
# 對數據集進行分詞處理
tokenized_datasets = dataset.map(preprocess_function, batched=True)
preprocess_function
定義了如何處理每條樣本的文本:
truncation=True
:超過 128 個字的文本會被截斷。padding="max_length"
:文本會補齊到固定長度(128)。max_length=128
:設定文本最大長度為 128。
dataset.map()
用來對整個數據集應用分詞函數,並將結果存到 tokenized_datasets
。
5. 訓練參數設置
這段程式碼設定了模型的訓練參數。
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="./logs",
save_strategy="epoch",
)
output_dir="./results"
:訓練結束後,模型會保存在./results
目錄下。evaluation_strategy="epoch"
:每個訓練週期(epoch)結束後執行一次驗證。learning_rate=2e-5
:學習率設為 2e-5。per_device_train_batch_size=16
:每個設備(GPU 或 CPU)上的訓練批次大小為 16。num_train_epochs=3
:訓練 3 個週期。save_strategy="epoch"
:每個週期結束後保存模型。
6. 定義評估函數
這段程式碼定義了模型的評估指標,包括準確率、精度、召回率和 F1 分數。
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
pred.label_ids
是真實的標籤。pred.predictions.argmax(-1)
是模型的預測標籤。- 使用
precision_recall_fscore_support
和accuracy_score
來計算性能指標。
7. 初始化 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
compute_metrics=compute_metrics,
)
Trainer
是 Hugging Face 提供的高級接口,用於簡化訓練過程。
model=model
:指定要訓練的模型。args=training_args
:設定訓練參數。train_dataset
和eval_dataset
:指定訓練集和驗證集。compute_metrics
:指定評估指標。
8. 開始訓練
這行程式碼會啟動訓練過程。
trainer.train()
train()
方法會自動執行分批次訓練,並在每個週期後執行評估。
root@08d912ebf816:/app/Huggingfacetransformer# python ChnSentiCorp_bert.py
2024-12-02 08:17:14.706343: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-02 08:17:15.668363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Using device: cpu
Generating train split: 9600 examples [00:00, 357091.84 examples/s]
Generating validation split: 1200 examples [00:00, 308877.86 examples/s]
Generating test split: 1200 examples [00:00, 252706.97 examples/s]
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 9600
})
validation: Dataset({
features: ['text', 'label'],
num_rows: 1200
})
test: Dataset({
features: ['text', 'label'],
num_rows: 1200
})
})
/usr/local/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|█████████████████████████████████████████████████████| 9600/9600 [00:10<00:00, 910.55 examples/s]
Map: 100%|█████████████████████████████████████████████████████| 1200/1200 [00:01<00:00, 914.24 examples/s]
Map: 100%|█████████████████████████████████████████████████████| 1200/1200 [00:01<00:00, 919.64 examples/s]
Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
{'loss': 0.2742, 'grad_norm': 13.644034385681152, 'learning_rate': 1.4444444444444446e-05, 'epoch': 0.83}
{'eval_loss': 0.24989934265613556, 'eval_accuracy': 0.9175, 'eval_f1': 0.9151670951156813, 'eval_precision': 0.9303135888501742, 'eval_recall': 0.9005059021922428, 'eval_runtime': 26.6792, 'eval_samples_per_second': 44.979, 'eval_steps_per_second': 2.811, 'epoch': 1.0}
{'loss': 0.1637, 'grad_norm': 13.341170310974121, 'learning_rate': 8.888888888888888e-06, 'epoch': 1.67}
{'eval_loss': 0.22070790827274323, 'eval_accuracy': 0.9375, 'eval_f1': 0.9355116079105761, 'eval_precision': 0.9543859649122807, 'eval_recall': 0.9173693086003373, 'eval_runtime': 26.3692, 'eval_samples_per_second': 45.508, 'eval_steps_per_second': 2.844, 'epoch': 2.0}
{'loss': 0.1153, 'grad_norm': 15.839860916137695, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.5}
{'eval_loss': 0.27151966094970703, 'eval_accuracy': 0.9366666666666666, 'eval_f1': 0.935374149659864, 'eval_precision': 0.9433962264150944, 'eval_recall': 0.927487352445194, 'eval_runtime': 26.3905, 'eval_samples_per_second': 45.471, 'eval_steps_per_second': 2.842, 'epoch': 3.0}
{'train_runtime': 2447.078, 'train_samples_per_second': 11.769, 'train_steps_per_second': 0.736, 'train_loss': 0.165834772321913, 'epoch': 3.0}
100%|██████████████████████████████████████████████████████████████████| 1800/1800 [40:47<00:00, 1.36s/it]
9. 保存模型
訓練完成後,將模型和分詞器保存到本地。
model.save_pretrained("./sentiment_model")
tokenizer.save_pretrained("./sentiment_model")
- 保存的模型和分詞器可以用來做推理或重新加載。
10. 測試模型
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
inputs = {key: value.to(device) for key, value in inputs.items()} # 將數據移動到設備
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(-1).item()
label_map = {0: "Negative", 1: "Positive"}
return label_map[predicted_class]
- 將輸入文本轉換為模型的輸入格式,並移動到指定設備(CPU 或 GPU)。
- 使用
outputs.logits.argmax(-1)
獲取預測類別。 - 返回對應的標籤名稱(
Positive
或Negative
)。
ChnSentiCorp_bert_predict.py - 預測結果
root@08d912ebf816:/app/Huggingfacetransformer# python ChnSentiCorp_bert_predict.py
請輸入文本進行情感分析(輸入 'exit' 結束程式):
輸入文本:我喜歡這間飯店
Text: 我喜歡這間飯店
Prediction: Positive
輸入文本:好失望 飯店好髒
Text: 好失望 飯店好髒
Prediction: Negative
輸入文本:我覺得這本書好可愛
Text: 我覺得這本書好可愛
Prediction: Negative
輸入文本:我覺得這飯店超級乾淨的
Text: 我覺得這飯店超級乾淨的
Prediction: Positive
輸入文本:我覺得飯店好可愛
Text: 我覺得飯店好可愛
Prediction: Negative
輸入文本:這飯店好乾淨
Text: 這飯店好乾淨
Prediction: Positive
輸入文本:這飯店好噁
Text: 這飯店好噁
Prediction: Negative
輸入文本:飯店打掃五顆心
Text: 飯店打掃五顆心
Prediction: Positive
輸入文本:飯店打掃0顆心
Text: 飯店打掃0顆心
Prediction: Positive
輸入文本:飯店打掃負評
Text: 飯店打掃負評
Prediction: Negative
輸入文本:飯店打掃超乾淨
Text: 飯店打掃超乾淨
Prediction: Positive
輸入文本:飯店打掃超髒
Text: 飯店打掃超髒
Prediction: Negative
輸入文本:這是一家很棒的飯店
Text: 這是一家很棒的飯店
Prediction: Positive
輸入文本:這是一家很糟糕的飯店
Text: 這是一家很糟糕的飯店
Prediction: Negative
輸入文本:bye
程式結束,再見!
ChnSentiCorp
數據集的標籤主要基於評論的情感(正面或負面),但它的標籤不涵蓋所有的中文語境或表達方式,如「可愛
」、「五顆心
」、「0顆心
」之類的。
ChnSentiCorp_bert_test_inference.py
Run 程式碼後會跑出 1200 個範例,好崴寶先列出前 4 個:
root@08d912ebf816:/app/Huggingfacetransformer# python ChnSentiCorp_bert_test_inference.py
使用的設備: cpu
範例 1:
內容: 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般
真實標籤: Positive
模型預測: Negative
--------------------------------------------------
範例 2:
內容: 怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一集算什么??简直是画蛇添足!!
真實標籤: Negative
模型預測: Negative
--------------------------------------------------
範例 3:
內容: 还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。
真實標籤: Negative
模型預測: Negative
--------------------------------------------------
範例 4:
內容: 交通方便;环境很好;服务态度很好 房间较小
真實標籤: Positive
模型預測: Positive
.
.
.
範例 1199:
內容: 房间不错,只是上网速度慢得无法忍受,打开一个网页要等半小时,连邮件都无法收。另前台工作人员服务态度是很好,只是效率有得改善。
真實標籤: Positive
模型預測: Positive
--------------------------------------------------
範例 1200:
內容: 挺失望的,还不如买一本张爱玲文集呢,以<色戒>命名,可这篇文章仅仅10多页,且无头无尾的,完全比不上里面的任意一篇其它文章.
真實標籤: Negative
模型預測: Negative
--------------------------------------------------
範例 1 標註錯誤:
內容:**”这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般”**
很明顯,這是一個偏向 負面情感 的評論,因為提到了「比较陈旧了」和「总体来说一般」,這些詞語通常表示不滿意或低於期望。
但真實標籤標記為 Positive,這是 ChnSentiCorp 數據標註的錯誤。
範例 1045:
內容: 优点: 外观够有型,配置很不错,价格合理,非常适合商务使用 不足: 光驱偶尔声音真的很大,底部发热量很大。 总结: 特价4999的价格购买还是很超值的,值得
真實標籤: Negative
模型預測: Positive
範例 900:
內容: 住这个酒店实在是太享受了,不仅可以使用五彩缤纷的白毛巾,还可以免费听赏别人KTV包厢里的高音演奏.!! 以下是某某歌奏家的话,,,,,谢谢 谢谢大家 首先我要感谢我的父母 还要感谢背后支持我的朋友们 让我有了以噪音感化大家的机会 现在我再为大家献上一首!!希望大家夜不能眠!!!!!!!
真實標籤: Negative
模型預測: Positive
範例 772:
內容: 喜欢AMD的芯片,性价比不错。这款散热不错喜欢AMD的芯片,性价比不错。这款散热不错
真實標籤: Negative
模型預測: Positive
崴寶還有找到測試集的幾個標註錯誤…
訓練集崴寶沒有時間去檢查,
之後崴寶可能會再找找有沒有其他資料集可以做教學!!
崴寶總結
- 設備選擇:自動檢查並選擇 GPU 或 CPU。
- 數據處理:將本地數據集分為訓練集、驗證集和測試集,並對文本進行分詞。
- 模型訓練:使用 BERT 中文模型進行訓練,並通過
Trainer
簡化訓練流程。 - 性能評估:計算模型的準確率、精度、召回率和 F1 分數。
- 模型保存與測試:保存訓練好的模型,並對新文本進行推理。
- 未來:預計會使用
roberta-chinese
模型進行訓練。
喜歡 好崴寶 Weibert Weiberson 的文章嗎?在這裡留下你的評論!本留言區支援 Markdown 語法。