cube-studio/job-template/job/ner/models/BERT_BiLSTM_CRF.py
2022-08-13 17:00:41 +08:00

163 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, tagset_size, config, tag2id, dropout=0.1, batch_first=True):
super(BiLSTM_CRF, self).__init__()
self.embedding_dim = config.input_size
self.hidden_dim = config.hidden_size
self.vocab_size = vocab_size
self.tagset_size = tagset_size
# self.vocabulary = word2id
self.tag2id = tag2id
# self.word_embeds = BertModel.from_pretrained('bert-base-chinese')
self.word_embeds = BertModel.from_pretrained('bert_model/bert-pretrain/Epoch45_Batchsize16_Loss0.580981_DateTime20220601_095801')
self.bilstm = nn.LSTM(
input_size=768,
hidden_size=self.hidden_dim // 2,
num_layers=1,
bidirectional=True,
batch_first=batch_first
)
self.hidden2tag = nn.Linear(self.hidden_dim, self.tagset_size, bias=True)
self.transition = nn.Parameter(
# torch.randn(self.tagset_size, self.tagset_size)
# torch.ones(self.tagset_size, self.tagset_size)
torch.ones(self.tagset_size, self.tagset_size) * 1 / self.tagset_size
)
def _get_lstm_features(self, sentences):
# self.hidden = self._init_hidden()
# print('sentence: ', type(sentences))
# print(sentences.device)
# embeds = self.word_embeds(sentences).last_hidden_state[:, 1:-1, :]
embeds = self.word_embeds(sentences).last_hidden_state
# print(embeds.shape)
lstm_out, _ = self.bilstm(embeds)
lstm_feats = self.hidden2tag(lstm_out)
return lstm_feats
def forward(self, sentences):
# B, L, out_size(tagset_size)
emission = self._get_lstm_features(sentences)
# calculate CRF scores 这个scores的大小为[B, L, out_size, out_size]
# every Chinese Character map to a matrix of [tagset_size, tagset_size]
# 该矩阵中的第i行 第j列的元素含义为上一时刻tag为i这一时刻tag为j的分数
# crf_scores = emission.unsqueeze(2).expand(-1, -1, self.tagset_size, -1) + self.transition.unsqueeze(0)
crf_scores = emission.unsqueeze(2).expand(-1, -1, self.tagset_size, -1) + self.transition.unsqueeze(0)
return crf_scores
def neg_log_likelihood(self, crf_scores, targets):
"""crf loss
此方法并行计算batch中每个样本的loss 并最终求和
此方法计算消耗时长较短
crf loss 负对数似然损失
loss = exp(gold_path) / sum(exp(all_path))
log(loss) = log(exp(gold_path)) - log(sum(exp(all_path)))
= gold_path_score - all_path_score
-log(loss) = all_path_score = gold_path)score"""
assert len(crf_scores) == len(targets)
PAD_ID = self.tag2id['<pad>']
START_ID = self.tag2id['<start>']
END_ID = self.tag2id['<end>']
device = targets.device
mask = (targets != PAD_ID)
batch_size, max_len = targets.shape
lengths = mask.sum(dim=1) # 每一个序列的实际长度
# gold_path_score 正确标签得分的和
# crf_scores矩阵形为[batch_size, max_len, tagset_size, tagset_size]
# crf_scores的第一个维度以句子为单位 crf_scores的第二个维度以单个字为单位
# 每个字由一个[i=tagset_size, j=tagset_size]矩阵表示,该矩阵的具体含义为
# 在前一个字的标签为i的前提下当前字的标签为j的概率
# 即 golden_socres 的值为每一个字矩阵中正确[i, j]下标的概率值之和
# target tagset_size*target_size的矩阵 保存该字正确的tag索引
# former_target tagset_size*target_size的矩阵 保存该字前一个字正确的tag索引
former_targets = torch.zeros_like(targets)
former_targets[:, 0] = START_ID
former_targets[:, 1:max_len] = targets[:, 0:max_len-1]
# 根据当前字正确的tag索引抽取 即 j
crf_scores_j = crf_scores.gather(
dim=3,
index=targets.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.tagset_size, -1)
)
# 根据前一个字正确的tag索引抽取 即 i
crf_score_i_j = crf_scores_j.gather(
dim=2,
index=former_targets.unsqueeze(-1).unsqueeze(-1)
)
# 抽取每个样本最后一个字符的tag索引
last_tag_index = targets.gather(dim=1, index=(lengths.unsqueeze(-1)-1)).squeeze(1)
# 计算每个样本最后一个字符转换到<end>标签的概率
end_score = self.transition[:, END_ID].gather(dim=0, index=last_tag_index)
# 所有样本目标路径去除 pad 求和 + 所有样本最后一个字符由正确tag转移到<end>标签的概率 求和
gold_score = crf_score_i_j.masked_select(mask.unsqueeze(-1).unsqueeze(-1)).sum() + end_score.sum()
# all_path_score 计算所有可能的值的和
current_scores = torch.zeros(batch_size, self.tagset_size).to(device)
for step in range(max_len):
# 当前时刻 有效的batch_size因为有些序列比较短)
batch_size_step = (lengths > step).sum().item()
if step == 0:
current_scores[:batch_size_step] = crf_scores[:batch_size_step, step, START_ID, :]
else:
# We add scores at current timestep to scores accumulated up to previous
# timestep, and log-sum-exp Remember, the cur_tag of the previous
# timestep is the prev_tag of this timestep
# So, broadcast prev. timestep's cur_tag scores
# along cur. timestep's cur_tag dimension
current_scores[:batch_size_step] = torch.logsumexp(
current_scores[:batch_size_step].unsqueeze(2) + crf_scores[:batch_size_step, step, :, :],
dim=1
)
# 所有样本到达最后一个字符的所有可能路径继续转移到<end>标签的路径
all_scores = torch.logsumexp(current_scores + self.transition[:, END_ID], dim=1).sum()
loss = (all_scores - gold_score) / batch_size
return loss
def vibiter_decoding(self, crf_score, ):
"""viterbi decoding"""
start_id = self.tag2id['<start>']
end_id = self.tag2id['<end>']
pad = self.tag2id['<pad>']
device = crf_score.device
seq_len = crf_score.shape[0]
viterbi = torch.zeros(seq_len, self.tagset_size).to(device)
backpointer = (torch.ones(size=(seq_len, self.tagset_size), dtype=torch.long) * end_id).to(device)
for step in range(seq_len):
if step == 0: # 第一个字
viterbi[step, :] = crf_score[step, start_id, :]
backpointer[step, :] = start_id
else:
max_scores, prev_tags_id = torch.max(
viterbi[step-1, :].unsqueeze(1) + crf_score[step, :, :],
dim = 0
)
viterbi[step, :] = max_scores
backpointer[step, :] = prev_tags_id
# 最后一个字符所属的所有标签中继续转移到<end>标签得分最高的tag便是最后一个字符的tag
best_end_idx = torch.argmax(max_scores + self.transition[:,end_id]).item()
best_path = [best_end_idx]
for step in range(seq_len-1, 0, -1):
best_path.append(backpointer[step, best_path[-1]].item())
best_path.reverse()
return best_path