mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2024-12-21 06:19:31 +08:00
114 lines
4.5 KiB
Python
114 lines
4.5 KiB
Python
import argparse
|
|
from minio import Minio
|
|
|
|
from data import build_corpus
|
|
from train_evaluate import hmm_train_eval, crf_train_eval, bilstm_train_eval, bilstm_crf_train_eval
|
|
from utils.preprocessing import Preprocessing
|
|
from utils.utils import save_model
|
|
from config import TrainingConfig, BiLSTMCRFTrainConfig
|
|
|
|
def parse_args():
|
|
description = "你正在学习如何使用argparse模块进行命令行传参..."
|
|
parser = argparse.ArgumentParser(description=description)
|
|
parser.add_argument("-m", "--model", type=str, default='HMM', help="There are five models of NER, they are HMM, CRF, BiLSTM, BiLSTM_CRF and Bert_BiLSTM_CRF.")
|
|
parser.add_argument("-p", "--path", type=str, default='./zdata/', help="data path")
|
|
parser.add_argument('-n', '--name', type=str, default='annotated_data.txt', help='data file name')
|
|
parser.add_argument('-bn', '--bucketname', type=str, default='data', help='MinIO bucket name')
|
|
parser.add_argument('-on', '--objectname', type=str, default='people_daily_BIO.txt', help='MinIO object name')
|
|
parser.add_argument('-dr', '--datarate', type=list, default=[0.7, 0.1, 0.2], help='The rate of train_data, dev_data and test_data')
|
|
parser.add_argument('-e', '--epochs', type=int, default=10, help='train epoch')
|
|
parser.add_argument('-b', '--batch_size', type=int, default=16, help='batch size')
|
|
parser.add_argument('-lr', '--learning_rate', type=float, default=0.0005, help='learning rate')
|
|
parser.add_argument('-pp', "--pklpath", type=str, default='./model.pkl', help='the path and filename to save .pkl file')
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
print(args)
|
|
|
|
# contact to Minio
|
|
# minio_client = Minio(
|
|
# '10.101.32.11:9000',
|
|
# access_key='admin',
|
|
# secret_key='root123456',
|
|
# secure=False
|
|
# )
|
|
|
|
# download data from MinIO
|
|
# try:
|
|
# minio_client.fget_object(
|
|
# bucket_name=args.bucketname,
|
|
# object_name=args.objectname,
|
|
# file_path=args.path+args.name
|
|
# )
|
|
# except BaseException as err:
|
|
# print(err)
|
|
|
|
# preprocessing data
|
|
data_preprocessing = Preprocessing(
|
|
file_path=args.path,
|
|
file_name=args.name
|
|
)
|
|
data_preprocessing.train_test_dev_split(data_rate=args.datarate)
|
|
data_preprocessing.construct_vocabulary_labels()
|
|
|
|
# load data
|
|
print('long data ...')
|
|
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train", data_dir=args.path)
|
|
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False, data_dir=args.path)
|
|
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False, data_dir=args.path)
|
|
|
|
if args.model == 'HMM':
|
|
# train and evaluate HMM model
|
|
model = hmm_train_eval(
|
|
file_path=args.path,
|
|
train_data=(train_word_lists, train_tag_lists),
|
|
test_data=(test_word_lists, test_tag_lists),
|
|
word2id=word2id,
|
|
tag2id=tag2id
|
|
)
|
|
elif args.model == 'CRF':
|
|
# train and evaluate CRF model
|
|
model = crf_train_eval(
|
|
file_path=args.path,
|
|
train_data=(train_word_lists, train_tag_lists),
|
|
test_data=(test_word_lists, test_tag_lists)
|
|
)
|
|
elif args.model == 'BiLSTM':
|
|
# BiLSTM
|
|
TrainingConfig.batch_size = args.batch_size
|
|
TrainingConfig.epochs = args.epochs
|
|
TrainingConfig.lr = args.learning_rate
|
|
model = bilstm_train_eval(
|
|
file_path=args.path,
|
|
train_data=(train_word_lists, train_tag_lists),
|
|
dev_data=(dev_word_lists, dev_tag_lists),
|
|
test_data=(test_word_lists, test_tag_lists),
|
|
word2id=word2id,
|
|
tag2id=tag2id
|
|
)
|
|
elif args.model == 'BiLSTM_CRF':
|
|
# BiLSTM CRF
|
|
BiLSTMCRFTrainConfig.batch_size = args.batch_size
|
|
BiLSTMCRFTrainConfig.epochs = args.epochs
|
|
BiLSTMCRFTrainConfig.lr = args.learning_rate
|
|
model = bilstm_crf_train_eval(
|
|
file_path=args.path,
|
|
train_data=(train_word_lists, train_tag_lists),
|
|
dev_data=(dev_word_lists, dev_tag_lists),
|
|
test_data=(test_word_lists, test_tag_lists),
|
|
word2id=word2id,
|
|
tag2id=tag2id
|
|
)
|
|
|
|
save_model(model, args.pklpath)
|
|
# upload data from MinIO
|
|
# try:
|
|
# minio_client.fput_object(
|
|
# bucket_name=args.bucketname,
|
|
# object_name=f'{args.model}_model.pkl',
|
|
# file_path='./ckpts/model.pkl'
|
|
# )
|
|
# except BaseException as err:
|
|
# print(err) |