Skip to content

Instantly share code, notes, and snippets.

@kasuganosora
Last active March 25, 2023 09:28
Show Gist options
  • Save kasuganosora/cd93062861944860fa33c74199aa3c44 to your computer and use it in GitHub Desktop.
Save kasuganosora/cd93062861944860fa33c74199aa3c44 to your computer and use it in GitHub Desktop.
中文分词模型
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import BertTokenizer, BertModel,AdamW
from torch.utils.data import DataLoader, Dataset
import random
import pandas as pd
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence
from torch.cuda.amp import GradScaler, autocast
import os
class BERT_CRF(nn.Module):
def __init__(self, bert, num_tags):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(bert.config.hidden_size, num_tags)
self.crf = CRF(num_tags, batch_first=True)
def forward(self, input_ids, attention_mask, tags=None):
outputs = self.bert(input_ids, attention_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
mask = attention_mask.bool()
if tags is not None:
loss = -self.crf(logits, tags, mask=mask, reduction='mean')
return loss
else:
return logits
def decode(self, logits, attention_mask):
mask = attention_mask.bool()
return self.crf.decode(logits, mask=mask)
# 创建标签到ID的映射
tag2id = {"S": 0, "B": 1, "M": 2, "E": 3, 'P': 4, "L": 5, "PAD": 6}
def tags_to_segmented_text(tag_ids, id2tag, text):
tags = [tag for tag_id_list in tag_ids for tag_id in tag_id_list for tag, id_ in tag2id.items() if tag_id == id_]
segmented_text = ""
for idx, (char, tag) in enumerate(zip(text, tags)):
if tag in ["S", "B"]:
segmented_text += " / " + char
elif tag in ["M", "E"]:
segmented_text += char
elif tag in ["P", "L"]:
segmented_text += " / " + char
return segmented_text.strip()
# 构建数据集
class ChineseSegmentationDataset(Dataset):
def __init__(self, data, tokenizer, tag2id, max_length=256):
self.data = data
self.tokenizer = tokenizer
self.tag2id = tag2id
self.max_length = max_length
self.data = self._filter_data(data)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text, tags = self.data[idx]
# 使用encode_plus方法
inputs = self.tokenizer.encode_plus(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
# 调整tags长度以与输入tokens对齐
tags_list = list(tags)
padded_tags = tags_list[:self.max_length - 2] + ["PAD"] * (self.max_length - len(tags_list) - 2)
tag_ids = [self.tag2id[tag] for tag in padded_tags]
return {
"input_ids": inputs["input_ids"].squeeze(),
"attention_mask": inputs["attention_mask"].squeeze(),
"tags": torch.tensor(tag_ids, dtype=torch.long),
}
def _filter_data(self, data):
filtered_data = []
for text, tags in data:
if not text.strip():
continue
if len(text) != len(tags):
print(f"Skipping data: {text} due to inconsistent length between tokens and tags")
continue
if len(text) > self.max_length -2:
print(f"Skipping data: {text} due to length exceeding max length of {self.max_length}")
continue
filtered_data.append((text, tags))
return filtered_data
def collate_fn(batch):
max_length = max([len(item["input_ids"]) for item in batch])
input_ids = torch.zeros((len(batch), max_length), dtype=torch.long)
attention_mask = torch.zeros((len(batch), max_length), dtype=torch.long)
tags = torch.zeros((len(batch), max_length), dtype=torch.long)
for i, item in enumerate(batch):
input_ids[i, :len(item["input_ids"])] = item["input_ids"]
attention_mask[i, :len(item["attention_mask"])] = item["attention_mask"]
tags[i, :len(item["tags"])] = item["tags"]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"tags": tags,
}
# 设置随机种子以获得可重复的结果
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 从CSV文件中读取数据
csv_file = "data/tweet.js.csv" # 替换为你的CSV文件名
df = pd.read_csv(csv_file, header=None, names=["text", "tags"])
# 将数据分为训练集和验证集
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
# 加载预训练的中文BERT分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
# 将DataFrame转换为数据集所需的元组列表格式
train_data = list(train_df.itertuples(index=False, name=None))
val_data = list(val_df.itertuples(index=False, name=None))
# 使用加载的数据创建训练和验证数据集
train_dataset = ChineseSegmentationDataset(train_data, tokenizer, tag2id)
val_dataset = ChineseSegmentationDataset(val_data, tokenizer, tag2id)
set_seed(42)
# 创建BERT_CRF模型实例
bert_model = BertModel.from_pretrained("bert-base-chinese")
model = BERT_CRF(bert_model, num_tags=len(tag2id)).to(device)
model.train()
# 创建DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
# 设置优化器和学习率
optimizer = AdamW(model.parameters(), lr=1e-5)
scaler = GradScaler()
# 判断之前的模型是否存在,如果存在则加载之前的模型
if os.path.exists("model_p.pth"):
model.load_state_dict(torch.load("model_p.pth"))
# 训练循环
num_epochs = 100
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(train_dataloader):
input_ids = batch["input_ids"].squeeze().to(device)
attention_mask = batch["attention_mask"].squeeze().to(device)
tags = batch["tags"].squeeze().to(device)
with autocast():
loss = model(input_ids, attention_mask, tags)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{len(train_dataloader)}, Loss: {loss.item()}")
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
torch.save(model.state_dict(), f"model_p.pth")
# 验证阶段
# 每2个epoch进行一次验证
if (epoch + 1) % 2 == 0:
model.eval()
val_loss = 0
with torch.no_grad():
for batch_idx, batch in enumerate(val_dataloader):
input_ids = batch["input_ids"].squeeze().to(device)
attention_mask = batch["attention_mask"].squeeze().to(device)
tags = batch["tags"].squeeze().to(device)
loss = model(input_ids, attention_mask, tags)
val_loss += loss.item()
val_loss /= len(val_dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss.item()}, Val Loss: {val_loss}")
# 将模型切换回训练模式
model.train()
# 保存模型
torch.save(model.state_dict(), "model.pth")
from a import *
from transformers import BertTokenizer, BertModel
import torch
import os
bert_model = BertModel.from_pretrained("bert-base-chinese")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERT_CRF(bert_model, num_tags=len(tag2id)).to(device)
# 加载预训练的中文BERT分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
# 将模型设置为评估模式
model.eval()
# 加载权重
if os.path.exists("model_p.pth"):
model.load_state_dict(torch.load("model_p.pth"))
else:
print("model not exist")
os._exit(1)
# 预测文本
text = """下面这个网址是https://www.google.com/,这是一个搜索引擎"""
# 使用分词器对输入文本进行编码
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=256, truncation=True)
# 确保输入的数据类型正确
inputs["input_ids"] = inputs["input_ids"].long().to(device)
inputs["attention_mask"] = inputs["attention_mask"].bool().to(device)
# 从模型中获取logits
with torch.no_grad():
logits = model(inputs["input_ids"], inputs["attention_mask"])
# 使用decode方法从logits中预测出tag_ids
tag_ids = model.decode(logits, inputs["attention_mask"].bool())
# 将预测的tag_ids转换为分词后的文本
segmented_text = tags_to_segmented_text(tag_ids, tag2id, text)
print(segmented_text)
import jieba
import json
import os
import string
import glob
import csv
# 把twitter 导出来的推文作为分词语料, 使用jieba分词
# S: 单字成词
# B: 词的开始
# M: 词的中间
# E: 词的结束
# P: 标点符号
tag2id = {"S": 0, "B": 1, "M": 2, "E": 3, 'P':4, 'U': 5}
# 读取语料
def read_corpus(corpus_path):
data = []
# 打开语料的js文件, 然后把第一行的直 [ 之前的内容去掉
with open(corpus_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
lines[0] = "[ {"
lines = "".join(lines)
# 作为json格式读取
items = json.loads(lines)
for item in items:
text = item["full_text"]
# 将连续空格和回车替换为单个空格
text = text.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
text = ' '.join(text.split()) # 也可以使用正则表达式来替换
text = text.strip()
tags = segment(text)
if len(tags) == 0:
continue
if len(tags) != len(text):
print("error")
print(text)
print(tags)
continue
data.append((text, ''.join(tags)))
return data
# 分词
def segment(text):
words = jieba.cut(text)
nw = []
for w in words:
nw.append(w)
words = nw
nw = []
l = len(words)
i = 0
# 如果遇到@/#的时候后面一段是字 并且后面有一个空格, 那么就把@/#和后面的字合并
while i < l:
w = words[i]
if w == "@" or w == "#":
s = []
while i < l and words[i] != ' ':
s.append(words[i])
i = i + 1
nw.append(''.join(s))
continue
nw.append(w)
i = i + 1
words = nw
# 生成标签
tags = []
i = 0
l = len(words)
while i < l:
w = words[i]
# 如果是单字并且是符号, 则标记为P
if len(w) == 1 and w in string.punctuation or w == ' ':
tags.append('P')
i = i + 1
continue
# 如果是单字并且不是符号, 则标记为S
if len(w) == 1:
tags.append('S')
i = i + 1
continue
# 判断是否是链接,如果是链接则开头标记为B,中间标记为M,结尾标记为E
# 判断依据是当前第二个词是否为":",第三个词是否为"/",第四个词是否为"/"
if l > i+3 and words[i+1] == ":" and words[i+2] == "/" and words[i+3] == "/":
# 例如: https://www.google.com 那么就是: BMMMMMMMMMMMMMMMMMME
first = True
while i < l and words[i] != '':
for w in words[i]:
if first:
tags.append('B')
first = False
else:
tags.append('M')
i = i + 1
tags[-1] = 'E'
continue
# 如果是多字, 则开头标记为B,中间标记为M,结尾标记为E
first = True
for w in w:
if first:
tags.append('B')
first = False
else:
tags.append('M')
tags[-1] = 'E'
i = i + 1
return tags
if __name__ == '__main__':
# 在当前目录下创建data文件夹
if not os.path.exists('data'):
os.mkdir('data')
corpusPath = "E:\\tw"
#遍历下面的 tweet-part*.json 文件
file_list = glob.glob(corpusPath + '\\tweet-part*.js')
file_list.append("E:\\tw\\tweet.js")
print(file_list)
for file in file_list:
# 建立语料的cvs文件
dataFileName = "data/" + os.path.basename(file) + '.csv'
with open(dataFileName, 'w', encoding='utf-8') as f:
data = read_corpus(file)
writer = csv.writer(f)
writer.writerows(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment