cube-studio/job-template/job/ner/train_evaluate.py

80 lines
2.7 KiB
Python
Raw Normal View History

2022-08-13 17:00:41 +08:00
from models.HMM import HMM
from models.CRF import CRFModel
from bilstm_opration import BiLSTM_opration
from bilstm_crf_opration import BiLSTM_CRF_opration
from evaluating import Metrics
from utils.utils import save_model, add_end_tag
def hmm_train_eval(file_path, train_data,test_data,word2id,tag2id,remove_0=False):
"""hmm模型的评估与训练"""
print("hmm模型的评估与训练...")
train_word_lists,train_tag_lists = train_data
test_word_lists,test_tag_lists = test_data
# 模型训练
hmm_model = HMM(N=len(tag2id), M=len(word2id))
hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
# save_model(hmm_model,"./ckpts/hmm.pkl")
# 模型评估
pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)
metrics = Metrics(file_path, test_tag_lists,pred_tag_lists)
metrics.report_scores(dtype='HMM')
return hmm_model
def crf_train_eval(file_path, train_data, test_data, remove_0=False):
"""crf模型的评估与训练"""
print("crf模型的评估与训练")
train_word_lists, train_tag_lists = train_data
test_word_lists, test_tag_lists = test_data
crf_model = CRFModel()
crf_model.train(train_word_lists, train_tag_lists)
# save_model(crf_model, "./ckpts/crf.pkl")
pred_tag_lists = crf_model.test(test_word_lists)
metrics = Metrics(file_path, test_tag_lists, pred_tag_lists)
metrics.report_scores(dtype='CRF')
return crf_model
def bilstm_train_eval(file_path, train_data, dev_data, test_data, word2id, tag2id):
"""BiLSTM模型的评估与训练"""
print("BiLSTM模型的评估与训练")
bilstm_model = BiLSTM_opration(
train_data=train_data,
dev_data=dev_data,
test_data=test_data,
word2id=word2id,
tag2id=tag2id
)
bilstm_model.train()
bilstm_model.evaluate(file_path=file_path)
return bilstm_model
def bilstm_crf_train_eval(file_path, train_data, dev_data, test_data, word2id, tag2id):
"""BiLSTM_CRF模型的评估与训练"""
print("BiLSTM_CRF模型的评估与训练")
train_word_lists, train_tag_lists = train_data
dev_word_lists, dev_tag_lists = dev_data
test_word_lists, test_tag_lists = test_data
train_word_lists, train_tag_lists = add_end_tag(train_word_lists, train_tag_lists)
dev_word_lists, dev_tag_lists = add_end_tag(dev_word_lists, dev_tag_lists)
test_word_lists = add_end_tag(test_word_lists)
bilstm_crf_model = BiLSTM_CRF_opration(
train_data=(train_word_lists, train_tag_lists),
dev_data=(dev_word_lists, dev_tag_lists),
test_data=(test_word_lists, test_tag_lists),
word2id=word2id,
tag2id=tag2id
)
bilstm_crf_model.train()
bilstm_crf_model.evaluate(file_path=file_path)
return bilstm_crf_model