Реализация набора данных видеокадров Pytorch

У меня есть набор данных видеокадров в папке со структурой каталогов, которая выглядит следующим образом:

-parent
 -video_id
  -frame0
  -frame1
  -...

Мой код, сбор кадров в набор данных Pytorch и разделение их на патчи, довольно медленный в __init__ метод / фаза при масштабировании до довольно большого количества кадров / видео. Может быть более эффективный способ создания набора данных?

Код:

class Dataset(torch.utils.data.Dataset):
    def __init__(self, directory='../data/*', get_subdirs=True, size=(16,16), max_ctx_length=4096, dry_run=False):
        print("Loading dataset...")
        self.data = glob.glob(directory)
        if get_subdirs:
            data_temp = []
            for p, i in enumerate(self.data):
                print("Loading data from {0}. {1} more to go...".format(i, len(self.data)-p))
                file_data = glob.glob(i+"/*")
                file_data.sort(key=lambda r: int(''.join(x for x in r if (x.isdigit())))) #This sort...
                data_temp.extend(file_data)
                if dry_run:
                    break
        self.data = data_temp
        self.max_ctx_length = max_ctx_length
        self.size = size
    def __len__(self):
        return len(self.data)*self.size[0]-self.max_ctx_length-1
    def __getitem__(self, key):
        frame_start = int(np.floor(key / self.size[0]))
        patch_start = int(np.mod(key, self.size[0]))
        
        patches = []
        i_frame = frame_start

        while len(patches) <= self.max_ctx_length+1:
            frame = (Tvio.read_image(self.data[i_frame], mode=Tvio.ImageReadMode.RGB).float() / 255).unsqueeze(0)
            if len(patches) == 0:
                patches.extend(F.unfold(frame, self.size, stride=self.size).transpose(1,2).split(1,1)[patch_start:])
            else:
                patches.extend(F.unfold(frame, self.size, stride=self.size).transpose(1,2).split(1, 1))
            i_frame += 1
        patches = patches[:self.max_ctx_length+1]

        data_x = patches[0:-1]
        data_y = patches[1:]

        return torch.cat(data_x, dim=1).squeeze(0), torch.cat(data_y, dim=1).squeeze(0)

0

Добавить комментарий

Ваш адрес email не будет опубликован. Обязательные поля помечены *