"""Constituency Parsing related modeling class"""
import re
from typing import List, Optional, Tuple
from lxml import etree
from pororo.tasks.utils.base import PororoFactoryBase, PororoTaskBase
from pororo.tasks.utils.download_utils import download_or_load
[docs]class PororoConstFactory(PororoFactoryBase):
"""
Constituency parsing using Transformer model
English (`transformer.base.en.const`)
- dataset: OntoNotes 5.0
- metric: TBU
Korean (`transformer.base.en.const`)
- dataset: Sejong Corpus
- metric: TBU
Chinese (`transformer.base.zh.const`)
- dataset: OntoNotes 5.0
- metric: TBU
Args:
text (str): input text
beam (int): size of beam search
pos (bool): contains PoS tagging or not
Returns:
result: result of constituency parsing
Examples:
>>> const = Pororo(task="const", lang="en")
>>> const("I love this place")
<TOP>
<S>
<NP>I</NP>
<VP>
love
<NP>this place</NP>
</VP>
</S>
</TOP>
>>> const = Pororo(task="const", lang="zh")
>>> const("我喜欢饼干")
<TOP>
<IP>
<NP>我</NP>
<VP>
喜欢
<NP>饼干</NP>
</VP>
</IP>
</TOP>
>>> const = Pororo(task="const", lang="ko")
>>> const("미국에서도 같은 우려가 나오고 있다.")
<S>
<NP_AJT>미국/NNP+에서/JKB+도/JX</NP_AJT>
<S>
<NP_SBJ>
<VP_MOD>같/VA+은/ETM</VP_MOD>
<NP_SBJ>우려/NNG+가/JKS</NP_SBJ>
</NP_SBJ>
<VP>
<VP>나오/VV+고/EC</VP>
<VP>있/VX+다/EF+./SF</VP>
</VP>
</S>
</S>
"""
def __init__(self, task: str, lang: str, model: Optional[str]):
super().__init__(task, lang, model)
[docs] @staticmethod
def get_available_langs():
return ["en", "ko", "zh"]
[docs] @staticmethod
def get_available_models():
return {
"en": ["transformer.base.en.const"],
"ko": ["transformer.base.ko.const"],
"zh": ["transformer.base.zh.const"],
}
[docs] def load(self, device: str):
"""
Load user-selected task-specific model
Args:
device (str): device information
Returns:
object: User-selected task-specific model
"""
if "transformer" in self.config.n_model:
from fairseq.models.transformer import TransformerModel
from pororo.tasks import PororoPosFactory
load_dict = download_or_load(
f"transformer/{self.config.n_model}",
self.config.lang,
)
model = (TransformerModel.from_pretrained(
model_name_or_path=load_dict.path,
checkpoint_file=f"{self.config.n_model}.pt",
data_name_or_path=load_dict.dict_path,
source_lang=load_dict.src_dict,
target_lang=load_dict.tgt_dict,
).eval().to(device))
if self.config.lang == "ko":
tagger = PororoPosFactory(
task="pos",
model="mecab-ko",
lang=self.config.lang,
).load(device)
return PororoTransConstKo(model, tagger, self.config)
if self.config.lang == "en":
tagger = PororoPosFactory(
task="pos",
model="nltk",
lang=self.config.lang,
).load(device)
return PororoTransConstEn(model, tagger, self.config)
if self.config.lang == "zh":
tagger = PororoPosFactory(
task="pos",
model="jieba",
lang=self.config.lang,
).load(device)
return PororoTransConstZh(model, tagger, self.config)
[docs]class PororoConstBase(PororoTaskBase):
"""Constituency Parsing base class containinig various methods related to Const. Parsing"""
def _fix_tree(self, output: str):
"""
Fix tree when XML conversion is not conducted
Args:
output (str): string to fix
Returns:
text: fixed tree string
"""
tag_ptn = "[A-Z][A-Z_]*"
output = re.sub("\s", "", output)
xml = re.sub(f"<({tag_ptn})>", r"[\1 ", output)
xml = re.sub(f"</{tag_ptn}>", r"] ", xml)
def _convert_to_xml(text):
for _ in range(max(text.count("["), text.count("]"))):
text = re.sub(
f"(?s)[({tag_ptn})([^[]]+?)]",
r"<\1>\2 </\1>",
text,
)
return text
xml = _convert_to_xml(xml)
xml = re.sub(f"[{tag_ptn}", "", xml)
xml = re.sub(f"{tag_ptn}]", "", xml)
xml = re.sub("[[]\s]", "", xml)
return xml
def _prettify(self, output: str):
"""
Prettify model result using XML tree
Args:
output (str): string to make tree
Returns:
pretty: tree style output
"""
output = re.sub("> +", ">", output)
output = re.sub(" +<", "<", output)
output = re.sub(
"(<[A-Za-z_\d]+>) *([^< ]+) *(<[^/])",
r"\1<temp>\2</temp>\3",
output,
)
output = re.sub(
"(</[A-Za-z_\d]+>) *([^< ]+) *(</)",
r"\1<temp>\2</temp>\3",
output,
)
try:
root = etree.fromstring(output)
except:
root = etree.fromstring(self._fix_tree(output))
tree = etree.ElementTree(root)
pretty = etree.tostring(tree, pretty_print=True, encoding="unicode")
pretty = pretty.replace("<temp>", "").replace("</temp>", "")
return pretty.replace(" ", "\t")
def __call__(
self,
text: str,
beam: int = 5,
pos: bool = False,
**kwargs,
):
"""
Conduct constituency parsing
Args:
text (str): input text
beam (int): size of beam search
pos (bool): contains PoS tagging or not
Returns:
result: result of constituency parsing
"""
assert isinstance(text, str), "Input text should be string type"
text = self._normalize(text)
return self.predict(text, beam, pos, **kwargs)
[docs]class PororoTransConstKo(PororoConstBase):
def __init__(self, model, tagger, config):
super().__init__(config)
self._model = model
self._tagger = tagger
def _postprocess(
self,
result: List[str],
eojeols: List[str],
poses: List[str],
):
"""
Postprocess method to make XML format
Args:
result (List[str]): constituency parsing result
eojeols (List): list of eojeol
poses (List): list of pos tag
Returns:
str: result of postprocess
"""
token_indices = []
temp_group = []
for i, res in enumerate(result):
if ("<" in res) or (">" in res):
continue
if not temp_group:
temp_group.append(i)
else:
if i == (temp_group[-1] + 1):
temp_group.append(i)
else:
token_indices.append(temp_group)
temp_group = [i]
token_indices.append(temp_group)
lucrative = 0
for i, li_index in enumerate(token_indices):
if poses:
eojeol = eojeols[i].split("+")
pos = poses[i].split("+")
tagged = []
for e, p in zip(eojeol, pos):
tagged.append(f"{e}/{p}")
result[li_index[0] - lucrative:li_index[-1] + 1 -
lucrative] = ["+".join(tagged)]
else:
result[li_index[0] - lucrative:li_index[-1] + 1 -
lucrative] = [eojeols[i]]
lucrative += len(li_index) - 1
return result
def _check_sanity(self, cands: List[str], n_space: int):
"""
Check sanity for valid xml structure
Args:
cands (List[str]): candidates
n_space (int): number of space
Returns:
return valid or not
"""
for cand in cands:
# Count the number of space special character
if cand.count("▁") != n_space:
continue
# Check whether candidate XML is valid
try:
etree.fromstring(cand)
return cand
except:
continue
return False
[docs] def predict(
self,
text: str,
beam: int = 5,
pos: bool = False,
**kwargs,
):
"""
Conduct constituency parsing
Args:
text (str): input text
beam (int): size of beam search
pos (bool): contains PoS tagging or not
Returns:
result of constituency parsing
"""
eojeols = self._tagger(text)
n_space = len([m for m in eojeols if m[1] == "SPACE"])
pairs = self._tagger(text, return_string=False)
src = " ".join(
[pair[1] if pair[1] != "SPACE" else "▁" for pair in pairs])
outputs = self._model.translate(
src,
beam=beam,
max_len_a=1,
max_len_b=50,
)
result = self._check_sanity([outputs], n_space)
if not result:
return f"<ERROR> {text} </ERROR>"
result = [res for res in result.split() if res != "▁"]
words = []
poses = []
tmp_word = ""
tmp_pos = ""
for eojeol in eojeols:
if eojeol[1] != "SPACE":
tmp_word += f"{eojeol[0]}+"
tmp_pos += f"{eojeol[1]}+"
else:
words.append(tmp_word[:-1])
poses.append(tmp_pos[:-1])
tmp_word = ""
tmp_pos = ""
words.append(tmp_word[:-1])
poses.append(tmp_pos[:-1])
if not pos:
poses = None
result = " ".join(self._postprocess(result, words, poses))
return self._prettify(result).strip()
[docs]class PororoTransConstEn(PororoConstBase):
def __init__(self, model, tagger, config):
super().__init__(config)
self._model = model
self._tagger = tagger
def _check_sanity(self, tags: List[str], n_words: int):
"""
Check sanity for valid xml structure
Args:
tags (List[str]): list of tags
n_words (int): number of words
Returns:
return valid or not
"""
n_out = 0
for tag in tags:
if ("<" not in tag) and (">" not in tag):
n_out += 1
return n_out == n_words
def _preprocess(self, tagged: List[Tuple]) -> str:
"""
Preprocess input sentence to replace whitespace token with whitespace
Args:
tagged (List[str]): list of tagges
Returns:
preprocessed sentence, original input
"""
ori = " ".join([tag[0] for tag in tagged if tag[1] != "SPACE"])
sent = " ".join([tag[1] for tag in tagged if tag[1] != "SPACE"])
sent = sent.replace("-LRB-", "(")
sent = sent.replace("-RRB-", ")")
return sent, ori
def _postprocess(self, tags: List[str], words: List[str], pos: List[str]):
"""
Postprocess result of parsing
Args:
tags (List[str]): list of parsing tag
words (List[str]): list of word
pos (List[str]): list of PoS tag
Returns:
postprocessed result string
"""
result = list()
i = 0
for tag in tags:
if ("<" not in tag) and (">" not in tag):
if pos:
result.append(f"{words[i]}/{pos[i]}")
else:
result.append(words[i])
i += 1
else:
result.append(tag)
return " ".join(result)
[docs] def predict(
self,
text: str,
beam: int = 5,
pos: bool = False,
**kwargs,
):
"""
Conduct constituency parsing
Args:
text (str): input sentence
beam (int): size of beam search
pos (bool): contains PoS tagging or not
Returns:
result of constituency parsing
"""
tags, ori = self._preprocess(self._tagger(text))
n_words = len(tags.split())
outputs = self._model.translate(
tags,
beam=beam,
max_len_a=1,
max_len_b=50,
)
result = self._check_sanity(outputs.split(), n_words)
if not result:
return f"<ERROR> {text} </ERROR>"
poses = None
if pos:
poses = tags.split()
outputs = self._postprocess(outputs.split(), ori.split(), poses)
return self._prettify(outputs).strip()
[docs]class PororoTransConstZh(PororoConstBase):
def __init__(self, model, tagger, config):
super().__init__(config)
self._model = model
self._tagger = tagger
self._map = {
"a": "ADJ",
"ad": "ADJ",
"ag": "ADJ",
"an": "ADJ",
"b": "NOUN",
"c": "CONJ",
"d": "ADV",
"df": "ADV",
"dg": "ADV",
"e": "INTJ",
"f": "NOUN",
"g": "MORPHEME",
"h": "PREFIX",
"i": "IDIOM",
"j": "NOUN",
"k": "SUFFIX",
"l": "IDIOM",
"m": "NUM",
"mg": "NUM",
"mq": "NUM",
"n": "NOUN",
"ng": "NOUN",
"nr": "NOUN",
"nrfg": "NOUN",
"nrt": "NOUN",
"ns": "NOUN",
"nt": "NOUN",
"nz": "NOUN",
"o": "ONOM",
"p": "PREP",
"q": "CLASSIFIER",
"r": "PRON",
"rg": "PRON",
"rr": "PRON",
"rz": "PRON",
"s": "NOUN",
"t": "NOUN",
"tg": "NOUN",
"u": "PART",
"ud": "PART",
"ug": "PART",
"uj": "PART",
"ul": "PART",
"uv": "PART",
"uz": "PART",
"v": "VERB",
"vd": "VERB",
"vg": "VERB",
"vi": "VERB",
"vn": "VERB",
"vq": "VERB",
"x": "X",
"y": "PART",
"z": "ADJ",
"zg": "ADJ",
"eng": "X",
}
def _check_sanity(self, tags: List[str], n_words: int):
"""
Check sanity for valid xml structure
Args:
tags (List[str]): list of tag
n_words (int): number of word
Returns:
return valid or not
"""
n_out = 0
for tag in tags:
if ("<" not in tag) and (">" not in tag):
n_out += 1
return n_out == n_words
def _preprocess(self, tagged: List[Tuple]) -> Tuple:
"""
Preprocess input sentence to replace whitespace token with whitespace
Args:
tagged (List[Tuple]): list of tagged tuple
Returns:
result of preprocess
"""
ori = " ".join([tag[0] for tag in tagged])
tags = [tag[1] for tag in tagged]
# Mapping into general tagset
tags = [self._map[tag] if tag in self._map else "X" for tag in tags]
return " ".join(tags), ori
def _postprocess(
self,
tags: List[str],
words: List[str],
pos: bool = False,
):
"""
Postprocess result of parsing
Args:
tags (List[str]): list of parsing tag
words (List[str]): list of word
pos (List[str]): list of PoS tag
Returns:
postprocessed result string
"""
result = list()
i = 0
for tag in tags:
if ("<" not in tag) and (">" not in tag):
if pos:
result.append(f"{words[i]}/{pos[i]}")
else:
result.append(words[i])
i += 1
else:
result.append(tag)
return " ".join(result)
[docs] def predict(
self,
text: str,
beam: int = 5,
pos: bool = False,
**kwargs,
):
"""
Conduct constituency parsing
Args:
text (str): input sentence
beam (int): size of beam search
pos (bool): contains PoS tagging or not
Returns:
result of constituency parsing
"""
tags, ori = self._preprocess(self._tagger(text))
n_words = len(tags.split())
outputs = self._model.translate(
tags,
beam=beam,
max_len_a=1,
max_len_b=50,
)
result = self._check_sanity(outputs.split(), n_words)
if not result:
return f"<ERROR> {text} </ERROR>"
poses = None
if pos:
poses = tags.split()
outputs = self._postprocess(outputs.split(), ori.split(), poses)
return self._prettify(outputs).strip()