ํ์ดํ ์น๋ก ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๋ง๋๋ ค๋ฉด ๋ฐ์ดํฐ๋ก๋๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค. ์ด๋ฅผ ์ ์ ํ๋ ๋ฒ์ ์ ๋ฆฌํด๋ณด๋ ค ํ๋ค ,
์ด์ ์ https://getacherryontop.tistory.com/144 ์์ ๋ฒํธ๋ก ์ํ๋ฆฌ๋ทฐ ๊ฐ์ ๋ถ์์ ํ๋ฉฐ ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ธ๊ธํ์๋ค.
1. transformers ์ Trainer๋ฅผ ํ์ฉํ๋ค.
2. pytorch๋ฅผ ์ฌ์ฉํ๋ค.
2๋ฒ์ ๊ฒฝ์ฐ์ ์ฐ๋ฆฌ๋ ๋ฐ์ดํฐ๋ก๋๋ฅผ ์ ์ํด์ผ ํ๋ค.
์ ๊ธ์์ ๋ฐ์ดํฐ๋ก๋๋ง ๋ณด๋ฉด
from torch.utils.data import DataLoader
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=32)
--------------------------------------
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(1000))
๋ฐ์ดํฐ๋ก๋์ ์ญํ
๋ฐ์ดํฐ๋ฅผ ๋ฐฐ์น ๋จ์๋ก ๋ชจ๋ธ์ ๋ฃ๋ ์ญํ ์ ํ๋ค.
๋ฐ์ดํฐ์ ์ ๋ฐ์ดํฐ ๋ก๋์ ๊ตฌ์ฑ ์์์ค ํ๋๋ก, ๋ฐ์ดํฐ์ ์์๋ ์ฌ๋ฌ ์ธ์คํด์ค(๋ฌธ์+๋ ์ด๋ธ)์ด ํฌํจ๋์ด ์๋ค.
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=32)
์ฒ๋ผ
DataLoader์ ๋ฐ์ดํฐ์ ๊ณผ ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ์ค์ ํ์ฌ ๋๊ฒจ์ค๋ค. train_dataloader์์๋ train ๋ฐ์ดํฐ์ ์์ ๋ฌธ์ฅ+๋ ์ด๋ธ์ ๋๋ค์ผ๋ก 8๊ฐ์ฉ ๋ฝ์ ๋ฐฐ์น1, ๋ค๋ฅธ 8๊ฐ๋ ๋ฐฐ์น2 ์ด๋ฐ์์ผ๋ก ๋ง๋๋ ๊ฒ์ด๋ค.
์ฌ๊ธฐ์ ์ฃผ์ํ ์ ์ ๋ฐฐ์น ์์ ์๋ ๋ฌธ์ฅ๋ค์ ๋ชจ๋ ๊ธธ์ด(ํ ํฐ ์)๊ฐ ๊ฐ์์ผ ํ๋ค. ์ด๋ ์ด์ ์ ํจ๋ฉ์ผ๋ก ๋ง์ถฐ์ค๋ค .
์ด ๊ณผ์ ์ ์์์ tokenize_function ํจ์๋ฅผ ์ฐธ๊ณ ํ๋ฉด ๋๋ค!
๋ฐ์ดํฐ ๋ก๋๋ก ๋๊ฒจ์ฃผ๋ ๋ณ์๋ค์ ์ ๊ฒ๋ง ์๋ ๊ฒ์ ์๋๋ค.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
์ด์ฒ๋ผ ๋ฐฐ์น์ ๋ชจ์๋ฑ์ ์ ํด ๋ชจ๋ธ์ ์ต์ข ์ ๋ ฅ์ผ๋ก ๋ค๋ฌ์ด์ฃผ๋ ๊ณผ์ ์ collate , ์ปฌ๋ ์ดํธ๋ผ๊ณ ํ๋ค.
++์ถ๊ฐ
๋ณด๋ค ์์ธํ ์ค๋ช ์ ๊ณต์๋ฌธ์๋ฅผ ์ฐธ์กฐํ๊ธธ ๋ฐ๋๋ค.
https://pytorch.org/docs/stable/data.html
DataLoader๋ก ํ์ต์ฉ ๋ฐ์ดํฐ ์ค๋นํ๊ธฐ
Dataset ์ ๋ฐ์ดํฐ์ ์ ํน์ง(feature)์ ๊ฐ์ ธ์ค๊ณ ํ๋์ ์ํ์ ์ ๋ต(label)์ ์ง์ ํ๋ ์ผ์ ํ ๋ฒ์ ํฉ๋๋ค. ๋ชจ๋ธ์ ํ์ตํ ๋, ์ผ๋ฐ์ ์ผ๋ก ์ํ๋ค์ “๋ฏธ๋๋ฐฐ์น(minibatch)”๋ก ์ ๋ฌํ๊ณ , ๋งค ์ํญ(epoch)๋ง๋ค ๋ฐ์ดํฐ๋ฅผ ๋ค์ ์์ด์ ๊ณผ์ ํฉ(overfit)์ ๋ง๊ณ , Python์ multiprocessing ์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ๊ฒ์ ์๋๋ฅผ ๋์ด๋ ค๊ณ ํฉ๋๋ค.
DataLoader ๋ ๊ฐ๋จํ API๋ก ์ด๋ฌํ ๋ณต์กํ ๊ณผ์ ๋ค์ ์ถ์ํํ ์ํ ๊ฐ๋ฅํ ๊ฐ์ฒด(iterable)์ ๋๋ค.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
DataLoader๋ฅผ ํตํด ์ํํ๊ธฐ(iterate)
DataLoader ์ ๋ฐ์ดํฐ์ ์ ๋ถ๋ฌ์จ ๋ค์๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์ ์ ์ํ(iterate)ํ ์ ์์ต๋๋ค. ์๋์ ๊ฐ ์ํ(iteration)๋ (๊ฐ๊ฐ batch_size=64 ์ ํน์ง(feature)๊ณผ ์ ๋ต(label)์ ํฌํจํ๋) train_features ์ train_labels ์ ๋ฌถ์(batch)์ ๋ฐํํฉ๋๋ค. shuffle=True ๋ก ์ง์ ํ์ผ๋ฏ๋ก, ๋ชจ๋ ๋ฐฐ์น๋ฅผ ์ํํ ๋ค ๋ฐ์ดํฐ๊ฐ ์์ ๋๋ค. (๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ ์์๋ฅผ ๋ณด๋ค ์ธ๋ฐํ๊ฒ(finer-grained) ์ ์ดํ๋ ค๋ฉด Samplers ๋ฅผ ์ดํด๋ณด์ธ์.)
# ์ด๋ฏธ์ง์ ์ ๋ต(label)์ ํ์ํฉ๋๋ค.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
'๋ฅ๋ฌ๋ > Today I learned :' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
numpy argsort ์ ์๋ฏธ์ ์ฌ์ฉ๋ฒ ์ ๋ฆฌ (0) | 2023.01.20 |
---|---|
end-to-end ๋ชจ๋ธ์ด๋ (0) | 2023.01.13 |
ํ์ดํ ์น๋ก ๊ฐ๋จํ ์ธ๊ณต์ ๊ฒฝ๋ง ๊ตฌํํ๊ธฐ (๋ถ๋ฅ) (0) | 2022.12.29 |
RNN (0) | 2022.12.28 |
python pytorch ํ ์ rank, un squeeze, view, ํ๋ ฌ๊ณฑ (0) | 2022.12.28 |