์ž์—ฐ์–ด ์ฒ˜๋ฆฌ/Today I learned :

model.train() ๊ณผ model.eval()์˜ ์Šค์œ„์นญ์€ ํ•„์ˆ˜์ผ๊นŒ?

์ฃผ์˜ ๐Ÿฑ 2023. 1. 20. 21:18
728x90
๋ฐ˜์‘ํ˜•

์ž์—ฐ์–ด์ฒ˜๋ฆฌ์—์„œ ๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ ์ค‘ ์˜ˆ๋ฅผ ๋“ค๋ฉด ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฌธ์ œ์—์„œ,

๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ๋•Œ train ๊ณผ validation ์œผ๋กœ ๋จผ์ € ์„ฑ๋Šฅ์„ ์ฑ„์ ํ•œ ํ›„,๋ ˆ์ด๋ธ”์ด ์—†๋Š” ์ƒˆ๋กœ์šด ์ธํ’‹์œผ๋กœ test์…‹์„ ๋„ฃ์–ด ์˜ˆ์ธก๋œ ๋ ˆ์ด๋ธ” ๊ฐ’์„ ์–ป๋Š”๋‹ค.  train ๊ณผ validation์„ ํ•˜๋Š” ๊ณผ์ •์—์„œ, train์„ ํ•˜๊ธฐ์ „ model.train() ์œผ๋กœ train์ƒํƒœ๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ณ , train์ด ๋๋‚˜๋ฉด model.eval()๋กœ ์Šค์œ„์นญํ•˜์—ฌ ๊ฒ€์ฆ์„ ํ•˜๊ณ  ๋‹ค์‹œ train- eval ํ•˜๋Š” ์‹์œผ๋กœ ์—ํฌํฌ ๋งŒํผ ๋Œ๊ฒŒ ๋œ๋‹ค. 

์ด ๋•Œ , train ํ•  ๋•Œ๋Š” ๋ฌด์กฐ๊ฑด train mode, validation ํ•  ๋•Œ๋Š” ๋ฌด์กฐ๊ฑด validation ๋ชจ๋“œ์— ์žˆ์–ด์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์Šค์œ„์นญ์€ ํ•„์ˆ˜ ์ด๋‹ค. 

์ฝ”๋“œ์—์„œ๋„ ์ด๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋ช…์‹œํ•ด์•ผ ํ•˜๋Š”์ง€ ๊ถ๊ธˆํ–ˆ์—ˆ๋Š”๋ฐ 

๋งŒ์•ฝ eval ์ด with ๋ฌธ ์•ˆ์— ์žˆ์œผ๋ฉด ๊ฒ€์ฆ์ด ๋๋‚˜๋ฉด ์ž๋™์œผ๋กœ train ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜์ด ๋˜์–ด model.train() ๊ณผ model.eval()์„ ํ•œ๋ฒˆ์”ฉ๋งŒ ์“ฐ๋ฉด ๋œ๋‹ค. ํ•˜์ง€๋งŒ ๊ทธ๋ ‡์ง€ ์•Š์„ ๊ฒฝ์šฐ model.train() ๊ณผ model.eval() ์ด ๋๋‚˜๋Š” ๋•Œ์— model.train() ์„ ๋‹ค์‹œ ๋ช…์‹œํ•ด ์ฃผ์–ด ํ›ˆ๋ จ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜์„ ํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค. 

๋”ฐ๋ผ์„œ ์Šค์œ„์นญ์€ ํ•„์ˆ˜์ ์ด๋‹ค!


์ฝ”๋“œ ์˜ˆ์‹œ

model.train()
for epoch in range(train_epoch):
........

(์ƒ๋žต)

            optimizer.zero_grad()
            loss = output.loss
            loss.backward()

            optimizer.step()

  
                    model.eval()
                    
                    for .... in tqdm(......):
       

                        output = model(....)

                        logits = output.logits
                        loss = output.loss
      
                    ...

(์ƒ๋žต)
             
            model.train()

 

 

๋ฐ˜์‘ํ˜•