pytorch dataloader 传入不定长数据

1.先把所有数据放在一起

2.使用idx标记每段数据的开始和结束位置

核心代码

1
2
3
flip_idx = [len(x) for x in data]
flip_idx = np.cumsum(flip_idx)
flip_idx = np.insert(flip_idx, 0, 0)
请作者喝一杯咖啡☕️