๋”ฅ๋Ÿฌ๋‹/Today I learned :

pytorch DataLoader ํŒŒ์ดํ† ์น˜ ๋ฐ์ดํ„ฐ๋กœ๋” ์‚ฌ์šฉ๋ฒ•

์ฃผ์˜ ๐Ÿฑ 2023. 1. 17. 10:43
728x90
๋ฐ˜์‘ํ˜•

ํŒŒ์ดํ† ์น˜๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ๋งŒ๋“œ๋ ค๋ฉด ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค. ์ด๋ฅผ ์ •์˜ ํ•˜๋Š” ๋ฒ•์„ ์ •๋ฆฌํ•ด๋ณด๋ ค ํ•œ๋‹ค ,

์ด์ „์— https://getacherryontop.tistory.com/144 ์—์„œ ๋ฒ„ํŠธ๋กœ ์˜ํ™”๋ฆฌ๋ทฐ ๊ฐ์ •๋ถ„์„์„ ํ•˜๋ฉฐ ๋‘ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ์–ธ๊ธ‰ํ–ˆ์—ˆ๋‹ค. 

1. transformers ์˜ Trainer๋ฅผ ํ™œ์šฉํ•œ๋‹ค.

2. pytorch๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. 

2๋ฒˆ์˜ ๊ฒฝ์šฐ์— ์šฐ๋ฆฌ๋Š” ๋ฐ์ดํ„ฐ๋กœ๋”๋ฅผ ์ •์˜ํ•ด์•ผ ํ•œ๋‹ค.

์œ„ ๊ธ€์—์„œ ๋ฐ์ดํ„ฐ๋กœ๋”๋งŒ ๋ณด๋ฉด

from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)

eval_dataloader = DataLoader(small_eval_dataset, batch_size=32)

--------------------------------------


def tokenize_function(examples):

    return tokenizer(examples['text'], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(1000))

small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(1000))


๋ฐ์ดํ„ฐ๋กœ๋”์˜ ์—ญํ• 

๋ฐ์ดํ„ฐ๋ฅผ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋ชจ๋ธ์— ๋„ฃ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค. 

๋ฐ์ดํ„ฐ์…‹์€ ๋ฐ์ดํ„ฐ ๋กœ๋”์˜ ๊ตฌ์„ฑ ์š”์†Œ์ค‘ ํ•˜๋‚˜๋กœ, ๋ฐ์ดํ„ฐ์…‹ ์•ˆ์—๋Š” ์—ฌ๋Ÿฌ ์ธ์Šคํ„ด์Šค(๋ฌธ์„œ+๋ ˆ์ด๋ธ”)์ด ํฌํ•จ๋˜์–ด ์žˆ๋‹ค. 

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)

eval_dataloader = DataLoader(small_eval_dataset, batch_size=32)

์ฒ˜๋Ÿผ

DataLoader์— ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋„˜๊ฒจ์ค€๋‹ค. train_dataloader์—์„œ๋Š” train ๋ฐ์ดํ„ฐ์…‹์—์„œ ๋ฌธ์žฅ+๋ ˆ์ด๋ธ”์„ ๋žœ๋ค์œผ๋กœ 8๊ฐœ์”ฉ ๋ฝ‘์•„ ๋ฐฐ์น˜1, ๋‹ค๋ฅธ 8๊ฐœ๋Š” ๋ฐฐ์น˜2 ์ด๋Ÿฐ์‹์œผ๋กœ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋‹ค. 

์—ฌ๊ธฐ์„œ ์ฃผ์˜ํ•  ์ ์€ ๋ฐฐ์น˜ ์•ˆ์— ์žˆ๋Š” ๋ฌธ์žฅ๋“ค์€ ๋ชจ๋‘ ๊ธธ์ด(ํ† ํฐ ์ˆ˜)๊ฐ€ ๊ฐ™์•„์•ผ ํ•œ๋‹ค. ์ด๋Š” ์ด์ „์— ํŒจ๋”ฉ์œผ๋กœ ๋งž์ถฐ์ค€๋‹ค .

์ด ๊ณผ์ •์€ ์œ„์—์„œ tokenize_function ํ•จ์ˆ˜๋ฅผ ์ฐธ๊ณ ํ•˜๋ฉด ๋œ๋‹ค!

 

๋ฐ์ดํ„ฐ ๋กœ๋”๋กœ ๋„˜๊ฒจ์ฃผ๋Š” ๋ณ€์ˆ˜๋“ค์€ ์ €๊ฒƒ๋งŒ ์žˆ๋Š” ๊ฒƒ์€ ์•„๋‹ˆ๋‹ค. 

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)

 

์ด์ฒ˜๋Ÿผ ๋ฐฐ์น˜์™€ ๋ชจ์–‘๋“ฑ์„ ์ •ํ•ด ๋ชจ๋ธ์˜ ์ตœ์ข… ์ž…๋ ฅ์œผ๋กœ ๋‹ค๋“ฌ์–ด์ฃผ๋Š” ๊ณผ์ •์€ collate , ์ปฌ๋ ˆ์ดํŠธ๋ผ๊ณ  ํ•œ๋‹ค. 

 

++์ถ”๊ฐ€

๋ณด๋‹ค ์ž์„ธํ•œ ์„ค๋ช…์€ ๊ณต์‹๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜๊ธธ ๋ฐ”๋ž€๋‹ค.

https://pytorch.org/docs/stable/data.html

 

DataLoader๋กœ ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ ์ค€๋น„ํ•˜๊ธฐ

Dataset ์€ ๋ฐ์ดํ„ฐ์…‹์˜ ํŠน์ง•(feature)์„ ๊ฐ€์ ธ์˜ค๊ณ  ํ•˜๋‚˜์˜ ์ƒ˜ํ”Œ์— ์ •๋‹ต(label)์„ ์ง€์ •ํ•˜๋Š” ์ผ์„ ํ•œ ๋ฒˆ์— ํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ, ์ผ๋ฐ˜์ ์œผ๋กœ ์ƒ˜ํ”Œ๋“ค์„ “๋ฏธ๋‹ˆ๋ฐฐ์น˜(minibatch)”๋กœ ์ „๋‹ฌํ•˜๊ณ , ๋งค ์—ํญ(epoch)๋งˆ๋‹ค ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์‹œ ์„ž์–ด์„œ ๊ณผ์ ํ•ฉ(overfit)์„ ๋ง‰๊ณ , Python์˜ multiprocessing ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ์†๋„๋ฅผ ๋†’์ด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

DataLoader ๋Š” ๊ฐ„๋‹จํ•œ API๋กœ ์ด๋Ÿฌํ•œ ๋ณต์žกํ•œ ๊ณผ์ •๋“ค์„ ์ถ”์ƒํ™”ํ•œ ์ˆœํšŒ ๊ฐ€๋Šฅํ•œ ๊ฐ์ฒด(iterable)์ž…๋‹ˆ๋‹ค.

from torch.utils.data import DataLoader

 

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

DataLoader๋ฅผ ํ†ตํ•ด ์ˆœํšŒํ•˜๊ธฐ(iterate)

DataLoader ์— ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜จ ๋’ค์—๋Š” ํ•„์š”์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ˆœํšŒ(iterate)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์˜ ๊ฐ ์ˆœํšŒ(iteration)๋Š” (๊ฐ๊ฐ batch_size=64 ์˜ ํŠน์ง•(feature)๊ณผ ์ •๋‹ต(label)์„ ํฌํ•จํ•˜๋Š”) train_features ์™€ train_labels ์˜ ๋ฌถ์Œ(batch)์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. shuffle=True ๋กœ ์ง€์ •ํ–ˆ์œผ๋ฏ€๋กœ, ๋ชจ๋“  ๋ฐฐ์น˜๋ฅผ ์ˆœํšŒํ•œ ๋’ค ๋ฐ์ดํ„ฐ๊ฐ€ ์„ž์ž…๋‹ˆ๋‹ค. (๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์ˆœ์„œ๋ฅผ ๋ณด๋‹ค ์„ธ๋ฐ€ํ•˜๊ฒŒ(finer-grained) ์ œ์–ดํ•˜๋ ค๋ฉด Samplers ๋ฅผ ์‚ดํŽด๋ณด์„ธ์š”.)

# ์ด๋ฏธ์ง€์™€ ์ •๋‹ต(label)์„ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.

train_features, train_labels = next(iter(train_dataloader))

print(f"Feature batch shape: {train_features.size()}")

print(f"Labels batch shape: {train_labels.size()}")

img = train_features[0].squeeze()

label = train_labels[0]

plt.imshow(img, cmap="gray")

plt.show()

print(f"Label: {label}")

๋ฐ˜์‘ํ˜•