def get_split_movies(self, split):
"""Get the list of movies in this split.
Raises:
ValueError: If input split type is unrecognized.
"""
this_split_movies = []
if split == 'train':
this_split_movies = self.data_split['train']
elif split == 'val':
this_split_movies = self.data_split['val']
elif split == 'test':
this_split_movies = self.data_split['test']
elif split == 'full':
this_split_movies = list(self.data_split['train'])
this_split_movies.extend(self.data_split['val'])
this_split_movies.extend(self.data_split['test'])
else:
raise ValueError('Invalid split type. Use "train", "val", "test", or "full"')
return this_split_movies
def get_story_qa_data(self, split='train', story_type='plot'):
"""Provide data based on a particular split and story-type.
Args:
split: 'train' OR 'val' OR 'test' OR 'full'.
story_type: 'plot', 'split_plot', 'subtitle', 'dvs', 'script'.
Returns:
story: Story for each movie indexed by imdb_key.
qa: The list of QAs in this split.
"""
this_split_movies = self.get_split_movies(split)
this_movies_map = {k: v for k, v in self.movies_map.iteritems()
if k in this_split_movies}
story = self.load_me_stories.load_story(this_movies_map, story_type)
this_split_movies = [m for m in this_split_movies if m in story]
qa = [qa for qa in self.qa_list if qa.imdb_key in this_split_movies]
return story, qa
def get_video_list(self, split='train', video_type='qa_clips'):
"""Provide data for a particular split and video type.
Args:
split: 'train' OR 'val' OR 'test' OR 'full'.
video_type: 'qa_clips', 'all_clips'.
Returns:
video_list: List of videos indexed by qid (for clips) or movie (for full).
qa: The list of QAs in this split.
Raises:
ValueError: If input video type is unrecognized.
"""
this_split_movies = self.get_split_movies(split)
qa = [qa for qa in self.qa_list if qa.video_clips and qa.imdb_key in this_split_movies]
video_list = {}
if video_type == 'qa_clips':
for qa_one in qa:
video_list.update({qa_one.qid:qa_one.video_clips})
elif video_type == 'all_clips':
for qa_one in qa:
if qa_one.imdb_key in video_list.keys():
video_list[qa_one.imdb_key].extend(qa_one.video_clips)
else:
video_list.update({qa_one.imdb_key:list(qa_one.video_clips)})
for imdb_key in video_list.keys():
video_list[imdb_key] = list(set(video_list[imdb_key]))
else:
raise ValueError('Invalid video type. Use "qa_clips" or "all_clips"')
return video_list, qa