Source code for pororo.tasks.image_captioning

"""Image Captioning related modeling class"""

import os
from typing import Optional

import torch

from pororo.tasks.utils.base import PororoFactoryBase, PororoSimpleBase
from pororo.tasks.utils.download_utils import download_or_load


[docs]class PororoCaptionFactory(PororoFactoryBase): """ Generates textual description of an image English (`transformer.base.en.caption`) - dataset: MS-COCO 2017 (Tsung-Yi Lin et al. 2014) - metric: TBU Examples: >>> caption = Pororo(task="caption", lang="en") >>> caption("https://i.pinimg.com/originals/b9/de/80/b9de803706fb2f7365e06e688b7cc470.jpg") 'Two men sitting at a table with plates of food.' """ def __init__(self, task: str, lang: str, model: Optional[str]): super().__init__(task, lang, model)
[docs] @staticmethod def get_available_langs(): return ["en", "ko", "zh", "ja"]
[docs] @staticmethod def get_available_models(): return { "en": ["transformer.base.en.caption"], "ko": ["transformer.base.en.caption"], "zh": ["transformer.base.en.caption"], "ja": ["transformer.base.en.caption"], }
[docs] def load(self, device: str): """ Load user-selected task-specific model Args: device (str): device information Returns: object: User-selected task-specific model """ translator = None if "transformer" in self.config.n_model: from transformers import BertTokenizer from pororo.models.caption import Caption, Detr load_dict = download_or_load( f"transformer/{self.config.n_model}", "en", ) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") pad_token_id = tokenizer.pad_token_id vocab_size = tokenizer.vocab_size transformer = Caption(pad_token_id, vocab_size) transformer.load_state_dict( torch.load( os.path.join( load_dict.path, f"{self.config.n_model}.pt", ), map_location=device, )["model"]) transformer.eval().to(device) detr = Detr(device) if self.config.lang != "en": assert self.config.lang in [ "ko", "ja", "zh", ], "Unsupported language code is selected!" from pororo.tasks import PororoTranslationFactory translator = PororoTranslationFactory( task="mt", lang="multi", model="transformer.large.multi.mtpg", ) translator = translator.load(device) return PororoCaptionBrainCaption( detr, transformer, tokenizer, translator, device, self.config, )
[docs]class PororoCaptionBrainCaption(PororoSimpleBase): def __init__( self, extractor, generator, tokenizer, translator, device, config, ): super().__init__(config) self._extractor = extractor self._generator = generator self._tokenizer = tokenizer self._translator = translator self._start_token = tokenizer.convert_tokens_to_ids( tokenizer._cls_token) self._end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token) self._device = device self._max_len = 128 def _create_caption_and_mask(self): """ Create dummy caption and mask templates Returns: torch.tensor : template tensors """ caption_template = torch.zeros((1, self._max_len), dtype=torch.long) mask_template = torch.ones((1, self._max_len), dtype=torch.bool) caption_template[:, 0] = self._start_token mask_template[:, 0] = False return caption_template.to(self._device), mask_template.to(self._device) # TODO : Add beam search logic def _generate(self, features, boxes, caption, caption_mask): """ Generate caption using decoding steps Args: features (torch.tensor): image feature tensor boxes (torch.tensor): bounding box features caption (torch.tensor): dummy caption template caption_mask (torch.tensor): mask template Returns: torch.tensor : generate token tensor """ for i in range(self._max_len - 1): pred = self._generator( features, boxes, caption, caption_mask, ) pred = pred[:, i, :] pred_id = torch.argmax(pred, axis=-1) if pred_id[0] == self._end_token: return caption caption[:, i + 1] = pred_id[0] caption_mask[:, i + 1] = False return caption
[docs] def predict(self, image: str, **kwargs): """ Predict caption using image features Args: image (str): image path Returns: str: generate captiong corresponding to input image """ output = self._extractor.extract_feature(image) features = output["features"].unsqueeze(0).to(self._device) boxes = output["boxes"].unsqueeze(0).to(self._device) caption, caption_mask = self._create_caption_and_mask() caption = self._generate( features, boxes, caption, caption_mask, ) caption = self._tokenizer.decode( caption[0].tolist(), skip_special_tokens=True, ).capitalize() # apply translation if needed if self._translator: caption = self._translator(caption, src="en", tgt=self.config.lang) return caption