分析PyTorch Dataloader报错ValueError:num_samples的另一种可能原因
作者:阳光素描
先粘报错信息
Traceback (most recent call last): File “train.py”, line 169, in
train_test() File “train.py”, line 29, in train_test
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=False)
File “/data3/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py”,
line 270, in init
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
File “/data3/anaconda3/lib/python3.8/site-packages/torch/utils/data/sampler.py”,
line 102, in init
raise ValueError("num_samples should be a positive integer "
ValueError: num_samples should be a positive integer value, but got num_samples=0
在使用pytorch训练模型时,同样的代码在Windows下可以正常训练,但在Linux下却会出现以上报错。
网上查阅相关资料,解决办法是完全相同的
出现的问题的地方可能是如下的地方
调用DataLoder时注意参数
self.train_dataloader = DataLoader(train_dataset, batch_size=TrainOption.train_batch_size, shuffle=TRUE, num_workers=TrainOption.data_load_worker_num)
shuffle的参数设置错误导致,
因为已经有batch_sample了,就不需要shuffle来进行随机的sample了,所以在这里的shuffle应该设置为FALSE才对。
但我这里并未使用batch_sample,因此不是上述原因。
经查发现
由于两系统下目录地址的格式不同,
因此直接从windows移植过来的代码不能在指定目录下正常读取数据,
且代码未设置sample读取个数为0时报错,
导致dataset返回长度为0,小于batch_size,因此出现上述报错。
出现上述问题时,如未使用batch_sample,可首先检查dataset.len()是否正常。
总结
因多次出现上述问题,故记录。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。