"""Phoneme to Grapheme related modeling class"""
from typing import Optional
from pororo.tasks.utils.base import (
PororoFactoryBase,
PororoGenerationBase,
PororoSimpleBase,
)
from pororo.tasks.utils.download_utils import download_or_load
[docs]class PororoP2gFactory(PororoFactoryBase):
"""
Conduct phoneme to grapheme conversion
Japanese (`p2g.ja`)
- dataset: jawiki-20180420 + romkan
- metric: TBU
Chinese (`p2g.zh`)
- dataset: zhwiki-20180420 + g2pM
- metric: TBU
Examples:
>>> p2g_zh = Pororo(task="p2g", lang="zh")
>>> p2g_zh(['ran2', 'er2', ',', 'ta1', 'hong2', 'le5', '20', 'nian2', 'yi3', 'hou4', ',', 'ta1', 'jing4', 'tui4', 'chu1', 'le5', 'da4', 'jia1', 'de5', 'shi4', 'xian4', '。'])
['然', '而', ',', '他', '红', '了', '20', '年', '乙', '后', ',', '他', '敬', '退', '出', '了', '大', '家', '的', '市', '县', '。']
>>> p2g_ja = Pororo(task="p2g", lang="ja")
>>> p2g_ja("python ga daisuki desu。")
pythonが大好きです。
"""
def __init__(self, task: str, lang: str, model: Optional[str]):
super().__init__(task, lang, model)
[docs] @staticmethod
def get_available_langs():
return ["zh", "ja"]
[docs] @staticmethod
def get_available_models():
return {
"zh": ["p2g.zh"],
"ja": ["p2g.ja"],
}
[docs] def load(self, device: str):
"""
Load user-selected task-specific model
Args:
device (str): device information
Returns:
object: User-selected task-specific model
"""
if self.config.n_model == "p2g.zh":
from pororo.models.p2g import P2gM
pinyin = download_or_load(
f"misc/pinyin2idx.{self.config.lang}.pkl",
self.config.lang,
)
char = download_or_load(
f"misc/char2idx.{self.config.lang}.pkl",
self.config.lang,
)
ckpt = download_or_load(
f"misc/{self.config.n_model}.pt",
self.config.lang,
)
model = P2gM(pinyin, char, ckpt, device)
return PororoP2GZh(model, self.config)
if self.config.n_model == "p2g.ja":
from fairseq.models.transformer import TransformerModel
load_dict = download_or_load(
"transformer/transformer.base.ja.p2g",
self.config.lang,
)
model = (TransformerModel.from_pretrained(
model_name_or_path=load_dict.path,
checkpoint_file="transformer.base.ja.p2g.pt",
data_name_or_path=load_dict.dict_path,
source_lang=load_dict.src_dict,
target_lang=load_dict.tgt_dict,
).eval().to(device))
return PororoP2GJa(model, self.config)
[docs]class PororoP2GZh(PororoSimpleBase):
def __init__(self, model, config):
super().__init__(config)
self._model = model
[docs] def predict(self, sent: str, **kwargs) -> str:
"""
Conduct grapheme to phoneme conversion
Args:
texts (List[str]): list of graphemes
Returns:
List[str]: converted phoeme string list
"""
results = self._model(sent)
return results
[docs]class PororoP2GJa(PororoGenerationBase):
def __init__(self, model, config):
super().__init__(config)
self._model = model
def _preprocess(self, sent: str) -> str:
"""
Preprocess non-chinese input sentence to replace whitespace token with whitespace
Args:
sent (str): non-chinese sentence
Returns:
str: preprocessed non-chinese sentence
"""
sent = sent.replace(" ", "▁")
return " ".join([c for c in sent])
def _postprocess(self, output: str) -> str:
"""
Postprocess output sentence to replace whitespace
Args:
output (str): output sentence generated by model
Returns:
str: postprocessed output sentence
"""
output = output.replace("▁", "")
return "".join(output.split())
[docs] def predict(
self,
text: str,
beam: int = 5,
temperature: float = 1.0,
top_k: int = -1,
top_p: float = -1,
no_repeat_ngram_size: int = 4,
len_penalty: float = 1.0,
**kwargs,
) -> str:
"""
Conduct paraphrase generation using Transformer Seq2Seq
Args:
text (str): input sentence
beam (int): beam search size
temperature (float): temperature scale
top_k (int): top-K sampling vocabulary size
top_p (float): top-p sampling ratio
no_repeat_ngram_size (int): no repeat ngram size
len_penalty (float): length penalty ratio
Returns:
str: generated paraphrase
"""
sampling = False
if top_k != -1 or top_p != -1:
sampling = True
text = self._preprocess(text)
output = self._model.translate(
text,
beam=beam,
sampling=sampling,
temperature=temperature,
sampling_topk=top_k,
sampling_topp=top_p,
no_repeat_ngram_size=no_repeat_ngram_size,
lenpen=len_penalty,
)
output = self._postprocess(output)
return output