Write a Model Adapter

The interface between your model and the evaluation on our metrics and datasets is provided as a vqa_benchmarking_backend.datasets.dataset.DatasetModelAdapter . An adapter wraps around a model and is required to return a probability distribution from its _forward() function. During metric calculation, the _forward() function recieves a list of ``DataSample``s that need to be transformed to fit your models expected input.

Some general getter functions are required, e.g. its name get_name(), output size get_output_size(), and the model itself get_torch_module() . The functions get_question_embedding() and get_image_embedding() should fill the properties question_features and image_features of a DataSample object respectively in order to enable caching and appyling noise onto the feature representations.

Now that we have a created a DatasetModelAdapter, we can start evaluating (see Evaluate Metrics) .

from vqa_benchmarking_backend.datasets.dataset import DatasetModelAdapter

class MyModelAdapter(DatasetModelAdapter):
    """
    NOTE: when inheriting from this class, make sure to

        * move the model to the intended device
        * move the data to the intended device inside the _forward method
    """
    def __init__(self,
                 device,
                 vocab: Vocabulary,
                 ckpt_file: str = '',
                 name: str,
                 n_classes: int) -> None:

        self.device = device
        self.vocab = vocab
        self.name = name
        self.n_classes = n_classes
        self.vqa_model = myModel().to(device) # the pytorch instance of the VQA model
        self.vqa_model.load_state_dict(torch.load(ckpt_file, map_location=device)['state_dict'])

        gpu_id = int(device.split(':')[1]) # cuda:ID -> ID
        self.img_feat_extractor, self.img_feat_cfg = setup("bottomupattention/configs/bua-caffe/extract-bua-caffe-r101.yaml", 10, 100, gpu_id)  # in this example, we load an external image feature extractor


    def get_name(self) -> str:
        # Needed for file caching, has to be overriden
        return self.name

    def get_output_size(self) -> int:
        # number of classes in prediction, has to be overriden
        return self.n_classes

    def get_torch_module(self) -> torch.nn.Module:
        # return the pytorch VQA model, has to be overriden
        return self.vqa_model

    def question_token_ids(self, question_tokenized: List[str]) -> torch.LongTensor:
        # helper function to get token ids as input to our VQA model, custom to this example
        return torch.tensor([self.vocab.stoi(token) if self.vocab.exists(token) else self.vocab.stoi('UNK') for token in question_tokenized], dtype=torch.long)

    def get_question_embedding(self, sample: DataSample) -> torch.FloatTensor:
        # embed questions without full model forward-pass, has to be overriden
        if isinstance(sample.question_features, type(None)):
            sample.question_features = self.vqa_model.embedding(self.question_token_ids(sample.question_tokenized).to(self.device)).cpu()
        return sample.question_features

    def get_image_embedding(self, sample: DataSample) -> torch.FloatTensor:
        # embed images without full model forward-pass, has to be overriden
        # in this example, the feature extractor is external
        if isinstance(sample.image_features, type(None)):
            sample.image_features = extract_feat_in_memory(self.img_feat_extractor, sample._image_path, self.img_feat_cfg)['x'].cpu()
        return sample.image_features

    def _forward(self, samples: List[DataSample]) -> torch.FloatTensor:
        """
        Overwrite this function to run a forward-pass of a list of samples using your model.
        IMPORTANT:
            * Make sure that the outputs are probabilities, not logits!
            * Make sure that the data samples are using the samples' question embedding field, if assigned (instead of re-calculating them, they could be modified from feature space methods)
            * Make sure that the data samples are moved to the intended device here
        """
        q_feats = pad_sequence(sequences=[self.get_question_embedding(sample).to(self.device) for sample in samples], batch_first=True) # extract question features
        img_feats = pad_sequence(sequences=[self.get_image_embedding(sample).to(self.device) for sample in samples], batch_first=True) # extract image features

        logits = self.vqa_model.forward(img_feats, q_feats) # run forward-pass for our VQA model
        probs = logits.softmax(dim=-1) # convert model outputs to probability distribution across answer space

        return probs