Datasets
Standard Dataset
Lightning arrester point cloud segmentation dataset
- Citation Author(s):
- Submitted by:
- Haoyu Song
- Last updated:
- Fri, 05/17/2024 - 06:09
- DOI:
- 10.21227/2gr7-vz15
- License:
- Categories:
- Keywords:
Abstract
This a Lightning arrester point cloud dataset, using TXT documents to save, each file format is (8192, 7), 8192 means each file has 8192 points, where 1-3 columns are spatial dimensions, 4-6 columns are color information, and the last column is the label information of lightning arrester parts segmentation. It can be used to finished pointcloud segmention task.
An example of a file read is as follows:
root = 'File Unzip Path'
TRAIN_DATASET = InsulatorDataset(root=root, npoints=args.npoint, split='train', normal_channel=args.normal)
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=6, drop_last=True)
TEST_DATASET = InsulatorDataset(root=root, npoints=args.npoint, split='test', normal_channel=args.normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=6)
class InsulatorDataset(Dataset):
def __init__(self, root='./data/insulator-8192/', npoints=8192, split='train', normal_channel=False):
self.npoints = npoints
self.root = root
self.split = split
self.normal_channel = normal_channel
# 定义你的类别和部分标签
self.classes = {'Insulator': 0}
self.seg_classes = {'Insulator': [0, 1]}
# 加载所有.txt文件
all_files = [(cat, os.path.join(self.root, cat, fn)) for cat in self.classes.keys()
for fn in os.listdir(os.path.join(self.root, cat))]
# 随机化并切分数据集
random.shuffle(all_files)
num_train = int(len(all_files) * 0.9) # approximately 90% data for training
if self.split == 'train':
self.datapath = all_files[:num_train]
elif self.split == 'test':
self.datapath = all_files[num_train:]
else:
raise ValueError('Invalid split argument. It should be "train" or "test".')
def __getitem__(self, index):
cat, fn = self.datapath[index]
cls = self.classes[cat]
cls = np.array([cls]).astype(np.int32)
data = np.loadtxt(fn).astype(np.float32)
if not self.normal_channel:
points = data[:, :3]
else:
points = data[:, :6] # 现在获取前六列,包括xyz和rgb
seg = data[:, -1].astype(np.int32)
choice = np.random.choice(len(seg), self.npoints, replace=True)
# resample
points = points[choice, :]
seg = seg[choice]
return points, cls, seg
def __len__(self):
return len(self.datapath)