Source code for pororo.tasks.zero_shot_classification

"""Zero-shot Classification related modeling class"""

from typing import Dict, List, Optional

from pororo.tasks.utils.base import PororoBiencoderBase, PororoFactoryBase


[docs]class PororoZeroShotFactory(PororoFactoryBase): """ Zero-shot topic classification See also: https://joeddav.github.io/blog/2020/05/29/ZSL.html Korean (`brainbert.base.ko.kornli`) - dataset: KorNLI (Ham et al. 2020) - metric: N/A English (`roberta.base.en.nli`) - dataset: MNLI (Adina Williams et al. 2017) - metric: N/A Japanese (`jaberta.base.ja.nli`) - dataset: XNLI (Alexis Conneau et al. 2018) - metric: N/A Chinese (`zhberta.base.zh.nli`) - dataset: XNLI (Alexis Conneau et al. 2018) - metric: N/A Examples: >>> zsl = Pororo(task="zero-topic") >>> zsl("Who are you voting for in 2020?", ["business", "art & culture", "politics"]) {'business': 33.23, 'art & culture': 8.33, 'politics': 96.12} >>> zsl = Pororo(task="zero-topic", lang="ko") >>> zsl('''라리가 사무국, 메시 아닌 바르사 지지..."바이 아웃 유효" [공식발표]''', ["스포츠", "사회", "정치", "경제", "생활/문화", "IT/과학"]) {'스포츠': 94.15, '사회': 37.11, '정치': 74.26, '경제': 39.18, '생활/문화': 71.15, 'IT/과학': 34.71} >>> zsl('''장제원, 김종인 당무감사 추진에 “참 잔인들 하다”···정강정책 개정안은 “졸작”''', ["스포츠", "사회", "정치", "경제", "생활/문화", "IT/과학"]) {'스포츠': 2.18, '사회': 56.1, '정치': 88.24, '경제': 16.17, '생활/문화': 66.13, 'IT/과학': 11.2} >>> zsl = Pororo(task="zero-topic", lang="ja") >>> zsl("香川 真司は、兵庫県神戸市垂水区出身のプロサッカー選手。元日本代表。ポジションはMF、FW。ボルシア・ドルトムント時代の2010-11シーズンでリーグ前半期17試合で8得点を記録し9シーズンぶりのリーグ優勝に貢献。キッカー誌が選定したブンデスリーガの年間ベスト イレブンに名を連ねた。", ["スポーツ", "政治", "技術"]) {'スポーツ': 0.2, '政治': 99.71, '技術': 68.9} >>> zsl = Pororo(task="zero-topic", lang="zh") >>> zsl("商务部14日发布数据显示,今年前10个月,我国累计对外投资904.6亿美元,同比增长5.9%。", ["政治", "经济", "国际化"]) {'政治': 33.72, '经济': 3.9, '国际化': 13.67} """ def __init__(self, task: str, lang: str, model: Optional[str]): super().__init__(task, lang, model)
[docs] @staticmethod def get_available_langs(): return ["en", "ko", "ja", "zh"]
[docs] @staticmethod def get_available_models(): return { "ko": ["brainbert.base.ko.kornli"], "ja": ["jaberta.base.ja.nli"], "zh": ["zhberta.base.zh.nli"], "en": ["roberta.base.en.nli"], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if "brainbert" in self.config.n_model: from pororo.models.brainbert import BrainRobertaModel model = BrainRobertaModel.load_model( f"bert/{self.config.n_model}", self.config.lang, ).eval().to(device) return PororoBertZeroShot(model, self.config) if "jaberta" in self.config.n_model: from pororo.models.brainbert import JabertaModel model = JabertaModel.load_model( f"bert/{self.config.n_model}", self.config.lang, ).eval().to(device) return PororoBertZeroShot(model, self.config) if "zhberta" in self.config.n_model: from pororo.models.brainbert import ZhbertaModel model = ZhbertaModel.load_model( f"bert/{self.config.n_model}", self.config.lang, ).eval().to(device) return PororoBertZeroShot(model, self.config) if "roberta" in self.config.n_model: from pororo.models.brainbert import CustomRobertaModel model = CustomRobertaModel.load_model( f"bert/{self.config.n_model}", self.config.lang, ).eval().to(device) return PororoBertZeroShot(model, self.config)
[docs]class PororoBertZeroShot(PororoBiencoderBase): def __init__(self, model, config): super().__init__(config) self._model = model self._template = { "ko": "이 문장은 {label}에 관한 것이다.", "ja": "この文は、{label}に関するものである。", "zh": "这句话是关于{label}的。", "en": "This sentence is about {label}.", }
[docs] def predict( self, sent: str, labels: List[str], **kwargs, ) -> Dict[str, float]: """ Conduct zero-shot classification Args: sent (str): sentence to be classified labels (List[str]): candidate labels Returns: List[Tuple(str, float)]: confidence scores corresponding to each input label """ cands = [ self._template[self.config.lang].format(label=label) for label in labels ] result = dict() for label, cand in zip(labels, cands): if self.config.lang == "ko": tokens = self._model.encode( sent, cand, add_special_tokens=True, no_separator=False, ) else: tokens = self._model.encode( sent, cand, no_separator=False, ) # throw away "neutral" (dim 1) and take the probability of "entail" (2) as the probability of the label being true pred = self._model.predict( "sentence_classification_head", tokens, return_logits=True, )[:, [0, 2]] prob = pred.softmax(dim=1)[:, 1].item() * 100 result[label] = round(prob, 2) return result