"""Dependency Parsing related modeling class"""
from typing import List, Optional, Tuple
from pororo.tasks.utils.base import PororoFactoryBase, PororoSimpleBase
[docs]class PororoDpFactory(PororoFactoryBase):
"""
Conduct dependency parsing
Korean (`posbert.base.ko.dp`)
- dataset: https://corpus.korean.go.kr/ 구문 분석 말뭉치
- metric: UAS (90.57), LAS (95.96)
Args:
sent: (str) sentence to be parsed dependency
Returns:
List[Tuple[int, str, int, str]]: token index, token label, token head and its relation
Examples:
>>> dp = Pororo(task="dep_parse", lang="ko")
>>> dp("분위기도 좋고 음식도 맛있었어요. 한 시간 기다렸어요.")
[(1, '분위기도', 2, 'NP_SBJ'), (2, '좋고', 4, 'VP'), (3, '음식도', 4, 'NP_SBJ'), (4, '맛있었어요.', 7, 'VP'), (5, '한', 6, 'DP'), (6, '시간', 7, 'NP_OBJ'), (7, '기다렸어요.', -1, 'VP')]
>>> dp("한시간 기다렸어요.")
[(1, '한시간', 2, 'NP_OBJ'), (2, '기다렸어요.', -1, 'VP')]
"""
def __init__(self, task: str, lang: str, model: Optional[str]):
super().__init__(task, lang, model)
[docs] @staticmethod
def get_available_langs():
return ["ko"]
[docs] @staticmethod
def get_available_models():
return {"ko": ["posbert.base.ko.dp"]}
[docs] def load(self, device: str):
"""
Load user-selected task-specific model
Args:
device (str): device information
Returns:
object: User-selected task-specific model
"""
from pororo.tasks import PororoPosFactory
if "posbert" in self.config.n_model:
from pororo.models.brainbert import RobertaSegmentModel
model = (RobertaSegmentModel.load_model(
f"bert/{self.config.n_model}",
self.config.lang,
).eval().to(device))
tagger = PororoPosFactory(
task="pos",
model="mecab-ko",
lang=self.config.lang,
).load(device)
return PororoSegmentBertDP(model, tagger, self.config)
[docs]class PororoSegmentBertDP(PororoSimpleBase):
def __init__(self, model, tagger, config):
super().__init__(config)
self._tagger = tagger
self._model = model
def _preprocess(self, sent: str) -> Tuple:
"""
Preprocess dependency parsing input
Args:
sent (str): input sentence to be preprocessed
Returns:
str: preprocessed input sentence with pos tag
"""
pairs = self._tagger(sent, return_surface=True)
# yapf: disable
tokens = ["<s>", "▃"] + [pair[0] if pair[0] != " " else "▃" for pair in pairs]
# yapf: enable
tags = [
pair[1] if pair[0] != " " else pairs[i + 1][1]
for i, pair in enumerate(pairs)
]
prefix = ["XX", tags[0]]
tags = prefix + tags
res_tags = []
for tag in tags:
if "+" in tag:
tag = tag[:tag.find("+")]
res_tags.append(tag)
return tokens, res_tags
def _postprocess(
self,
ori: str,
tokens: List[str],
heads: List[int],
labels: List[str],
):
"""
Postprocess dependency parsing output
Args:
ori (sent): original sentence
heads (List[str]): dependency heads generated by model
labels (List[str]): tag labels generated by model
Returns:
List[Tuple[int, str, int, str]]: token index, token label, token head and its relation
"""
eojeols = ori.split()
indices = [i for i, token in enumerate(tokens) if token == "▃"]
real_heads = [head for i, head in enumerate(heads) if i in indices]
real_labels = [label for i, label in enumerate(labels) if i in indices]
result = []
for i, (head, label,
eojeol) in enumerate(zip(
real_heads,
real_labels,
eojeols,
)):
curr = i + 1
try:
head_eojeol = indices.index(head) + 1
except:
head_eojeol = -1
if head_eojeol == curr:
head_eojeol = -1
result.append((curr, eojeol, head_eojeol, label))
return result
[docs] def predict(self, sent: str, **kwargs):
"""
Conduct dependency parsing
Args:
sent: (str) sentence to be parsed dependency
Returns:
List[Tuple[int, str, int, str]]: token index, token label, token head and its relation
"""
tokens, tags = self._preprocess(sent)
heads, labels = self._model.predict_dependency(tokens, tags)
heads = [int(head) - 1 for head in heads] # due to default <s> token
return self._postprocess(sent, tokens, heads, labels)