์์ฐ์ด์ฒ๋ฆฌ์์ ๋ค์ด์คํธ๋ฆผ ํ์คํฌ ์ค ์๋ฅผ ๋ค๋ฉด ๋ถ๋ฅํ๋ ๋ฌธ์ ์์,
๋ชจ๋ธ์ ํ์ต์ํฌ ๋ train ๊ณผ validation ์ผ๋ก ๋จผ์ ์ฑ๋ฅ์ ์ฑ์ ํ ํ,๋ ์ด๋ธ์ด ์๋ ์๋ก์ด ์ธํ์ผ๋ก test์ ์ ๋ฃ์ด ์์ธก๋ ๋ ์ด๋ธ ๊ฐ์ ์ป๋๋ค. train ๊ณผ validation์ ํ๋ ๊ณผ์ ์์, train์ ํ๊ธฐ์ model.train() ์ผ๋ก train์ํ๋ก ๋ง๋ค์ด์ฃผ๊ณ , train์ด ๋๋๋ฉด model.eval()๋ก ์ค์์นญํ์ฌ ๊ฒ์ฆ์ ํ๊ณ ๋ค์ train- eval ํ๋ ์์ผ๋ก ์ํฌํฌ ๋งํผ ๋๊ฒ ๋๋ค.
์ด ๋ , train ํ ๋๋ ๋ฌด์กฐ๊ฑด train mode, validation ํ ๋๋ ๋ฌด์กฐ๊ฑด validation ๋ชจ๋์ ์์ด์ผ ํ๊ธฐ ๋๋ฌธ์ ์ค์์นญ์ ํ์ ์ด๋ค.
์ฝ๋์์๋ ์ด๋ฅผ ์๋์ผ๋ก ๋ช ์ํด์ผ ํ๋์ง ๊ถ๊ธํ์๋๋ฐ
๋ง์ฝ eval ์ด with ๋ฌธ ์์ ์์ผ๋ฉด ๊ฒ์ฆ์ด ๋๋๋ฉด ์๋์ผ๋ก train ๋ชจ๋๋ก ๋ณํ์ด ๋์ด model.train() ๊ณผ model.eval()์ ํ๋ฒ์ฉ๋ง ์ฐ๋ฉด ๋๋ค. ํ์ง๋ง ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ model.train() ๊ณผ model.eval() ์ด ๋๋๋ ๋์ model.train() ์ ๋ค์ ๋ช ์ํด ์ฃผ์ด ํ๋ จ๋ชจ๋๋ก ๋ณํ์ ํด์ฃผ์ด์ผ ํ๋ค.
๋ฐ๋ผ์ ์ค์์นญ์ ํ์์ ์ด๋ค!
์ฝ๋ ์์
model.train()
for epoch in range(train_epoch):
........
(์๋ต)
optimizer.zero_grad()
loss = output.loss
loss.backward()
optimizer.step()
model.eval()
for .... in tqdm(......):
output = model(....)
logits = output.logits
loss = output.loss
...
(์๋ต)
model.train()
'์์ฐ์ด ์ฒ๋ฆฌ > Today I learned :' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
์ฝ๋ฉ ํ๋ก, ํ๋ก ํ๋ฌ์ค ์จ๋ณธ ํ๊ธฐ (0) | 2023.01.21 |
---|---|
์์ฐ์ด ์ฒ๋ฆฌ์์์ ํ์ดํผ ํ๋ผ๋ฏธํฐ ์ข ๋ฅ, ์ค์ (0) | 2023.01.20 |
์์ฐ์ด์ฒ๋ฆฌ ๋ชจ๋ธ์ด ํ์คํฌ๋ฅผ ์ํํ๋ ๋ฐฉ๋ฒ์? (์ธ ์ปจํ ์คํธ ๋ฌ๋, ์ ๋ก์ท, ์์ท ํจ์ท ๋ฌ๋) (0) | 2023.01.17 |
์ธ์ด๋ชจ๋ธ GPT (1) | 2023.01.17 |
๋ฒํธ๋ฅผ ํ์ฉํ ์ํ๋ฆฌ๋ทฐ ๋ถ๋ฅ (0) | 2023.01.16 |