mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2024-12-21 06:19:31 +08:00
80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
|
from models.HMM import HMM
|
||
|
from models.CRF import CRFModel
|
||
|
from bilstm_opration import BiLSTM_opration
|
||
|
from bilstm_crf_opration import BiLSTM_CRF_opration
|
||
|
from evaluating import Metrics
|
||
|
from utils.utils import save_model, add_end_tag
|
||
|
|
||
|
def hmm_train_eval(file_path, train_data,test_data,word2id,tag2id,remove_0=False):
|
||
|
"""hmm模型的评估与训练"""
|
||
|
print("hmm模型的评估与训练...")
|
||
|
train_word_lists,train_tag_lists = train_data
|
||
|
test_word_lists,test_tag_lists = test_data
|
||
|
|
||
|
# 模型训练
|
||
|
hmm_model = HMM(N=len(tag2id), M=len(word2id))
|
||
|
hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
|
||
|
# save_model(hmm_model,"./ckpts/hmm.pkl")
|
||
|
|
||
|
# 模型评估
|
||
|
pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)
|
||
|
metrics = Metrics(file_path, test_tag_lists,pred_tag_lists)
|
||
|
metrics.report_scores(dtype='HMM')
|
||
|
|
||
|
return hmm_model
|
||
|
|
||
|
def crf_train_eval(file_path, train_data, test_data, remove_0=False):
|
||
|
"""crf模型的评估与训练"""
|
||
|
print("crf模型的评估与训练")
|
||
|
train_word_lists, train_tag_lists = train_data
|
||
|
test_word_lists, test_tag_lists = test_data
|
||
|
crf_model = CRFModel()
|
||
|
crf_model.train(train_word_lists, train_tag_lists)
|
||
|
# save_model(crf_model, "./ckpts/crf.pkl")
|
||
|
|
||
|
pred_tag_lists = crf_model.test(test_word_lists)
|
||
|
metrics = Metrics(file_path, test_tag_lists, pred_tag_lists)
|
||
|
metrics.report_scores(dtype='CRF')
|
||
|
|
||
|
return crf_model
|
||
|
|
||
|
def bilstm_train_eval(file_path, train_data, dev_data, test_data, word2id, tag2id):
|
||
|
"""BiLSTM模型的评估与训练"""
|
||
|
print("BiLSTM模型的评估与训练")
|
||
|
|
||
|
bilstm_model = BiLSTM_opration(
|
||
|
train_data=train_data,
|
||
|
dev_data=dev_data,
|
||
|
test_data=test_data,
|
||
|
word2id=word2id,
|
||
|
tag2id=tag2id
|
||
|
)
|
||
|
bilstm_model.train()
|
||
|
bilstm_model.evaluate(file_path=file_path)
|
||
|
|
||
|
return bilstm_model
|
||
|
|
||
|
def bilstm_crf_train_eval(file_path, train_data, dev_data, test_data, word2id, tag2id):
|
||
|
"""BiLSTM_CRF模型的评估与训练"""
|
||
|
print("BiLSTM_CRF模型的评估与训练")
|
||
|
|
||
|
train_word_lists, train_tag_lists = train_data
|
||
|
dev_word_lists, dev_tag_lists = dev_data
|
||
|
test_word_lists, test_tag_lists = test_data
|
||
|
|
||
|
train_word_lists, train_tag_lists = add_end_tag(train_word_lists, train_tag_lists)
|
||
|
dev_word_lists, dev_tag_lists = add_end_tag(dev_word_lists, dev_tag_lists)
|
||
|
test_word_lists = add_end_tag(test_word_lists)
|
||
|
|
||
|
bilstm_crf_model = BiLSTM_CRF_opration(
|
||
|
train_data=(train_word_lists, train_tag_lists),
|
||
|
dev_data=(dev_word_lists, dev_tag_lists),
|
||
|
test_data=(test_word_lists, test_tag_lists),
|
||
|
word2id=word2id,
|
||
|
tag2id=tag2id
|
||
|
)
|
||
|
|
||
|
bilstm_crf_model.train()
|
||
|
bilstm_crf_model.evaluate(file_path=file_path)
|
||
|
|
||
|
return bilstm_crf_model
|