Source code for pororo.tasks.grapheme_conversion

"""Phoneme to Grapheme related modeling class"""

from typing import Optional

from pororo.tasks.utils.base import (
    PororoFactoryBase,
    PororoGenerationBase,
    PororoSimpleBase,
)
from pororo.tasks.utils.download_utils import download_or_load


[docs]class PororoP2gFactory(PororoFactoryBase): """ Conduct phoneme to grapheme conversion Japanese (`p2g.ja`) - dataset: jawiki-20180420 + romkan - metric: TBU Chinese (`p2g.zh`) - dataset: zhwiki-20180420 + g2pM - metric: TBU Examples: >>> p2g_zh = Pororo(task="p2g", lang="zh") >>> p2g_zh(['ran2', 'er2', ',', 'ta1', 'hong2', 'le5', '20', 'nian2', 'yi3', 'hou4', ',', 'ta1', 'jing4', 'tui4', 'chu1', 'le5', 'da4', 'jia1', 'de5', 'shi4', 'xian4', '。']) ['然', '而', ',', '他', '红', '了', '20', '年', '乙', '后', ',', '他', '敬', '退', '出', '了', '大', '家', '的', '市', '县', '。'] >>> p2g_ja = Pororo(task="p2g", lang="ja") >>> p2g_ja("python ga daisuki desu。") pythonが大好きです。 """ def __init__(self, task: str, lang: str, model: Optional[str]): super().__init__(task, lang, model)
[docs] @staticmethod def get_available_langs(): return ["zh", "ja"]
[docs] @staticmethod def get_available_models(): return { "zh": ["p2g.zh"], "ja": ["p2g.ja"], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ if self.config.n_model == "p2g.zh": from pororo.models.p2g import P2gM pinyin = download_or_load( f"misc/pinyin2idx.{self.config.lang}.pkl", self.config.lang, ) char = download_or_load( f"misc/char2idx.{self.config.lang}.pkl", self.config.lang, ) ckpt = download_or_load( f"misc/{self.config.n_model}.pt", self.config.lang, ) model = P2gM(pinyin, char, ckpt, device) return PororoP2GZh(model, self.config) if self.config.n_model == "p2g.ja": from fairseq.models.transformer import TransformerModel load_dict = download_or_load( "transformer/transformer.base.ja.p2g", self.config.lang, ) model = (TransformerModel.from_pretrained( model_name_or_path=load_dict.path, checkpoint_file="transformer.base.ja.p2g.pt", data_name_or_path=load_dict.dict_path, source_lang=load_dict.src_dict, target_lang=load_dict.tgt_dict, ).eval().to(device)) return PororoP2GJa(model, self.config)
[docs]class PororoP2GZh(PororoSimpleBase): def __init__(self, model, config): super().__init__(config) self._model = model
[docs] def predict(self, sent: str, **kwargs) -> str: """ Conduct grapheme to phoneme conversion Args: texts (List[str]): list of graphemes Returns: List[str]: converted phoeme string list """ results = self._model(sent) return results
[docs]class PororoP2GJa(PororoGenerationBase): def __init__(self, model, config): super().__init__(config) self._model = model def _preprocess(self, sent: str) -> str: """ Preprocess non-chinese input sentence to replace whitespace token with whitespace Args: sent (str): non-chinese sentence Returns: str: preprocessed non-chinese sentence """ sent = sent.replace(" ", "▁") return " ".join([c for c in sent]) def _postprocess(self, output: str) -> str: """ Postprocess output sentence to replace whitespace Args: output (str): output sentence generated by model Returns: str: postprocessed output sentence """ output = output.replace("▁", "") return "".join(output.split())
[docs] def predict( self, text: str, beam: int = 5, temperature: float = 1.0, top_k: int = -1, top_p: float = -1, no_repeat_ngram_size: int = 4, len_penalty: float = 1.0, **kwargs, ) -> str: """ Conduct paraphrase generation using Transformer Seq2Seq Args: text (str): input sentence beam (int): beam search size temperature (float): temperature scale top_k (int): top-K sampling vocabulary size top_p (float): top-p sampling ratio no_repeat_ngram_size (int): no repeat ngram size len_penalty (float): length penalty ratio Returns: str: generated paraphrase """ sampling = False if top_k != -1 or top_p != -1: sampling = True text = self._preprocess(text) output = self._model.translate( text, beam=beam, sampling=sampling, temperature=temperature, sampling_topk=top_k, sampling_topp=top_p, no_repeat_ngram_size=no_repeat_ngram_size, lenpen=len_penalty, ) output = self._postprocess(output) return output