Python / PyTorch Fundamentals Interview Questions
What is the purpose of torch.utils.data.random_split() and how do you create train/validation/test splits in PyTorch?
Splitting a dataset into training, validation, and test subsets is a fundamental step before training. PyTorch's random_split() creates non-overlapping random subsets from a single Dataset, while preserving the lazy-loading behaviour of the original Dataset.
import torch
from torch.utils.data import Dataset, DataLoader, random_split
class MyDataset(Dataset):
def __init__(self, n=1000):
self.data = torch.randn(n, 20)
self.labels = torch.randint(0, 3, (n,))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
full_dataset = MyDataset(n=1000)
# Split: 70% train, 15% val, 15% test
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size # remainder, avoids rounding loss
# Use a generator for reproducible splits
generator = torch.Generator().manual_seed(42)
train_ds, val_ds, test_ds = random_split(
full_dataset,
[train_size, val_size, test_size],
generator=generator,
)
print(len(train_ds), len(val_ds), len(test_ds)) # 700 150 150
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False) # no shuffle needed
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)
# IMPORTANT GOTCHA: if your Dataset applies different transforms
# (e.g. data augmentation only for training), random_split alone
# does NOT let you apply different transforms per split, because
# all splits reference the SAME underlying Dataset object.
# Common workaround: split INDICES, then wrap with two separate
# Dataset instances using different transforms
from torch.utils.data import Subset
indices = torch.randperm(len(full_dataset), generator=generator).tolist()
train_idx = indices[:train_size]
val_idx = indices[train_size:train_size+val_size]
# train_dataset_aug = Subset(MyDatasetWithAugmentation(...), train_idx)
# val_dataset_plain = Subset(MyDatasetPlain(...), val_idx)
Invest now in Acorns!!! 🚀
Join Acorns and get your $5 bonus!
Acorns is a micro-investing app that automatically invests your "spare change" from daily purchases into diversified, expert-built portfolios of ETFs. It is designed for beginners, allowing you to start investing with as little as $5. The service automates saving and investing. Disclosure: I may receive a referral bonus.
Invest now!!! Get Free equity stock (US, UK only)!
Use Robinhood app to invest in stocks. It is safe and secure. Use the Referral link to claim your free stock when you sign up!.
The Robinhood app makes it easy to trade stocks, crypto and more.
Webull! Receive free stock by signing up using the link: Webull signup.
More Related questions...
