gradio/demo/question_answer/files/bert.py
Ömer Faruk Özdemir cc0cff893f Format The Codebase
- black formatting
- isort formatting
2022-01-21 16:44:12 +03:00

98 lines
3.4 KiB
Python

from __future__ import absolute_import, division, print_function
import collections
import logging
import math
import numpy as np
import torch
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from utils import (get_answer, input_to_squad_example,
squad_examples_to_features, to_list)
RawResult = collections.namedtuple(
"RawResult", ["unique_id", "start_logits", "end_logits"]
)
class QA:
def __init__(self, model_path: str):
self.max_seq_length = 384
self.doc_stride = 128
self.do_lower_case = True
self.max_query_length = 64
self.n_best_size = 20
self.max_answer_length = 30
self.model, self.tokenizer = self.load_model(model_path)
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.model.to(self.device)
self.model.eval()
def load_model(self, model_path: str, do_lower_case=False):
config = BertConfig.from_pretrained(model_path + "/bert_config.json")
tokenizer = BertTokenizer.from_pretrained(
model_path, do_lower_case=do_lower_case
)
model = BertForQuestionAnswering.from_pretrained(
model_path, from_tf=False, config=config
)
return model, tokenizer
def predict(self, passage: str, question: str):
example = input_to_squad_example(passage, question)
features = squad_examples_to_features(
example,
self.tokenizer,
self.max_seq_length,
self.doc_stride,
self.max_query_length,
)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor(
[f.input_mask for f in features], dtype=torch.long
)
all_segment_ids = torch.tensor(
[f.segment_ids for f in features], dtype=torch.long
)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(
all_input_ids, all_input_mask, all_segment_ids, all_example_index
)
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
all_results = []
for batch in eval_dataloader:
batch = tuple(t.to(self.device) for t in batch)
with torch.no_grad():
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2],
}
example_indices = batch[3]
outputs = self.model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
result = RawResult(
unique_id=unique_id,
start_logits=to_list(outputs[0][i]),
end_logits=to_list(outputs[1][i]),
)
all_results.append(result)
answer = get_answer(
example,
features,
all_results,
self.n_best_size,
self.max_answer_length,
self.do_lower_case,
)
return answer