def get_split_movies(self, split):
        """Get the list of movies in this split.
        Raises:
          ValueError:   If input split type is unrecognized.
        """
        this_split_movies = []
        if split == 'train':
            this_split_movies = self.data_split['train']
        elif split == 'val':
            this_split_movies = self.data_split['val']
        elif split == 'test':
            this_split_movies = self.data_split['test']
        elif split == 'full':
            this_split_movies = list(self.data_split['train'])
            this_split_movies.extend(self.data_split['val'])
            this_split_movies.extend(self.data_split['test'])
        else:
            raise ValueError('Invalid split type. Use "train", "val", "test", or "full"')

        return this_split_movies


    def get_story_qa_data(self, split='train', story_type='plot'):
        """Provide data based on a particular split and story-type.
        Args:
          split:        'train' OR 'val' OR 'test' OR 'full'.
          story_type:   'plot', 'split_plot', 'subtitle', 'dvs', 'script'.
        Returns:
          story:        Story for each movie indexed by imdb_key.
          qa:           The list of QAs in this split.
        """
        this_split_movies = self.get_split_movies(split)

        # Load story
        this_movies_map = {k: v for k, v in self.movies_map.iteritems()
                           if k in this_split_movies}
        story = self.load_me_stories.load_story(this_movies_map, story_type)

        # Restrict this split movies to ones which have a story,
        # get restricted QA list
        this_split_movies = [m for m in this_split_movies if m in story]
        qa = [qa for qa in self.qa_list if qa.imdb_key in this_split_movies]

        return story, qa


    def get_video_list(self, split='train', video_type='qa_clips'):
        """Provide data for a particular split and video type.
        Args:
          split:        'train' OR 'val' OR 'test' OR 'full'.
          video_type:   'qa_clips', 'all_clips'.
        Returns:
          video_list:   List of videos indexed by qid (for clips) or movie (for full).
          qa:           The list of QAs in this split.
        Raises:
          ValueError:   If input video type is unrecognized.
        """
        this_split_movies = self.get_split_movies(split)

        # Get all video QAs
        qa = [qa for qa in self.qa_list if qa.video_clips and qa.imdb_key in this_split_movies]

        video_list = {}
        if video_type == 'qa_clips':
            # add each qa's video clips to video list
            for qa_one in qa:
                video_list.update({qa_one.qid:qa_one.video_clips})

        elif video_type == 'all_clips':
            # collect list of clips by movie
            for qa_one in qa:
                if qa_one.imdb_key in video_list.keys():
                    video_list[qa_one.imdb_key].extend(qa_one.video_clips)
                else:
                    video_list.update({qa_one.imdb_key:list(qa_one.video_clips)})

            # keep non-repeated clips
            for imdb_key in video_list.keys():
                video_list[imdb_key] = list(set(video_list[imdb_key]))

        else:
            raise ValueError('Invalid video type. Use "qa_clips" or "all_clips"')

        return video_list, qa

results matching ""

    No results matching ""