Source code for pororo.tasks.machine_translation

"""Machine-translation related modeling class"""

from typing import Optional

from pororo.tasks.utils.base import PororoFactoryBase, PororoGenerationBase
from pororo.tasks.utils.download_utils import download_or_load


[docs]class PororoTranslationFactory(PororoFactoryBase): """ Machine translation using Transformer models Multi (`transformer.large.multi.mtpg`) - dataset: Train (Internal data) / Test (Multilingual TED Talk) - metric: BLEU score +-----------------+-----------------+------------+ | Source Language | Target Language | BLEU score | +=================+=================+============+ | Average | X | 10.00 | +-----------------+-----------------+------------+ | English | Korean | 15 | +-----------------+-----------------+------------+ | English | Japanese | 8 | +-----------------+-----------------+------------+ | English | Chinese | 8 | +-----------------+-----------------+------------+ | Korean | English | 15 | +-----------------+-----------------+------------+ | Korean | Japanese | 10 | +-----------------+-----------------+------------+ | Korean | Chinese | 4 | +-----------------+-----------------+------------+ | Japanese | English | 11 | +-----------------+-----------------+------------+ | Japanese | Korean | 13 | +-----------------+-----------------+------------+ | Japanese | Chinese | 4 | +-----------------+-----------------+------------+ | Chinese | English | 16 | +-----------------+-----------------+------------+ | Chinese | Korean | 10 | +-----------------+-----------------+------------+ | Chinese | Japanese | 6 | +-----------------+-----------------+------------+ - ref: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/ - note: This result is about out of domain settings, TED Talk data wasn't used during model training. Multi (`transformer.large.multi.fast.mtpg`) - dataset: Train (Internal data) / Test (Multilingual TED Talk) - metric: BLEU score +-----------------+-----------------+------------+ | Source Language | Target Language | BLEU score | +=================+=================+============+ | Average | X | 8.75 | +-----------------+-----------------+------------+ | English | Korean | 13 | +-----------------+-----------------+------------+ | English | Japanese | 6 | +-----------------+-----------------+------------+ | English | Chinese | 7 | +-----------------+-----------------+------------+ | Korean | English | 15 | +-----------------+-----------------+------------+ | Korean | Japanese | 11 | +-----------------+-----------------+------------+ | Korean | Chinese | 10 | +-----------------+-----------------+------------+ | Japanese | English | 3 | +-----------------+-----------------+------------+ | Japanese | Korean | 13 | +-----------------+-----------------+------------+ | Japanese | Chinese | 4 | +-----------------+-----------------+------------+ | Chinese | English | 15 | +-----------------+-----------------+------------+ | Chinese | Korean | 8 | +-----------------+-----------------+------------+ | Chinese | Japanese | 4 | +-----------------+-----------------+------------+ - ref: http://www.cs.jhu.edu/~kevinduh/a/multitarget-tedtalks/ - note: This result is about out of domain settings, TED Talk data wasn't used during model training. Args: text (str): input text to be translated beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio Returns: str: machine translated sentence Examples: >>> mt = Pororo(task="translation", lang="multi") >>> mt("케빈은 아직도 일을 하고 있다.", src="ko", tgt="en") 'Kevin is still working.' >>> mt("死神は りんごしか食べない。", src="ja", tgt="ko") '사신은 사과밖에 먹지 않는다.' >>> mt("人生的伟大目标,不是知识而是行动。", src="zh", tgt="ko") '인생의 위대한 목표는 지식이 아니라 행동이다.' """ def __init__( self, task: str, lang: str, model: Optional[str], tgt: str = None, ): super().__init__(task, lang, model) self._src = self.config.lang self._tgt = tgt
[docs] @staticmethod def get_available_langs(): return ["multi"]
[docs] @staticmethod def get_available_models(): return { "multi": [ "transformer.large.multi.mtpg", "transformer.large.multi.fast.mtpg", ], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ from pororo.tasks import PororoTokenizationFactory sent_tokenizer = (lambda text, lang: PororoTokenizationFactory( task="tokenization", lang=lang, model=f"sent_{lang}", ).load(device).predict(text)) if "multi" in self.config.n_model: from fairseq.models.transformer import TransformerModel from pororo.tasks.utils.tokenizer import CustomTokenizer load_dict = download_or_load( f"transformer/{self.config.n_model}", self.config.lang, ) model = (TransformerModel.from_pretrained( model_name_or_path=load_dict.path, checkpoint_file=f"{self.config.n_model}.pt", data_name_or_path=load_dict.dict_path, source_lang=load_dict.src_dict, target_lang=load_dict.tgt_dict, ).eval().to(device)) tokenizer = CustomTokenizer.from_file( vocab_filename=f"{load_dict.src_tok}/vocab.json", merges_filename=f"{load_dict.src_tok}/merges.txt", ) if "mtpg" in self.config.n_model: langtok_style = "mbart" elif "m2m" in self.config.n_model: langtok_style = "multilingual" else: langtok_style = "basic" return PororoTransformerTransMulti( model, self.config, tokenizer, sent_tokenizer, langtok_style, )
[docs]class PororoTransformerTransMulti(PororoGenerationBase): def __init__(self, model, config, tokenizer, sent_tokenizer, langtok_style): super().__init__(config) self._model = model self._tokenizer = tokenizer self._sent_tokenizer = sent_tokenizer self._langtok_style = langtok_style def _langtok(self, lang: str, langtok_style: str): """ Args: lang (str): language langtok_style (str): style of language token See Also: https://github.com/pytorch/fairseq/blob/master/fairseq/data/multilingual/multilingual_utils.py#L34 """ if langtok_style == "basic": return f"[{lang.upper()}]" elif langtok_style == "mbart": mapping = {"en": "_XX", "ja": "_XX", "ko": "_KR", "zh": "_CN"} return f"[{lang + mapping[lang]}]" elif langtok_style == "multilingual": return f"__{lang}__" def _preprocess(self, text: str, src: str, tgt: str) -> str: """ Preprocess non-chinese input sentence to replace whitespace token with whitespace Args: text (str): non-chinese sentence src (str): source language tgt (str): target language Returns: str: preprocessed non-chinese sentence """ if src == "en": pieces = " ".join(self._tokenizer.segment(text.strip())) else: pieces = " ".join([c if c != " " else "▁" for c in text.strip()]) return f"{self._langtok(src, self._langtok_style)} {pieces} {self._langtok(tgt, self._langtok_style)}" def _postprocess(self, output: str) -> str: """ Postprocess output sentence to replace whitespace Args: output (str): output sentence generated by model Returns: str: postprocessed output sentence """ return output.replace(" ", "").replace("▁", " ").strip()
[docs] def predict( self, text: str, src: str, tgt: str, beam: int = 5, temperature: float = 1.0, top_k: int = -1, top_p: float = -1, no_repeat_ngram_size: int = 4, len_penalty: float = 1.0, **kwargs, ) -> str: """ Conduct machine translation Args: text (str): input text to be translated beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio Returns: str: machine translated sentence """ text = self._preprocess(text, src, tgt) sampling = False if top_k != -1 or top_p != -1: sampling = True output = self._model.translate( text, beam=beam, sampling=sampling, temperature=temperature, sampling_topk=top_k, sampling_topp=top_p, max_len_a=1, max_len_b=50, no_repeat_ngram_size=no_repeat_ngram_size, lenpen=len_penalty, ) output = self._postprocess(output) return output
def __call__( self, text: str, src: str, tgt: str, beam: int = 5, temperature: float = 1.0, top_k: int = -1, top_p: float = -1, no_repeat_ngram_size: int = 4, len_penalty: float = 1.0, ): assert isinstance(text, str), "Input text should be string type" assert src in [ "ko", "zh", "ja", "en", ], "Source language must be one of CJKE !" assert tgt in [ "ko", "zh", "ja", "en", ], "Target language must be one of CJKE !" return " ".join([ self.predict( t, src, tgt, beam, temperature, top_k, top_p, no_repeat_ngram_size, len_penalty, ) for t in self._sent_tokenizer(text, src) ])