mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2025-01-18 13:53:59 +08:00
fix ner service
This commit is contained in:
parent
89b00f2e3d
commit
8f86437bb9
@ -6,6 +6,6 @@ COPY . /SVC
|
||||
RUN python3 -m pip install pickle-mixin -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN python3 -m pip install FastAPI -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
RUN python3 -m pip install uvicorn -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
WORKDIR /SVC
|
||||
|
||||
ENTRYPOINT ["python3", "launcher.py"]
|
||||
|
||||
ENTRYPOINT ["python3", "/SVC/launcher.py"]
|
||||
|
@ -1,4 +1,18 @@
|
||||
# download_minio 模板
|
||||
镜像:ccr.ccs.tencentyun.com/cube-studio/ner-service:20220812
|
||||
# ner service 模板
|
||||
镜像:ccr.ccs.tencentyun.com/cube-studio/ner:service-20220812
|
||||
|
||||
# 参数解析
|
||||
|
||||
`--service_type`:服务类型,一般 web 服务镜像填 `serving`。
|
||||
|
||||
`--images`:服务镜像,上文第二步打的镜像。
|
||||
|
||||
`--ports`:web 镜像里面 rest 服务的端口号,这里填入将其映射出来
|
||||
|
||||
# 使用服务
|
||||
|
||||
* 点击 IP 访问服务
|
||||
|
||||
> 访问地址后面加上`docs` 类似:`http://xx.xx.xx.xx:xx/docs`,可利用 FastAPI 的接口访问服务
|
||||
|
||||
* 点击 Try it out ,输入待检测文本
|
||||
|
@ -1,17 +0,0 @@
|
||||
service: "service:svc" # Same as the argument passed to `bentoml serve`
|
||||
labels:
|
||||
owner: zjlab
|
||||
stage: predict
|
||||
include:
|
||||
- "*.py" # A pattern for matching which files to include in the bento
|
||||
- "BiLSTM_CRF_tags.txt"
|
||||
- "BiLSTM_CRF_voc.txt"
|
||||
python:
|
||||
packages: # Additional pip packages required by the service
|
||||
- torch>=1.7.1+cu110
|
||||
- numpy>=1.19.5
|
||||
- transformers>=4.20.1
|
||||
- matplotlib>=3.3.4
|
||||
- tqdm>=4.64.0
|
||||
- bentoml>=1.0.0
|
||||
- minio>=7.1.10
|
@ -204,6 +204,7 @@ class BiLSTM_CRF_opration:
|
||||
wordlist = []
|
||||
for i in sentence:
|
||||
wordlist.append(i)
|
||||
wordlist.append('<END>')
|
||||
with torch.no_grad():
|
||||
# wordlist = ['1', '9', '6', '2', '年', '1', '月', '出', '生', ',', '南', '京', '工', '学', '院', '毕', '业', '。']
|
||||
word_id_list = self._tokenizer([wordlist])
|
||||
|
@ -1,48 +0,0 @@
|
||||
import os
|
||||
|
||||
|
||||
def build_corpus(split, make_vocab=True, data_dir='./zdata/'):
|
||||
"""数据读取
|
||||
"""
|
||||
|
||||
assert split.lower() in ["train", "dev", "test"]
|
||||
|
||||
word_lists = []
|
||||
tag_lists = []
|
||||
with open(os.path.join(data_dir, split + '.txt'), 'r') as f:
|
||||
word_list = []
|
||||
tag_list = []
|
||||
for line in f:
|
||||
if line != '\n':
|
||||
word, tag = line.strip('\n').split()
|
||||
word_list.append(word)
|
||||
tag_list.append(tag)
|
||||
else:
|
||||
word_lists.append(word_list)
|
||||
tag_lists.append(tag_list)
|
||||
word_list = []
|
||||
tag_list = []
|
||||
|
||||
if make_vocab:
|
||||
word2id = build_map(word_lists)
|
||||
tag2id = build_map(tag_lists)
|
||||
return word_lists, tag_lists, word2id, tag2id
|
||||
else:
|
||||
return word_lists, tag_lists
|
||||
|
||||
|
||||
def build_map(lists):
|
||||
maps = {}
|
||||
for list_ in lists:
|
||||
for e in list_:
|
||||
if e not in maps:
|
||||
maps[e] = len(maps)
|
||||
return maps
|
||||
|
||||
|
||||
def token_to_str(word_lists):
|
||||
s = ' '
|
||||
sentence_list = []
|
||||
for word_list in word_lists:
|
||||
sentence_list.append(s.join(word_list))
|
||||
return sentence_list
|
@ -163,4 +163,4 @@ class Metrics:
|
||||
for metric in weighted_average.keys():
|
||||
weighted_average[metric] /= total
|
||||
|
||||
return weighted_average
|
||||
return weighted_average
|
@ -1,7 +1,7 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
import pickle
|
||||
|
||||
import os
|
||||
|
||||
fmt = "\n=== {:30} ===\n"
|
||||
search_latency_fmt = "search latency = {:.4f}s"
|
||||
@ -13,8 +13,8 @@ def load_model(filename='./ckpts/model_BiLSTM_CRF.pkl'):
|
||||
with open(filename, 'rb') as f:
|
||||
model = pickle.load(f)
|
||||
return model
|
||||
|
||||
model = load_model("/mnt/admin/model.pkl")
|
||||
MODEL_PATH=os.getenv("MODEL_PATH",'/mnt/admin/model.pkl')
|
||||
model = load_model(MODEL_PATH)
|
||||
|
||||
def serve(model, sentence):
|
||||
# sentence = '1962年1月出生,南京工学院毕业。'
|
||||
@ -25,45 +25,44 @@ def serve(model, sentence):
|
||||
pos = 0
|
||||
output = []
|
||||
tags = []
|
||||
count = 0
|
||||
while pos<n:
|
||||
count += 1
|
||||
if count >=100*n:
|
||||
return ans
|
||||
tmp = ''
|
||||
if pos < n and ans[pos][0] == 'B':
|
||||
tags.append(ans[pos][2:])
|
||||
tmp += sentence[pos]
|
||||
pos += 1
|
||||
while pos < n and ans[pos][0] != 'B' and ans[pos][0] != 'O':
|
||||
count += 1
|
||||
if count >=100*n:
|
||||
return ans
|
||||
tmp += sentence[pos]
|
||||
pos += 1
|
||||
output.append(tmp)
|
||||
tmp = ''
|
||||
while pos < n and ans[pos][0] == 'O':
|
||||
count +=1
|
||||
if count >=100*n:
|
||||
return ans
|
||||
tmp += sentence[pos]
|
||||
pos += 1
|
||||
if tmp:
|
||||
tags.append('O')
|
||||
output.append(tmp)
|
||||
# print('方案一')
|
||||
# for i in output:
|
||||
# print(i+',')
|
||||
# for i in tags:
|
||||
# print(i+',')
|
||||
# print('方案二')
|
||||
outputs = ','.join(output)
|
||||
tagsStr = ','.join(tags)
|
||||
return outputs + '\n' + tagsStr
|
||||
|
||||
|
||||
@app.get("/items/{item_id}")
|
||||
async def read_item(item_id: int):
|
||||
return {"item_id": item_id}
|
||||
|
||||
@app.get("/ner")
|
||||
@app.get("/")
|
||||
async def serve_api(s: str):
|
||||
res = serve(model, s)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
# uvicorn main:app --host '0.0.0.0' --port 8123 --reload
|
||||
if __name__ == "__main__":
|
||||
|
||||
uvicorn.run(app=app, host='0.0.0.0', port=8123)
|
||||
|
@ -1,37 +0,0 @@
|
||||
from data import build_corpus
|
||||
from train_evaluate import hmm_train_eval, crf_train_eval, bilstm_train_eval
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# load data
|
||||
print('long data ...')
|
||||
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train")
|
||||
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
|
||||
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
|
||||
|
||||
# train and evaluate HMM model
|
||||
hmm_pred = hmm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
|
||||
# train and evaluate CRF model
|
||||
crf_pred = crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists)
|
||||
)
|
||||
|
||||
# BiLSTM
|
||||
bilstm_pred = bilstm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
|
||||
|
||||
main()
|
@ -1,112 +0,0 @@
|
||||
import argparse
|
||||
from minio import Minio
|
||||
|
||||
from data import build_corpus
|
||||
from train_evaluate import hmm_train_eval, crf_train_eval, bilstm_train_eval, bilstm_crf_train_eval
|
||||
from utils.preprocessing import Preprocessing
|
||||
from utils.utils import save_model
|
||||
from config import TrainingConfig, BiLSTMCRFTrainConfig
|
||||
|
||||
|
||||
def parse_args():
|
||||
description = "你正在学习如何使用argparse模块进行命令行传参..."
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-m", "--model", type=str, default='HMM',
|
||||
help="There are five models of NER, they are HMM, CRF, BiLSTM, BiLSTM_CRF and Bert_BiLSTM_CRF.")
|
||||
parser.add_argument("-p", "--path", type=str, default='./zdata/', help="data path")
|
||||
parser.add_argument('-n', '--name', type=str, default='annotated_data.txt', help='data file name')
|
||||
parser.add_argument('-bn', '--bucketname', type=str, default='data', help='MinIO bucket name')
|
||||
parser.add_argument('-on', '--objectname', type=str, default='people_daily_BIO.txt', help='MinIO object name')
|
||||
parser.add_argument('-dr', '--datarate', type=list, default=[0.7, 0.1, 0.2],
|
||||
help='The rate of train_data, dev_data and test_data')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=10, help='train epoch')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='batch size')
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.0005, help='learning rate')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print(args)
|
||||
|
||||
# contact to Minio
|
||||
minio_client = Minio(
|
||||
'10.101.32.11:9000',
|
||||
access_key='admin',
|
||||
secret_key='root123456',
|
||||
secure=False
|
||||
)
|
||||
|
||||
# download data from MinIO
|
||||
try:
|
||||
minio_client.fget_object(
|
||||
bucket_name=args.bucketname,
|
||||
object_name=args.objectname,
|
||||
file_path=args.path + args.name
|
||||
)
|
||||
except BaseException as err:
|
||||
print(err)
|
||||
|
||||
# preprocessing data
|
||||
data_preprocessing = Preprocessing(
|
||||
file_path=args.path,
|
||||
file_name=args.name
|
||||
)
|
||||
data_preprocessing.train_test_dev_split(data_rate=args.datarate)
|
||||
data_preprocessing.construct_vocabulary_labels()
|
||||
|
||||
# load data
|
||||
print('long data ...')
|
||||
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train", data_dir=args.path)
|
||||
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False, data_dir=args.path)
|
||||
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False, data_dir=args.path)
|
||||
model = None
|
||||
if args.model == 'HMM':
|
||||
# train and evaluate HMM model
|
||||
model = hmm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
elif args.model == 'CRF':
|
||||
# train and evaluate CRF model
|
||||
model = crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists)
|
||||
)
|
||||
elif args.model == 'BiLSTM':
|
||||
# BiLSTM
|
||||
TrainingConfig.batch_size = args.batch_size
|
||||
TrainingConfig.epochs = args.epochs
|
||||
TrainingConfig.lr = args.learning_rate
|
||||
model = bilstm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
elif args.model == 'BiLSTM_CRF':
|
||||
# BiLSTM CRF
|
||||
BiLSTMCRFTrainConfig.batch_size = args.batch_size
|
||||
BiLSTMCRFTrainConfig.epochs = args.epochs
|
||||
BiLSTMCRFTrainConfig.lr = args.learning_rate
|
||||
model = bilstm_crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
|
||||
save_model(model, './ckpts/model_bilstm_crf.pkl.pkl')
|
||||
# upload data from MinIO
|
||||
try:
|
||||
minio_client.fput_object(
|
||||
bucket_name=args.bucketname,
|
||||
object_name=f'{args.model}_model.pkl',
|
||||
file_path='./ckpts/model_bilstm_crf.pkl'
|
||||
)
|
||||
except BaseException as err:
|
||||
print(err)
|
@ -1,121 +0,0 @@
|
||||
import argparse
|
||||
from minio import Minio
|
||||
|
||||
from data import build_corpus
|
||||
from train_evaluate import hmm_train_eval, crf_train_eval, bilstm_train_eval, bilstm_crf_train_eval
|
||||
from utils.preprocessing import Preprocessing, save2id
|
||||
from utils.utils import save_model
|
||||
from config import TrainingConfig, BiLSTMCRFTrainConfig
|
||||
|
||||
|
||||
def parse_args():
|
||||
description = "你正在学习如何使用argparse模块进行命令行传参..."
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-m", "--model", type=str, default='BiLSTM_CRF',
|
||||
help="There are five models of NER, they are HMM, CRF, BiLSTM, BiLSTM_CRF and Bert_BiLSTM_CRF.")
|
||||
parser.add_argument("-p", "--path", type=str, default='./zdata/', help="data path")
|
||||
parser.add_argument('-n', '--name', type=str, default='annotated_data.txt', help='data file name')
|
||||
parser.add_argument('-bn', '--bucketname', type=str, default='data', help='MinIO bucket name')
|
||||
parser.add_argument('-on', '--objectname', type=str, default='people_daily_BIO.txt', help='MinIO object name')
|
||||
parser.add_argument('-dr', '--datarate', type=list, default=[0.7, 0.1, 0.2],
|
||||
help='The rate of train_data, dev_data and test_data')
|
||||
parser.add_argument('-e', '--epochs', type=int, default=10, help='train epoch')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=16, help='batch size')
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.0005, help='learning rate')
|
||||
parser.add_argument('-d', '--device',type = str, default='cpu')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print(args)
|
||||
|
||||
# contact to Minio
|
||||
minio_client = Minio(
|
||||
'10.101.32.11:9000',
|
||||
access_key='admin',
|
||||
secret_key='root123456',
|
||||
secure=False
|
||||
)
|
||||
|
||||
# download data from MinIO
|
||||
try:
|
||||
minio_client.fget_object(
|
||||
bucket_name=args.bucketname,
|
||||
object_name=args.objectname,
|
||||
file_path=args.path + args.name
|
||||
)
|
||||
except BaseException as err:
|
||||
print(err)
|
||||
|
||||
# preprocessing data
|
||||
data_preprocessing = Preprocessing(
|
||||
file_path=args.path,
|
||||
file_name=args.name
|
||||
)
|
||||
data_preprocessing.train_test_dev_split(data_rate=args.datarate)
|
||||
data_preprocessing.construct_vocabulary_labels()
|
||||
|
||||
# load data
|
||||
print('long data ...')
|
||||
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train", data_dir=args.path)
|
||||
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False, data_dir=args.path)
|
||||
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False, data_dir=args.path)
|
||||
model = None
|
||||
if args.model == 'HMM':
|
||||
# train and evaluate HMM model
|
||||
model = hmm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
elif args.model == 'CRF':
|
||||
# train and evaluate CRF model
|
||||
model = crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists)
|
||||
)
|
||||
elif args.model == 'BiLSTM':
|
||||
# BiLSTM
|
||||
TrainingConfig.batch_size = args.batch_size
|
||||
TrainingConfig.epochs = args.epochs
|
||||
TrainingConfig.lr = args.learning_rate
|
||||
TrainingConfig.device = args.device
|
||||
model = bilstm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
elif args.model == 'BiLSTM_CRF':
|
||||
# BiLSTM CRF
|
||||
BiLSTMCRFTrainConfig.batch_size = args.batch_size
|
||||
BiLSTMCRFTrainConfig.epochs = args.epochs
|
||||
BiLSTMCRFTrainConfig.lr = args.learning_rate
|
||||
BiLSTMCRFTrainConfig.device = args.device
|
||||
model = bilstm_crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
|
||||
model_bilstm_crf_path = './ckpts/model_{}.pkl'.format(args.model)
|
||||
print('complete model training')
|
||||
|
||||
save_model(model, model_bilstm_crf_path)
|
||||
save2id(word2id, tag2id, args.model, args.path)
|
||||
# upload data from MinIO
|
||||
print('upload model to MinIO')
|
||||
try:
|
||||
minio_client.fput_object(
|
||||
bucket_name=args.bucketname,
|
||||
object_name=f'{args.model}_model.pkl',
|
||||
file_path=model_bilstm_crf_path
|
||||
)
|
||||
except BaseException as err:
|
||||
print(err)
|
||||
print('all tasks finished')
|
@ -1,7 +0,0 @@
|
||||
torch>=1.7.1+cu110
|
||||
numpy>=1.19.5
|
||||
transformers>=4.20.1
|
||||
matplotlib>=3.3.4
|
||||
tqdm>=4.64.0
|
||||
bentoml>=1.0.0
|
||||
minio>=7.1.10
|
@ -1,100 +0,0 @@
|
||||
import bentoml
|
||||
import config
|
||||
from data import build_corpus
|
||||
from minio import Minio
|
||||
import os
|
||||
from train_evaluate import hmm_train_eval, crf_train_eval, bilstm_train_eval, bilstm_crf_train_eval
|
||||
from utils.preprocessing import Preprocessing, save2id
|
||||
from utils.utils import save_model
|
||||
from config import TrainingConfig, BiLSTMCRFTrainConfig
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not os.path.exists(config.data_path+config.data_name):
|
||||
# contact to Minio
|
||||
minio_client = Minio(
|
||||
'10.101.32.11:9000',
|
||||
access_key='admin',
|
||||
secret_key='root123456',
|
||||
secure=False
|
||||
)
|
||||
try:
|
||||
minio_client.fget_object(
|
||||
bucket_name='data',
|
||||
object_name='people_daily_BIO.txt',
|
||||
file_path=config.data_path + config.data_name
|
||||
)
|
||||
except BaseException as err:
|
||||
print(err)
|
||||
# preprocessing data
|
||||
data_preprocessing = Preprocessing(
|
||||
file_path=config.data_path,
|
||||
file_name=config.data_name
|
||||
)
|
||||
data_preprocessing.train_test_dev_split(data_rate=config.data_rate)
|
||||
data_preprocessing.construct_vocabulary_labels()
|
||||
|
||||
# load data
|
||||
logger.info('long data ...')
|
||||
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train", data_dir=config.data_path)
|
||||
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False, data_dir=config.data_path)
|
||||
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False, data_dir=config.data_path)
|
||||
model = None
|
||||
if config.model_name == 'HMM':
|
||||
# train and evaluate HMM model
|
||||
logger.info("HMM_train_eval start")
|
||||
model = hmm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
save2id(word2id, tag2id, config.model_name)
|
||||
save_model(model, './ckpts/{}model.pkl'.format(config.model_name))
|
||||
saved_model = bentoml.picklable_model.save_model(
|
||||
config.model_name,
|
||||
model,
|
||||
signatures={"__call__": {"batchable": True}}
|
||||
)
|
||||
print(f"Model saved: {saved_model}")
|
||||
logger.info("HMM_train_eval end")
|
||||
|
||||
elif config.model_name == 'CRF':
|
||||
# train and evaluate CRF model
|
||||
model = crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists)
|
||||
)
|
||||
elif config.model_name == 'BiLSTM':
|
||||
# BiLSTM
|
||||
TrainingConfig.batch_size = config.batch_size
|
||||
TrainingConfig.epochs = config.epochs
|
||||
TrainingConfig.lr = config.lr
|
||||
model = bilstm_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
save2id(word2id, tag2id, config.model_name)
|
||||
bentoml.pytorch.save_model('BiLSTM', model.model)
|
||||
save_model(model, './ckpts/{}model.pkl'.format(config.model_name))
|
||||
elif config.model_name == 'BiLSTM_CRF':
|
||||
# BiLSTM CRF
|
||||
BiLSTMCRFTrainConfig.batch_size = config.batch_size
|
||||
BiLSTMCRFTrainConfig.epochs = config.epochs
|
||||
BiLSTMCRFTrainConfig.lr = config.lr
|
||||
model = bilstm_crf_train_eval(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
save2id(word2id, tag2id, config.model_name)
|
||||
bentoml.pytorch.save_model('BiLSTM_CRF', model.model)
|
||||
save_model(model, './ckpts/{}model.pkl'.format(config.model_name))
|
||||
|
@ -1,36 +0,0 @@
|
||||
import bentoml
|
||||
import torch
|
||||
import config
|
||||
import pickle
|
||||
from utils.bilstm_crf_token import readfile, token, idlist2tag
|
||||
|
||||
word2id = readfile(config.model_name+'_voc.txt')
|
||||
tag2id = readfile(config.model_name+'_tags.txt')
|
||||
id2tag = dict((ids, tag) for tag, ids in tag2id.items())
|
||||
device = config.device
|
||||
model = bentoml.pytorch.load_model("bilstm_crf:latest")
|
||||
# model = bentoml.pytorch.load_model("bilstm:latest")
|
||||
|
||||
sentence = '1962年1月出生,南京工学院毕业'
|
||||
|
||||
# with open('./ckpts/BiLSTM_CRFmodel.pkl', 'rb') as f:
|
||||
# modela = pickle.load(f)
|
||||
# tags = modela.predict_sentence(sentence)
|
||||
# print(tags)
|
||||
|
||||
|
||||
|
||||
word_list = []
|
||||
for i in sentence:
|
||||
word_list.append(i)
|
||||
word_list = ['1', '9', '6', '2', '年', '1', '月', '出', '生', ',', '南', '京', '工', '学', '院', '毕', '业', '。']
|
||||
inputs = token([word_list], word2id, device)
|
||||
# output = model.forward(inputs)
|
||||
# tags = idlist2tag(output, tag2id, id2tag)
|
||||
# print(tags)
|
||||
|
||||
runner = bentoml.pytorch.get('bilstm_crf:latest').to_runner()
|
||||
runner.init_local()
|
||||
outputs = runner.__call__.run(inputs)
|
||||
tags = idlist2tag(outputs, tag2id, id2tag)
|
||||
print(tags)
|
@ -1,96 +0,0 @@
|
||||
# 读取数据集
|
||||
from re import S
|
||||
|
||||
|
||||
with open('data/ThePeoplesDaily/raw/source_BIO_2014_cropus.txt', 'r') as source_file:
|
||||
source = source_file.read().split('\n')
|
||||
|
||||
with open('data/ThePeoplesDaily/raw/target_BIO_2014_cropus.txt', 'r') as target_file:
|
||||
target = target_file.read().split('\n')
|
||||
|
||||
|
||||
# 统计每个样本的句长
|
||||
max_len = 0
|
||||
seq_len_dict = {}
|
||||
for sentence in source:
|
||||
seq_len = (len(sentence) + 1) / 2
|
||||
if max_len < seq_len:
|
||||
max_len = seq_len
|
||||
if seq_len in seq_len_dict: seq_len_dict[seq_len] += 1
|
||||
else: seq_len_dict[seq_len] = 1
|
||||
print(max_len)
|
||||
seq_len_dict = sorted(seq_len_dict.items(), key=lambda x: x[0], reverse=True)
|
||||
print(seq_len_dict)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# # 统计样本数量
|
||||
# assert len(source) == len(target)
|
||||
# sample_count = len(source)
|
||||
# print('行数:', sample_count)
|
||||
|
||||
# # 统计字数 字典 标签种类
|
||||
# vocabulary = dict()
|
||||
# labels = dict()
|
||||
# char_count = 0
|
||||
# for sentence, tags in zip(source, target):
|
||||
# char_count += round( (len(sentence)+0.49)/2 )
|
||||
# for char, tag in zip(sentence.split(' '), tags.split(' ')):
|
||||
# if char in vocabulary: vocabulary[char] += 1
|
||||
# else: vocabulary[char] = 1
|
||||
# if tag in labels: labels[tag] += 1
|
||||
# else: labels[tag] = 1
|
||||
# print('字数:', char_count)
|
||||
|
||||
# # 根据数量 降序 并写入vocabulary.txt和labels.txt文件
|
||||
# vocabulary = dict(sorted(vocabulary.items(), key=lambda x: x[1], reverse=True))
|
||||
# labels = dict(sorted(labels.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
# with open('data/ThePeoplesDaily/vocabulary.txt', 'w') as vocabulary_file:
|
||||
# vocabulary_file.write('\n'.join(vocabulary))
|
||||
|
||||
# with open('data/ThePeoplesDaily/labels.txt', 'w') as labels_file:
|
||||
# labels_file.write('\n'.join(labels))
|
||||
|
||||
# # 划分 训练集 验证集 测试集
|
||||
# train_data_size = 0.7
|
||||
# dev_data_size = 0.2
|
||||
# test_data_size = 0.1
|
||||
|
||||
# train_source, train_target = source[0 : round(train_data_size*sample_count)], target[0 : round(train_data_size*sample_count)]
|
||||
# with open('data/ThePeoplesDaily/train.txt', 'w') as train_file:
|
||||
# for sentence, tags in zip(train_source, train_target):
|
||||
# for char, tag in zip(sentence.split(' '), tags.split(' ')):
|
||||
# train_file.write(char + ' ' + tag + '\n')
|
||||
# train_file.write('\n')
|
||||
# print('train_data_sample_size: ', len(train_source))
|
||||
|
||||
# dev_source, dev_target = source[round(train_data_size*sample_count) : round((train_data_size+dev_data_size)*sample_count)], target[round(train_data_size*sample_count) : round((train_data_size+dev_data_size)*sample_count)]
|
||||
# with open('data/ThePeoplesDaily/dev.txt', 'w') as dev_file:
|
||||
# for sentence, tags in zip(dev_source, dev_target):
|
||||
# for char, tag in zip(sentence.split(' '), tags.split(' ')):
|
||||
# dev_file.write(char + ' ' + tag + '\n')
|
||||
# dev_file.write('\n')
|
||||
# print('dev_data_sample_size: ', len(dev_source))
|
||||
|
||||
# test_source, test_target = source[round((train_data_size+dev_data_size)*sample_count) : -1], target[round((train_data_size+dev_data_size)*sample_count) : -1]
|
||||
# with open('data/ThePeoplesDaily/test.txt', 'w') as test_file:
|
||||
# for sentence, tags in zip(test_source, test_target):
|
||||
# for char, tag in zip(sentence.split(' '), tags.split(' ')):
|
||||
# test_file.write(char + ' ' + tag + '\n')
|
||||
# test_file.write('\n')
|
||||
# print('test_data_sample_size: ', len(test_source))
|
||||
|
||||
|
||||
|
||||
|
||||
with open('data/ThePeoplesDaily/raw/source_BIO.txt', 'w') as file:
|
||||
for sentence, tags in zip(source, target):
|
||||
for char, tag in zip(sentence.split(' '), tags.split(' ')):
|
||||
file.write(char + ' ' + tag + '\n')
|
||||
file.write('\n')
|
||||
print('data_sample_size: ', len(source))
|
@ -1,117 +0,0 @@
|
||||
# from ..data import build_corpus
|
||||
|
||||
# files = ['train', 'dev', 'test']
|
||||
|
||||
# words = ['PAD']
|
||||
# for file_name in files:
|
||||
# with open(f'./data/{file_name}.char', 'r') as file_object:
|
||||
# words_tags = file_object.read().split('\n')
|
||||
# for word_tag in words_tags:
|
||||
# if word_tag is None: continue
|
||||
# word = word_tag.split(' ')[0]
|
||||
# if word not in words:
|
||||
# words.append(word)
|
||||
|
||||
# # print(words)
|
||||
|
||||
# with open('data/vocabulary.char', 'w') as file_object:
|
||||
# for word in words:
|
||||
# file_object.write(f'{word}\n')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import os
|
||||
|
||||
def build_corpus(split, make_vocab=True, data_dir='./data'):
|
||||
"""数据读取
|
||||
"""
|
||||
|
||||
assert split.lower() in ["train", "dev", "test"]
|
||||
|
||||
word_lists = []
|
||||
tag_lists = []
|
||||
with open(os.path.join(data_dir, split + '.char'), 'r', encoding='utf-8') as f:
|
||||
word_list = []
|
||||
tag_list = []
|
||||
for line in f:
|
||||
if line != '\n':
|
||||
word, tag = line.strip('\n').split()
|
||||
word_list.append(word)
|
||||
tag_list.append(tag)
|
||||
else:
|
||||
word_lists.append(word_list)
|
||||
tag_lists.append(tag_list)
|
||||
word_list = []
|
||||
tag_list = []
|
||||
|
||||
if make_vocab:
|
||||
word2id = build_map(word_lists)
|
||||
tag2id = build_map(tag_lists)
|
||||
return word_lists, tag_lists, word2id, tag2id
|
||||
else:
|
||||
return word_lists, tag_lists
|
||||
|
||||
def build_map(lists):
|
||||
maps = {}
|
||||
for list_ in lists:
|
||||
for e in list_:
|
||||
if e not in maps:
|
||||
maps[e] = len(maps)
|
||||
return maps
|
||||
|
||||
|
||||
|
||||
# 查找序列最大值
|
||||
# train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train")
|
||||
# dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
|
||||
# test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
|
||||
|
||||
# max_length = 0
|
||||
# for word_lists in [train_word_lists, dev_word_lists, test_word_lists]:
|
||||
# for word_list in word_lists:
|
||||
# if len(word_list) > max_length:
|
||||
# max_length = len(word_list)
|
||||
# print(max_length)
|
||||
|
||||
|
||||
|
||||
|
||||
# train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train")
|
||||
# dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
|
||||
# test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
|
||||
|
||||
# seq_len_frequence = dict()
|
||||
|
||||
# for word_lists in [train_word_lists, dev_word_lists, test_word_lists]:
|
||||
# for word_list in word_lists:
|
||||
# if len(word_list) in seq_len_frequence:
|
||||
# seq_len_frequence[len(word_list)] += 1
|
||||
# else:
|
||||
# seq_len_frequence[len(word_list)] = 1
|
||||
# print(seq_len_frequence)
|
||||
|
||||
# seq_len_frequence = sorted(seq_len_frequence.items(), key=lambda x: x[0])
|
||||
# print(seq_len_frequence)
|
||||
|
||||
# print(dict(seq_len_frequence))
|
||||
# X = dict(seq_len_frequence).keys()
|
||||
# Y = dict(seq_len_frequence).values()
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# a = plt.figure(figsize=(5,10))
|
||||
# plt.bar(x=X, height=Y)
|
||||
# plt.show()
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
df2 = pd.DataFrame(np.random.rand(10, 4), columns=["a", "b", "c", "d"])
|
||||
df2.plot.bar()
|
||||
|
||||
fig, ax = plt.subplots() # Create a figure containing a single axes.
|
||||
ax.plot([1, 2, 3, 4], [1, 4, 2, 3]); # Plot some data on the axes.
|
@ -1,84 +0,0 @@
|
||||
from models.HMM import HMM
|
||||
from models.CRF import CRFModel
|
||||
from bilstm_opration import BiLSTM_opration
|
||||
from bilstm_crf_opration import BiLSTM_CRF_opration
|
||||
from evaluating import Metrics
|
||||
from utils.utils import save_model, add_end_tag
|
||||
|
||||
|
||||
def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_0=False):
|
||||
"""hmm模型的评估与训练"""
|
||||
print("hmm模型的评估与训练...")
|
||||
train_word_lists, train_tag_lists = train_data
|
||||
test_word_lists, test_tag_lists = test_data
|
||||
|
||||
# 模型训练
|
||||
hmm_model = HMM(N=len(tag2id), M=len(word2id))
|
||||
hmm_model.train(train_word_lists, train_tag_lists, word2id, tag2id)
|
||||
# save_model(hmm_model,"./ckpts/hmm.pkl")
|
||||
|
||||
# 模型评估
|
||||
pred_tag_lists = hmm_model.test(test_word_lists, word2id, tag2id)
|
||||
metrics = Metrics(test_tag_lists, pred_tag_lists)
|
||||
metrics.report_scores(dtype='HMM')
|
||||
|
||||
return hmm_model
|
||||
|
||||
|
||||
def crf_train_eval(train_data, test_data, remove_0=False):
|
||||
"""crf模型的评估与训练"""
|
||||
print("crf模型的评估与训练")
|
||||
train_word_lists, train_tag_lists = train_data
|
||||
test_word_lists, test_tag_lists = test_data
|
||||
crf_model = CRFModel()
|
||||
crf_model.train(train_word_lists, train_tag_lists)
|
||||
# save_model(crf_model, "./ckpts/crf.pkl")
|
||||
|
||||
pred_tag_lists = crf_model.test(test_word_lists)
|
||||
metrics = Metrics(test_tag_lists, pred_tag_lists)
|
||||
metrics.report_scores(dtype='CRF')
|
||||
|
||||
return crf_model
|
||||
|
||||
|
||||
def bilstm_train_eval(train_data, dev_data, test_data, word2id, tag2id):
|
||||
"""BiLSTM模型的评估与训练"""
|
||||
print("BiLSTM模型的评估与训练")
|
||||
|
||||
bilstm_model = BiLSTM_opration(
|
||||
train_data=train_data,
|
||||
dev_data=dev_data,
|
||||
test_data=test_data,
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
bilstm_model.train()
|
||||
bilstm_model.evaluate()
|
||||
|
||||
return bilstm_model
|
||||
|
||||
|
||||
def bilstm_crf_train_eval(train_data, dev_data, test_data, word2id, tag2id):
|
||||
"""BiLSTM_CRF模型的评估与训练"""
|
||||
print("BiLSTM_CRF模型的评估与训练")
|
||||
|
||||
train_word_lists, train_tag_lists = train_data
|
||||
dev_word_lists, dev_tag_lists = dev_data
|
||||
test_word_lists, test_tag_lists = test_data
|
||||
|
||||
train_word_lists, train_tag_lists = add_end_tag(train_word_lists, train_tag_lists)
|
||||
dev_word_lists, dev_tag_lists = add_end_tag(dev_word_lists, dev_tag_lists)
|
||||
test_word_lists = add_end_tag(test_word_lists)
|
||||
|
||||
bilstm_crf_model = BiLSTM_CRF_opration(
|
||||
train_data=(train_word_lists, train_tag_lists),
|
||||
dev_data=(dev_word_lists, dev_tag_lists),
|
||||
test_data=(test_word_lists, test_tag_lists),
|
||||
word2id=word2id,
|
||||
tag2id=tag2id
|
||||
)
|
||||
|
||||
bilstm_crf_model.train()
|
||||
bilstm_crf_model.evaluate()
|
||||
|
||||
return bilstm_crf_model
|
@ -1,99 +0,0 @@
|
||||
import config
|
||||
import torch
|
||||
|
||||
|
||||
def build_maps(lists) -> dict:
|
||||
maps = {}
|
||||
for e in lists:
|
||||
if e not in maps:
|
||||
maps[e] = len(maps)
|
||||
return maps
|
||||
|
||||
|
||||
def readfile(filename):
|
||||
with open(config.data_path + filename, 'r', encoding='utf-8') as f:
|
||||
data = f.read().split('\n')
|
||||
return build_maps(data)
|
||||
|
||||
|
||||
def token(word_lists, word2id, device):
|
||||
"""
|
||||
:param word_lists: list of words
|
||||
:param word2id: dictionary of word2id
|
||||
"""
|
||||
for idx in range(len(word_lists)):
|
||||
word_lists[idx].append('<end>')
|
||||
|
||||
word_id_lists = (torch.ones(size=(len(word_lists), len(word_lists[0])), dtype=torch.long) * word2id[
|
||||
'<pad>']).to(device if torch.cuda.is_available() else 'cpu')
|
||||
for i in range(len(word_lists)):
|
||||
for j in range(len(word_lists[i])):
|
||||
word_id_lists[i][j] = word2id.get(word_lists[i][j],word2id['<unk>']) # 遇到词表中不存在的字符,使用<unk>代替
|
||||
word_id_lists[i][-1] = word2id.get(word_lists[i][-1], word2id['<unk>'])
|
||||
return word_id_lists
|
||||
|
||||
|
||||
# def token(word_lists, word2id, device):
|
||||
# """
|
||||
# :param word_lists: list of words
|
||||
# :param word2id: dictionary of word2id
|
||||
# """
|
||||
# word_lists.append('<end>')
|
||||
#
|
||||
# word_id_lists = (torch.ones(size=len(word_lists), dtype=torch.long) * word2id[
|
||||
# '<pad>']).to(device if torch.cuda.is_available() else 'cpu')
|
||||
#
|
||||
# for j in range(len(word_id_lists)):
|
||||
# word_id_lists[j] = word2id.get(word_lists[j],word2id['<unk>']) # 遇到词表中不存在的字符,使用<unk>代替
|
||||
# word_id_lists[-1] = word2id.get(word_lists[-1], word2id['<unk>'])
|
||||
# return word_id_lists
|
||||
|
||||
|
||||
def viterbi_decoding(crf_score, tag2id):
|
||||
"""viterbi decoding
|
||||
不支持 batch"""
|
||||
start_id = tag2id['<start>']
|
||||
end_id = tag2id['<end>']
|
||||
|
||||
device = crf_score.device
|
||||
seq_len = crf_score.shape[0]
|
||||
viterbi = torch.zeros(seq_len, len(tag2id)).to(device)
|
||||
backpointer = (torch.ones(size=(seq_len, len(tag2id)), dtype=torch.long) * end_id).to(device)
|
||||
|
||||
for step in range(seq_len):
|
||||
if step == 0: # 第一个字
|
||||
viterbi[step, :] = crf_score[step, start_id, :]
|
||||
backpointer[step, :] = start_id
|
||||
else:
|
||||
max_scores, prev_tags_id = torch.max(
|
||||
viterbi[step - 1, :].unsqueeze(1) + crf_score[step, :, :],
|
||||
dim=0
|
||||
)
|
||||
viterbi[step, :] = max_scores
|
||||
backpointer[step, :] = prev_tags_id
|
||||
|
||||
best_end_idx = end_id
|
||||
best_path = []
|
||||
for step in range(seq_len - 1, 0, -1):
|
||||
if step == seq_len - 1:
|
||||
best_path.append(backpointer[step, best_end_idx].item())
|
||||
else:
|
||||
best_path.append(backpointer[step, best_path[-1]].item())
|
||||
best_path.reverse()
|
||||
|
||||
return best_path
|
||||
|
||||
|
||||
def predtion_to_tags(prediction, id2tag):
|
||||
return [id2tag.get(id, 'O') for id in prediction]
|
||||
|
||||
|
||||
def idlist2tag(lists, tag2id, id2tag):
|
||||
pred_tag_lists = []
|
||||
best_path = viterbi_decoding(lists[0], tag2id)
|
||||
pred_tag_lists.append(
|
||||
predtion_to_tags(
|
||||
best_path, id2tag
|
||||
)
|
||||
)
|
||||
return pred_tag_lists
|
@ -1,35 +0,0 @@
|
||||
import config
|
||||
import torch
|
||||
|
||||
|
||||
def build_maps(lists) -> dict:
|
||||
maps = {}
|
||||
for e in lists:
|
||||
if e not in maps:
|
||||
maps[e] = len(maps)
|
||||
return maps
|
||||
|
||||
|
||||
def readfile(filename):
|
||||
with open(config.data_path + filename, 'r', encoding='utf-8') as f:
|
||||
data = f.read().split('\n')
|
||||
return build_maps(data)
|
||||
|
||||
|
||||
def token(word_lists, word2id, device='cpu'):
|
||||
word_id_lists = (
|
||||
torch.ones(size=(len(word_lists), len(word_lists[0])), dtype=torch.long) * word2id['<pad>']).to(
|
||||
device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
for i in range(len(word_lists)):
|
||||
for j in range(len(word_lists[i])):
|
||||
word_id_lists[i][j] = word2id.get(word_lists[i][j], word2id['<unk>']) # 遇到词表中不存在的字符,使用<unk>代替
|
||||
|
||||
return word_id_lists
|
||||
|
||||
|
||||
def idlist2tag(prediction, tag2id, id2tag):
|
||||
pred_tag_lists = []
|
||||
pred_tag_lists.append([id2tag[ids.item()] for ids in torch.argmax(prediction, dim=2)[0]]
|
||||
)
|
||||
return pred_tag_lists
|
@ -1,44 +0,0 @@
|
||||
from inspect import classify_class_attrs
|
||||
from numpy import dtype
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
class WordTagDataset(Dataset):
|
||||
|
||||
def __init__(self, word_lists, tag_lists, vocabulary, tag2id) -> None:
|
||||
super(WordTagDataset).__init__()
|
||||
assert len(word_lists) == len(tag_lists)
|
||||
|
||||
# pairs = list(zip(word_lists,tag_lists))
|
||||
# indices = sorted(range(len(pairs)),key=lambda x:len(pairs[x][0]),reverse=True)
|
||||
# pairs = [pairs[i] for i in indices]
|
||||
# self.word_lists, self.tag_lists = list(zip(*pairs))
|
||||
|
||||
self.word_lists = word_lists
|
||||
self.tag_lists = tag_lists
|
||||
self.vocabulary = vocabulary
|
||||
self.tag2id = tag2id
|
||||
|
||||
def __getitem__(self, index):
|
||||
wordID_list = [self.vocabulary.get(word, self.vocabulary['<unk>']) for word in self.word_lists[index]]
|
||||
tagID_list = [self.tag2id.get(tag, self.tag2id['<unk>']) for tag in self.tag_lists[index]]
|
||||
MAX_PADDING = 64
|
||||
seq_len = len(wordID_list)
|
||||
if seq_len < MAX_PADDING:
|
||||
for i in range(MAX_PADDING - seq_len):
|
||||
wordID_list.append(self.vocabulary['<pad>'])
|
||||
tagID_list.append(self.tag2id['<pad>'])
|
||||
else:
|
||||
wordID_list = wordID_list[0:MAX_PADDING]
|
||||
tagID_list = tagID_list[0:MAX_PADDING]
|
||||
# print(torch.tensor(wordID_list, dtype=torch.long))
|
||||
# print(torch.tensor(tagID_list, dtype=torch.long))
|
||||
|
||||
return torch.tensor(wordID_list, dtype=torch.long), torch.tensor(tagID_list, dtype=torch.long)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word_lists)
|
||||
|
||||
|
||||
|
@ -1,92 +0,0 @@
|
||||
import config
|
||||
|
||||
|
||||
class Preprocessing:
|
||||
"""
|
||||
数据预处理:
|
||||
划分 训练集、验证集、测试集
|
||||
构建 字典
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, file_name: str) -> None:
|
||||
self.file_path = file_path
|
||||
self.file_name = file_name
|
||||
with open(file_path + file_name, 'r', encoding='utf-8') as file:
|
||||
self.item_list = file.read().split('\n\n')
|
||||
print('样本数量:', len(self.item_list))
|
||||
|
||||
def train_test_dev_split(self, data_rate: list):
|
||||
assert len(data_rate) == 3 and sum(data_rate) == 1
|
||||
|
||||
# 划分 训练集 验证集 测试集
|
||||
train_data_size = data_rate[0]
|
||||
dev_data_size = data_rate[1]
|
||||
test_data_size = data_rate[2]
|
||||
|
||||
train_data = self.item_list[0: round(train_data_size * len(self.item_list))]
|
||||
with open(f'{self.file_path}train.txt', 'w') as train_file:
|
||||
train_file.write('\n\n'.join(train_data))
|
||||
train_file.write('\n')
|
||||
print('train_data_sample_size: ', len(train_data))
|
||||
|
||||
dev_data = self.item_list[round(train_data_size * len(self.item_list)): round(
|
||||
(train_data_size + dev_data_size) * len(self.item_list))]
|
||||
with open(f'{self.file_path}dev.txt', 'w') as dev_file:
|
||||
dev_file.write('\n\n'.join(dev_data))
|
||||
dev_file.write('\n')
|
||||
print('dev_data_sample_size: ', len(dev_data))
|
||||
|
||||
test_data = self.item_list[round((train_data_size + dev_data_size) * len(self.item_list)): -1]
|
||||
with open(f'{self.file_path}test.txt', 'w') as test_file:
|
||||
test_file.write('\n\n'.join(test_data))
|
||||
test_file.write('\n')
|
||||
print('test_data_sample_size: ', len(test_data))
|
||||
|
||||
def construct_vocabulary_labels(self):
|
||||
|
||||
char_count = 0
|
||||
vocabulary = dict()
|
||||
labels = dict()
|
||||
|
||||
for item in self.item_list:
|
||||
for char_tag in item.split('\n'):
|
||||
try:
|
||||
char = char_tag.split(' ')[0]
|
||||
tag = char_tag.split(' ')[1]
|
||||
if char in vocabulary:
|
||||
vocabulary[char] += 1
|
||||
else:
|
||||
vocabulary[char] = 1
|
||||
if tag in labels:
|
||||
labels[tag] += 1
|
||||
else:
|
||||
labels[tag] = 1
|
||||
char_count += 1
|
||||
except:
|
||||
print(char_tag)
|
||||
print('字数:', char_count)
|
||||
|
||||
# 根据数量 降序 并写入vocabulary.txt和labels.txt文件
|
||||
vocabulary = dict(sorted(vocabulary.items(), key=lambda x: x[1], reverse=True))
|
||||
labels = dict(sorted(labels.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
with open(f'{self.file_path}vocabulary.txt', 'w') as vocabulary_file:
|
||||
vocabulary_file.write('\n'.join(vocabulary))
|
||||
print('vocabulary.txt constructed')
|
||||
|
||||
with open(f'{self.file_path}labels.txt', 'w') as labels_file:
|
||||
labels_file.write('\n'.join(labels))
|
||||
print('labels.txt constructed')
|
||||
|
||||
|
||||
def save2id(word2id, tag2id, model_name,path):
|
||||
with open(path + model_name+'_voc.txt', 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(word2id))
|
||||
with open(path + model_name+'_tags.txt', 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(tag2id))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
p = Preprocessing(file_path='./data/', file_name='annotated_data.txt')
|
||||
p.train_test_dev_split(data_rate=[0.7, 0.1, 0.2])
|
||||
p.construct_vocabulary_labels()
|
Loading…
Reference in New Issue
Block a user