Cache 生成的对象

可持久化对象不如直接多线程dataloader。

想把处理好的对象直接保存下来,然后写了一个Cache

第一次会把对象保存到硬盘, 第二次会直接读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import torch
from abc import ABCMeta, abstractmethod


class DataBase(object):
__metaclass__ = ABCMeta

def __init__(self):
pass

@abstractmethod
def get_data(self):
pass


class Cache:
def __init__(self, path):
self.path = path

def fetch(self, name, database):
save_path = os.path.join(self.path, name)
if os.path.exists(save_path):
return torch.load(save_path)
else:
data = database.get_data()
torch.save(data, save_path)
return data
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TrainDataBase(DataBase):
def get_data(self):
return [ {'image': data['image'], 'pose': data['pose']}
for i, data in tqdm(enumerate(train_dataloader))]


class TestDataBase(DataBase):
def get_data(self):
return [ {'image': data['image'], 'pose': data['pose']}
for i, data in tqdm(enumerate(test_dataloader)) ]


persistent_path = "/SSD1/save_obj"
cache = Cache(persistent_path)
time1 = time()

train_data_vec = cache.fetch("mobile-net-train-data", TrainDataBase())
time2 = time()

test_data_vec = cache.fetch("mobile-net-test-data", TestDataBase())
time3 = time()
print ("%s, %s\n", time2 - time1, time3 - time1)

然后测试结果是
74.46734094619751s, 75.0271668434143s
本来多线程是不到1分钟。。。
去生成的文件瞅瞅,12G
因为Dataloader是多线程处理的,所以会快很多。

请作者喝一杯咖啡☕️