Source code for pororo.tasks.question_generation

"""Question generation related modeling class"""

import re
import string
from collections import Counter, OrderedDict
from typing import List, Optional, Union

from whoosh.qparser import QueryParser

from pororo.tasks.utils.base import PororoFactoryBase, PororoGenerationBase
from pororo.tasks.utils.download_utils import download_or_load


[docs]class PororoQuestionGenerationFactory(PororoFactoryBase): """ Question generation using BART model Korean (`kobart.base.ko.qg`) - dataset: KorQuAD 1.0 (Lim et al. 2019) + AI hub Reading Comprehension corpus + AI hub Commonsense corpus - metric: Model base evaluation using PororoMrc (`brainbert.base`) - EM (82.59), F1 (94.06) - ref: https://www.aihub.or.kr/aidata/86 - ref: https://www.aihub.or.kr/aidata/84 Args: answer (str): answer text context (str): source article beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio n_wrong (int): number of wrong answer candidate return_context (bool): return context together or not Returns: str : question (if `n_wrong` < 1) Tuple[str, List[str]] : question, wrong_answers (if `n_wrong` > 1) Examples: >>> # question generation has argument named `n_wrong` which creates wrong answers together >>> qg = Pororo(task="qg", lang="ko") >>> qg( ... "카카오톡", ... "카카오톡은 스마트폰의 데이터 통신 기능을 이용하여, 문자 과금 없이 사람들과 메시지를 주고받을 수 있는 애플리케이션이다. 스마트폰 대중화 이 후 기존 인스턴트 메신저 앱의 번거로운 친구 추가 절차 없이, 스마트폰 주소록의 전 화번호만으로 손쉽게 메시지를 주고받을 수 있는 것이 특징이다.", ... n_wrong=3 ... ) (('스마트폰의 데이터 통신 기능을 이용해서 문자 과금 없이 사람들과 메시지를 주고받을 수 있는 앱을 뭐라고 해?', ['텔레그램', '챗온', '디스코드']) >>> # question generation task supports batchwise inference (1:N, N:1, N:N) >>> # answer : context = 1 : N >>> qg( ... "카카오톡", ... ["카카오톡은 스마트폰의 데이터 통신 기능을 이용하여, 문자 과금 없이 사람들과 메시지를 주고받을 수 있는 애플리케이션이다. 스마트폰 대중화 이 후 기존 인스턴트 메신저 앱의 번거로운 친구 추가 절차 없이, 스마트폰 주소록의 전화번호만으로 손쉽게 메시지를 주고받을 수 있는 것이 특징이다.", ... "메시징 서비스와 사회관계망서비스(SNS) 등 소셜미디어를 사용하는 사람들이 뉴스를 볼 때 가장 선호하는 플랫폼은 카카오톡인 것으로 나타났다. 30일 한국언론진흥재단이 최근 공개한 '2017 소셜미디어 이용자 조사' 결과를 보면 소셜미디어로 뉴스를 본 경험이 있는 우리나라 국민 1천747명에게 뉴스 이용 플랫폼을 중복으로 선택하게 한 결과 50.4%가 카카오톡을 사용했다고 답했다. 카카오톡 다음으로는 페이스북(42%) 사용률이 높았으며 유튜브(31.8%)가 뒤를 이었다."], ... n_wrong=3 ... ) [('스마트폰의 데이터 통신 기능을 이용해서 문자 과금 없이 사람들과 메시지를 주고받을 수 있는 앱을 뭐라고 해?', ['텔레그램', '챗온', '디스코드']), ('메시징 서비스와 사회관계망서비스(SNS) 등 소셜미디어를 사용하는 사람들이 뉴스를 볼 때 가장 선호하는 플랫폼은 뭐야?', ['텔레그램', '챗온', '디스코드'])] >>> # answer : context = N : 1 >>> qg( ... ["토트넘", "손흥민", "인스타그램"], ... "‘2020년 더 베스트 국제축구연맹(FIFA) 풋볼 어워드’에서 FIFA 푸스카스를 수상한 잉글랜드 프로축구 1부리그 프리미어리 그(EPL) 소속 토트넘 홋스퍼 FC의 손흥민이 기쁜 마음을 드러냈다. 푸스카스는 1년간 전 세계 축구경기에 서 나온 골 중 가장 멋진 골을 뽑는다. 손흥민은 스위스 취리히 소재 FIFA 본부에서 온라인으로 열린 ‘2020년 더 베스트 FIFA 풋볼 어워드’ 시 상식에서 푸스카스를 받았다. 손흥민의 이번 수상은 한국 선수로는 최초이자 아시아에서는 2016년 모하메 드 파이즈 수브리(말레이시아)에 이어 두 번째다. 상을 받은 손흥민은 이날 오전 인스타그램에 “아주 특별한 밤이었다. 저를 지지해주고 제게 투표해 주어서 감사하다”며 셀 프 카메라가 담긴 사진 한 장을 게시했 다. 공개된 사진 속 손흥민은 환한 미소와 엄지를 들어 보이고 있 다. 여기에 손흥민은 이 순간을 절대 잊 지 않겠다고 덧붙였다. 앞서 손흥민은 지난해 12월8일 번리 FC와 가진 EPL 경기에서 환상적인 골을 터뜨 렸다. 당시 손흥민은 토트넘 진영에서 얀 베르통언(벨기에)의 패 스를 받고 공을 잡고 약 70m를 혼자 내달리며 무려 번리 선수 6명을 따돌린 뒤 페널티 지역에서 오른발 슈팅으로 골망을 흔들었다. 이후 이 골은 EPL ‘12월의 골’을 시작으로 영국 공영방송 BBC의 ‘올해의 골’, 영국 스포츠 매체 디 애슬레틱의 ‘올해의 골’에 이어 EPL 사무국이 선정하는 2019∼20시즌 ‘올해의 골’ 등으로 선정되며 최고의 골로 인정받은 바 있다.", ... n_wrong=3 ... ) [('손흥민은 어느 팀 소속이야?', ['이즐링턴', '웸블리', '첼시']), ('누가 ‘2020년 더 베스트 국제축구연맹(FIFA) 풋볼 어워드’에서 FIFA 푸스카스를 수상했어?', ['지동원', '구자철', '이청용']), ('손흥민은 셀프 카메라가 담긴 사진 한 장을 어디에 게시했어?', ['트위터', '페이스북', '사운드클라우드'])] >>> # answer : context = N : N (answers list and contexts list must be same length. not N : M) >>> qg( ... ["카카오톡", "손흥민"], ... ["메시징 서비스와 사회관계망서비스(SNS) 등 소셜미디어를 사용하는 사 람들이 뉴스를 볼 때 가장 선호하는 플랫폼은 카카오톡인 것으로 나타났다. 30일 한국언론진흥재단이 최근 공개한 '2017 소셜미디어 이용자 조사' 결과를 보면 소셜미디어로 뉴스를 본 경험이 있는 우리나라 국민 1천747명에게 뉴스 이용 플랫폼을 중복으로 선택하게 한 결과 50.4%가 카카오톡을 사용했다고 답했다. 카 카오톡 다음으로는 페이스북(42%) 사용률이 높았으며 유튜브(31.8%)가 뒤를 이었다.", ... "‘2020년 더 베스트 국제축구연맹(FIFA) 풋볼 어워드’에서 FIFA 푸스카스를 수상한 잉글랜드 프로축구 1부리그 프리미어리그(EPL) 소속 토트넘 홋스퍼 FC의 손흥민이 기쁜 마음을 드러냈다. 푸스카스는 1년간 전 세계 축구경기에서 나온 골 중 가장 멋진 골을 뽑는다. 손흥민은 스위스 취리히 소 재 FIFA 본부에서 온라인으로 열린 ‘2020년 더 베스트 FIFA 풋볼 어워드’ 시 상식에서 푸스카스를 받았다. 손흥민의 이번 수상은 한국 선수로는 최초이자 아시아에서는 2016년 모하메 드 파이즈 수브리(말레이시아)에 이어 두 번째다. 상을 받은 손흥민은 이날 오전 인스타그램에 “아주 특별한 밤이었다. 저를 지지해주 고 제게 투표해 주어서 감사하다”며 셀프 카메라가 담긴 사진 한 장을 게시했 다. 공개된 사진 속 손흥민 은 환한 미소와 엄지를 들어 보이고 있다. 여기에 손흥민은 이 순간을 절대 잊 지 않겠다고 덧붙였다. 앞서 손흥민은 지난해 12월8일 번리 FC와 가진 EPL 경기에서 환상적인 골을 터뜨 렸다. 당시 손흥민은 토트 넘 진영에서 얀 베르통언(벨기에)의 패스를 받고 공을 잡고 약 70m를 혼자 내달리며 무려 번리 선수 6명 을 따돌린 뒤 페널티 지역에서 오른발 슈팅으로 골망을 흔들었다. 이후 이 골은 EPL ‘12월의 골’을 시작으로 영국 공영방송 BBC의 ‘올해의 골’, 영국 스포츠 매체 디 애슬레틱의 ‘올해의 골’에 이어 EPL 사무국이 선정하는 2019∼20시즌 ‘올해의 골’ 등으로 선정되며 최고의 골로 인정받은 바 있다."], ... n_wrong=3 ... ) [('메시징 서비스와 사회관계망서비스(SNS) 등 소셜미디어를 사용하는 사 람들이 뉴스를 볼 때 가장 선호하는 플랫폼은 뭐야?', ['텔레그램', '챗온', '디스코드']), ('누가 ‘2020년 더 베스트 국제축구연맹(FIFA) 풋볼 어워드’에서 FIFA 푸스카스를 수상했어?', ['지동원', '구자철', '이청용'])] """ def __init__(self, task: str, lang: str, model: Optional[str]): super().__init__(task, lang, model)
[docs] @staticmethod def get_available_langs(): return ["ko"]
[docs] @staticmethod def get_available_models(): return { "ko": ["kobart.base.ko.qg"], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if "bart" in self.config.n_model: from whoosh import index from pororo.models.bart.KoBART import KoBartModel from pororo.models.wikipedia2vec import Wikipedia2Vec from pororo.tasks import PororoTokenizationFactory vec_map = { "ko": "kowiki_20200720_100d.pkl", "en": "enwiki_20180420_100d.pkl", "ja": "jawiki_20180420_100d.pkl", "zh": "zhwiki_20180420_100d.pkl", } f_wikipedia2vec = download_or_load( f"misc/{vec_map[self.config.lang]}", self.config.lang, ) f_index = download_or_load( f"misc/{self.config.lang}_indexdir.zip", self.config.lang, ) model = Wikipedia2Vec(model_file=f_wikipedia2vec, device=device) idx = index.open_dir(f_index) sim_words = SimilarWords(model, idx) model_path = download_or_load( f"bart/{self.config.n_model}", self.config.lang, ) model = KoBartModel.from_pretrained( device=device, model_path=model_path, ) sent_tok = (lambda text: PororoTokenizationFactory( task="tokenization", lang=self.config.lang, model=f"sent_{self.config.lang}", ).load(device).predict(text)) return PororoKoBartQuestionGeneration( model, sim_words, sent_tok, self.config, )
[docs]class PororoKoBartQuestionGeneration(PororoGenerationBase): def __init__(self, model, sim_words, sent_tok, config): super(PororoKoBartQuestionGeneration, self).__init__(config) self._model = model self._max_len = 1024 self._start_token = "<unused0>" self._end_token = "<unused1>" self._sim_words = sim_words self._sent_tok = sent_tok def _focus_answer( self, context: str, answer: str, truncate: bool = True, ) -> str: """ add answer start and end token and truncate context text to make inference speed fast Args: context (str): context string answer (str): answer string truncate (bool): truncate or not Returns: context (str): preprocessed context string """ answer_start_idx = context.find(answer) answer_end_idx = answer_start_idx + len(answer) + len(self._start_token) context = self._insert_token(context, answer_start_idx, self._start_token) context = self._insert_token(context, answer_end_idx, self._end_token) if (len(context) < self._max_len) or (not truncate): return context sentences = self._sent_tok(context) answer_sentence_idx = None for i in range(len(sentences)): if self._start_token in sentences[i]: answer_sentence_idx = i i, j = answer_sentence_idx, answer_sentence_idx truncated_context = [sentences[answer_sentence_idx]] while len(" ".join(truncated_context)) < self._max_len: prev_context_length = len(" ".join(truncated_context)) i -= 1 j += 1 if i > 0: truncated_context = [sentences[i]] + truncated_context if j < len(sentences): truncated_context = truncated_context + [sentences[j]] if len(" ".join(truncated_context)) == prev_context_length: break return " ".join(truncated_context) def _insert_token(self, src: str, index: int, token: str) -> str: """ insert token to context Args: src (str): source string index (int): index that you want token (str): token string Returns: token_added_string (str): token added string """ return src[:index] + token + src[index:] def _postprocess(self, text: str): text = text.strip() if text[-1] != "?": text = text + "?" return text
[docs] def predict( self, answer: Union[str, List[str]], context: Union[str, List[str]], beam: int = 5, temperature: float = 1.0, top_k: int = -1, top_p: float = -1, no_repeat_ngram_size: int = 4, len_penalty: float = 1.0, n_wrong: int = 0, ): """ Conduct question generation Args: answer (str): answer text context (str): source article beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio n_wrong (int): number of wrong answer candidate """ sampling = False if top_k != -1 or top_p != -1: sampling = True output = self._model.translate( text=context, beam=beam, sampling=sampling, temperature=temperature, sampling_topk=top_k, sampling_topp=top_p, max_len_a=1, max_len_b=50, no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=len_penalty, ) if isinstance(output, str): output = self._postprocess(output) else: output = [self._postprocess(o) for o in output] if n_wrong > 0: if isinstance(context, str) and isinstance(answer, str): wrong_answers = self._sim_words._extract_wrongs(answer) output = output, wrong_answers[:n_wrong] elif isinstance(context, list) and isinstance(answer, str): wrong_answers = self._sim_words._extract_wrongs(answer)[:n_wrong] output = [(o, wrong_answers) for o in output] elif isinstance(context, str) and isinstance(answer, list): wrong_answers = [ self._sim_words._extract_wrongs(a)[:n_wrong] for a in answer ] output = [(output, w) for w in wrong_answers] else: wrong_answers = [ self._sim_words._extract_wrongs(a)[:n_wrong] for a in answer ] output = [(o, w) for o, w in zip(output, wrong_answers)] return output
def __call__( self, answer: Union[str, List[str]], context: Union[str, List[str]], beam: int = 5, temperature: float = 1.0, top_k: int = -1, top_p: float = -1, no_repeat_ngram_size: int = 4, len_penalty: float = 1.0, n_wrong: int = 0, return_context: bool = False, ): """ Conducnt question generation Args: answer (str): answer text context (str): source article beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio n_wrong (int): number of wrong answer candidate return_context (bool): return context together or not """ if isinstance(answer, str) and isinstance(context, str): context = self._focus_answer(context, answer) elif isinstance(answer, list) and isinstance(context, str): context = [self._focus_answer(context, a) for a in answer] elif isinstance(answer, str) and isinstance(context, list): context = [self._focus_answer(c, answer) for c in context] elif isinstance(answer, list) and isinstance(context, list): assert len(answer) == len( context), "length of answer list and context list must be same." context = [ self._focus_answer(c, a) for c, a in zip(context, answer) ] result = self.predict( answer, context, beam, temperature, top_k, top_p, no_repeat_ngram_size, len_penalty, n_wrong, ) if return_context: result = result, context return result
[docs]class SimilarWords(object): def __init__(self, model, idx): self._wikipedia2vec = model self._ix = idx self._searcher = self._ix.searcher() def _normalize(self, word: str): """ normalize input string Args: word (str): input string Returns: normalized string """ return word.lower().replace(" ", "_") def _entity(self, entity: str): """ find entity in entity dictionary Args: entity (str): entity string Returns: wikipedia2vec entity """ entity = self._normalize(entity) entity = self._wikipedia2vec.get_entity(entity) return entity def _similar(self, entity: str): """ find similar word with given entity Args: entity (str): answer that user inputted Returns: dict: category to entity list """ entity_hit = None entity_ = self._entity(entity) headword2relatives = {} if not entity_: return headword2relatives from_searchterms = QueryParser( "searchterms", self._ix.schema, ).parse(entity) hits = self._searcher.search(from_searchterms) for hit in hits: wiki_title = hit["wiki_title"] if wiki_title == entity: entity_hit = hit if not entity_hit: return headword2relatives results = self._wikipedia2vec.most_similar(entity_) categories = entity_hit["categories"].split(";") category2entities = {category: [] for category in categories} if not results: return category2entities for result in results: if hasattr(result[0], "text"): continue if result[0].title == entity_.title or "분류" in result[0].title: continue idx = result[0].index.item() from_idx = QueryParser("entity_idx", self._ix.schema).parse(str(idx)) hits2 = self._searcher.search(from_idx) if hits2: categories2 = hits2[0]["categories"].split(";") for each in categories2: if each in category2entities: category2entities[each].append(result[0].title) return category2entities def _extract_wrongs(self, entity: str) -> List[str]: """ extract wrong answers candidates Args: entity: entity string Returns: wrong_list (List[str]): wrong answer candidates list """ entity_list = [] answer = self._normalize_answer(entity) sims = self._similar(answer) sims = list(sims.items()) if len(sims) == 0: return entity_list for key, val in sims: if key.lower() != "word": entity_list += self._compare_with_answer(val, answer) return list(OrderedDict.fromkeys(entity_list)) def _compare_with_answer( self, entity_list: List[str], answer: str, ) -> List[str]: """ add wrong answer candidate to list after compare with answer using n-gram (f1 score) Args: entity_list (List[str]): wrong answers candidates answer (str): answer that will be compared wrong answer Returns: result_list (List[str]): wrong answer candidates list """ result_list = [] for e in entity_list: if "분류" not in e: e = re.sub(r"\([^)]*\)", "", e).strip() if self._f1_score(e, answer) < 0.5: result_list.append(e) return result_list def _normalize_answer(self, s): """ normalize answer string Args: s: sentence Returns: normalized sentence References: https://korquad.github.io/ """ def remove_(text): text = re.sub("'", " ", text) text = re.sub('"', " ", text) text = re.sub("《", " ", text) text = re.sub("》", " ", text) text = re.sub("<", " ", text) text = re.sub(">", " ", text) text = re.sub("〈", " ", text) text = re.sub("〉", " ", text) text = re.sub("\\(", " ", text) text = re.sub("\\)", " ", text) text = re.sub("‘", " ", text) text = re.sub("’", " ", text) return text def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_punc(lower(remove_(s)))) def _f1_score(self, prediction, ground_truth): """ compute F1 score Args: prediction: prediction answer ground_truth: ground truth answer Returns: F1 score between prediction and ground truth References: https://korquad.github.io/ """ prediction_tokens = self._normalize_answer(prediction).split() ground_truth_tokens = self._normalize_answer(ground_truth).split() # F1 by character prediction_Char = [] for tok in prediction_tokens: now = [a for a in tok] prediction_Char.extend(now) ground_truth_Char = [] for tok in ground_truth_tokens: now = [a for a in tok] ground_truth_Char.extend(now) common = Counter(prediction_Char) & Counter(ground_truth_Char) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_Char) recall = 1.0 * num_same / len(ground_truth_Char) f1 = (2 * precision * recall) / (precision + recall) return f1