๊ฐ๋จํ ๊ธ๋ถ์ ์ด์ง ๋ถ๋ฅ ๋ชจ๋ธ์ ๋ง๋ค์๋ค.
์ ์ฒด์ฝ๋๋ ๊นํ์์ ๋ณผ ์ ์๋ค!
https://github.com/Juyoung-b/Improving-the-Performance-of-Sentiment-Classification
์์ด๋ก ๋ ๋ ์คํ ๋ ๋ฆฌ๋ทฐ๋ฅผ ๊ฐ์ง๊ณ , ๊ธ์ (1), ๋ถ์ (0)์ผ๋ก ๋ถ๋ฅํ๋ ๊ฐ๋จํ task ๋ชจ๋ธ์ด๋ค.
์ด๋ฒ ํ๋ก์ ํธ์์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๊ฐ์ ํ๋๊ฒ์ ์ง์คํ๋ค.
๋ค์ํ pre-trained ๋ชจ๋ธ์ ์จ๋ณด๊ณ , ์ ํ๋๋ฅผ ์ฌ๋ฆฌ๊ธฐ ์ํด ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋๊ณผ ์๋ํ ๋ฐฉ๋ฒ๋ค์ ์ด์ ์ ๋ง์ท๋ค. ๋๋ต์ ์ธ ํ๋ก์ฐ๋ ๋ค์๊ณผ ๊ฐ๋ค.
๊ฐ๋ฐ ํ๊ฒฝ์ python, pytorch, ์ฝ๋ฉ ํ๋กํ๋ฌ์ค ์ด๋ค.
1. ๋ฐ์ดํฐ ์ค๋น
Yelp ๋ฐ์ดํฐ์ ์ ๊ธ๋ถ์ ์ด ๋ ์ด๋ธ๋ง๋ ๋ ์คํ ๋ ๋ฆฌ๋ทฐ ์ด 443,259๊ฐ๋ฅผ ์ค๋นํ๋ค. train ๋ฐ์ดํฐ๋ก๋ ๊ธ์ ๋ฆฌ๋ทฐ 266041๊ฐ ๋ฌธ์ฅ, ๋ถ์ ๋ฆฌ๋ทฐ๋ 177218๊ฐ์ ๋ฌธ์ฅ, validation๋ฌธ์ฅ์ ๊ธ๋ถ์ ๊ฐ๊ฐ 2000๊ฐ์ฉ ์ค๋นํ๋ค.
2. ์ ์ฒ๋ฆฌ
์ ๋ถ ํ ๋ฌธ์ฅ์ผ๋ก ๋ ์์ด ๋ฆฌ๋ทฐ๋ก, ์ ์ฒ๋ฆฌํ ๊ฒ์ด ์์ด ํ์ง ์์๋ค.
3. fine-tuning
huggingface์ transformer ๋ชจ๋๋ก ๊ฐ ๋ชจ๋ธ์ ๋ถ๋ฌ์์ ํ๋ จ์ํค๋ ๊ฒ์ ์ด๋ ต์ง ์๋ค.
BERT, RoBERTa, DistilBERT, GPT2์ ํ ํฌ๋์ด์ , ๋ชจ๋ธ์ ๋ถ๋ฌ์ trainํ๋ค. ๋ถ๋ฅ์ ์ธ ์ ์๋ ๋ชจ๋ธ์ ์ด ๋ง๊ณ ๋ ALBERT, ๋ฑ ๋ค์ํ ๋ชจ๋ธ๋ค์ด ์๋ค.
๊ฐ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด์ ๋ ์ฐ์ , BERT๋ ๋ฐฉ๋ํ corpus๋ก ํ๋ จ๋ ๋ชจ๋ธ์ด๊ธฐ์ fine tuning์ผ๋ก ๋์ ์ฑ๋ฅ์ ๋ผ ์ ์๋ ๋ชจ๋ธ์ด๋ค. RoBERTa๋ ํ์ด์ค๋ถ์์ ๊ณต๊ฐํ ์ธ์ด๋ชจ๋ธ๋ก, BERT๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์ง๋ง ํ๋ฆฌํธ๋ ์ธ ๋ฐฉ์์ ๋ณํ์ํค๊ณ , ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ฆ๊ฐ์์ผ ๊ธฐ์กด ๋ฒํธ์ ํ๊ณ์ ์ ๋ณด์ํ ๋ชจ๋ธ์ด๋ค. ๋ ผ๋ฌธ์ ๋ฐ๋ฅด๋ฉด, RoBERTa๊ฐ BERT๋ณด๋ค ์ฐ์ธํ ์ฑ๋ฅ์ ๋ณด์ด๊ธฐ์ ์คํ ๋ชจ๋ธ์ ์ถ๊ฐํ๋ค.(์ค์ ๋ก RoBERTa๊ฐ BERT๋ณด๋ค ๋์ ์ ํ๋๋ฅผ ๋ณด์๋ค.)DistilBERT๋ BERT๋ณด๋ค 40% ์๊ณ , 60% ๋น ๋ฅด๋ฉด์๋ 97%์ capability๋ฅผ ๋ณด์กดํ๋ ๋ชจ๋ธ๋ก, ํจ์จ์ ์ธ ๋ชจ๋ธ์ด๋ผ ํ๋จํ์ฌ ์ฌ์ฉํ๋ค.ํ์ง๋ง ์ฝ๋ฉ GPU ์ ํ๋์ด ์์ด ํ ๋ฒ๋ฐ์ ์คํํด๋ณด์ง ๋ชปํ๋ค..
GPT-2๋ BERT์ ๋ฌ๋ฆฌ transformer ์ ๋์ฝ๋๋ฅผ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ด๊ณ , ์ ๋ ฅ ์ํ์ค์ ๋ง์ง๋ง ํ ํฐ์ ์ ๋ ฅ ๋ค์ ์์ผ ํ๋ ๋ค์ ํ ํฐ์ ์์ธก ํ๋๋ฐ ์ฌ์ฉ๋๋ค. ์ด๋ ์ ๋ ฅ ์ํ์ค์ ๋ง์ง๋ง ํ ํฐ์ ์์ธก์ ํ์ํ ๋ชจ๋ ์ ๋ณด๊ฐ ํฌํจ๋์ด ์์์ ์๋ฏธํ๋ฏ๋ก ๋ถ๋ฅ์์ ๊ทธ ์ ๋ณด๋ฅผ ์ฌ์ฉํ์ฌ ์์ธกํ ์ ์๊ธฐ์ ๋ถ๋ฅ ํ์คํฌ๋ฅผ ์ํํ ์ ์๋ค.
Bert์์์ฒ๋ผ ์ฒซ ๋ฒ์งธ ํ ํฐ ์๋ฒ ๋ฉ์ ์ฌ์ฉํ์ฌ ์์ธกํ๋ ๋์ ๋ง์ง๋ง ํ ํฐ ์๋ฒ ๋ฉ์ ์ฌ์ฉํ์ฌ ์์ธกํด์ผ ํ๊ธฐ์, ๊ธฐ์กด BERT๊ธฐ๋ฐ ๋ชจ๋ธ๊ณผ ๋ค๋ฅด๊ฒ GPT2์์๋ ์์ธก์ ์ํด ๋ง์ง๋ง ํ ํฐ์ ์ฌ์ฉํ๊ณ ์์ผ๋ฏ๋ก ์ผ์ชฝ์ ํจ๋ฉ์ ๋ฃ์ด์ผ ํ๋ค.
4. ๊ฒฐ๊ณผ
- ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋
ํ์ดํผํ๋ผ๋ฏธํฐ๋ ๋ฐฐ์น์ฌ์ด์ฆ, ๋ฌ๋๋ ์ดํธ, ์ํฌํฌ๋ฅผ ์กฐ์ ํ๋๋ฐ ์ฝ๋ฉ GPU ์ ํ์ด ์์ด ๋ฐฐ์น์ฌ์ด์ฆ๋ 512์ด์ ํ ์ ์์๊ณ ์ํฌํฌ๋ 3~4ํ๋ฐ์ ๋ชปํ๋ค. ๋ฌ๋๋ ์ดํธ๋ 1์ํฌํฌ ์ดํ ๋ก์ค๊ฐ ์ ์ค์ด๋ค์ง ์์ ๋ฌ๋๋ ์ดํธ ์ค์ผ์ฅด๋ฌ๋ฅผ ์ถ๊ฐํ๋ค. ๋ฌ๋๋ ์ดํธ ์ค์ผ์ฅด๋ฌ๋ ํ์ดํ ์น์์ ์ ๊ณตํ๋ LamdaLR์ ์ฌ์ฉํ์ฌ 0.6์ ๊ฑฐ๋ญ์ ๊ณฑ์ผ๋ก ์ค์ด๋ค๊ฒ ํ์๋ค. ์ถํ์๋ wandb์์ ์ ๊ณตํ๋ sweepsํด์ ์ฌ์ฉํ๋ค.
๋ค์ ํ๋ ๊ฒฐ๊ณผ๋ฅผ ํ๋์ ๋ณผ ์ ์๊ฒ ์ ๋ฆฌํ ๊ฒ์ด๋ค.
์ ํ๋๋ roBERTa๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ๊ฐ 98.28๋ก ์ ์ผ ๋์์ผ๋ test ์ ํ๋๋ BERT๊ฐ 98.9๋ก ๊ฐ์ฅ ๋์๋ค.
๊ธฐ๋ณธ์ ์ผ๋ก ์ฑ๋ฅ์ด ์ข๋ค๊ณ ์๋ ค์ง ์ข์ ๋ชจ๋ธ์ ์ฐ๊ณ , ๋ฐฐ์น์ฌ์ด์ฆ๋ฅผ ๋๋ฆฌ๋ ๊ฒ์ด ์ฑ๋ฅ์ ๋์ธ๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ๊ทธ๋ ์ง๋ง ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ๋ฌด์์ ์์ฒญ ํฌ๊ฒ ํ๋ค๊ณ ํด์ ์ ํ๋๊ฐ ์ฌ๋ผ๊ฐ์ง๋ ์๋๋ค. ๋ชจ๋ธ๋ณ๋ก ๋ฒ ์คํธ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ๋ค ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ด๋ค.
- GPT2์ ํน์ด์
GPT2๋ ๋ค๋ฅธ BERT ๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ ๋นํด ๋ชจ๋ธ ์ ํ๋๋ณด๋ค ํ ์คํธ ์ ํ๋๊ฐ ๋ฎ์์ง๋ ๊ฒฝ์ฐ๊ฐ ๋ฐ์ํ ๊ฒ์ผ๋ก ๋ณด์,
GPT2๊ฐ ๋ค๋ฅธ BERT๊ธฐ๋ฐ ๋ชจ๋ธ์ ๋นํด classification task ์์๋ ์ฑ๋ฅ์ด ์ข์ง ์์ ๊ฒฝํฅ์ด ์๋ค๊ณ ์๊ฐ๋์ด ์ด์ ๋ฅผ ์ฐพ์๋ณธ ๊ฒฐ๊ณผ,
GPT-2๋ BERT์ ๋ค๋ฅด๊ฒ ์์์ ๋ณธ ๋ฐ์ดํฐ๋ง์ ๊ธฐ๋ฐ์ผ๋ก ๋ค์ ๋จ์ด๋ฅผ ์์ธกํ๋ฉด์ ํ์ตํ ๋ชจ๋ธ์ด๊ธฐ์ ๋ถ๋ฅ ์์ ์ ์ฌ์ฉ๋๋ ๊ฒฝ์ฐ ์ด๋ฏธ ๋ณธ ํจํด์ ๋๋ฌด ๋ง์ด ์์กดํ๋ ๊ฒฝํฅ์ด ์์ด, ์ ๋ฐ์ดํฐ์์ ์ฑ๋ฅ์ด ์ ํ๋ ์ ์์ผ๋ฉฐ ์ด๋ ๊ณผ์ ํฉ์ผ๋ก ์ด์ด์ง ์ ์๋ค๊ณ ํ๋ค. ๋ํ GPT-2๋ ํ ์คํธ ๋ถ๋ฅ๋ณด๋ค๋ ๋ค์ ๋จ์ด๋ ํ ํฐ์ ์์ธกํ๋ ๊ฒ์ด ์ฃผ์ ๋ชฉํ์ธ ์ธ์ด ๋ชจ๋ธ์ด๊ธฐ์ ๋ถ๋ฅ์์๋ ๋ค๋ฅธ ๋ชจ๋ธ์ ๋นํด ๋ฎ์ ์ฑ๋ฅ์ด ๋์ฌ ์ ์๋ค.
๋ ๋ค๋ฅธ ์ด์ ๋ GPT-2์ ํ๋ผ๋ฏธํฐ ์๊ฐ ๋ง๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ์ด ๊ณผ์ ํฉ๋๊ธฐ ์ฝ๋ค๋ ๊ฒ์ด๋ค. ๋ง์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํ๋ฉด ๋ชจ๋ธ์ด ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ๋ง์ถ ์ ์๋ ์ฉ๋์ด ๋ ํฌ์ง๋ง ์๋ก์ด ๋ฐ์ดํฐ๋ก ์ ์ผ๋ฐํํ์ง ๋ชปํ ์ ์๋ค๊ณ ํ๋ค.
๋ฌผ๋ก ๋ชจ๋ธ๋ง๋ค ๋ฒ ์คํธ ํ์ดํผํ๋ผ๋ฏธํฐ๊ฐ ๋ค ๋ค๋ฅผ ์ ์๋ค๋ ์ ์ ๊ฐ์ํ๊ณ ๋ดค์ ๋, GPT2๋ BERT๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ ๋นํ๋ฉด ํ ์คํธ ๋ถ๋ฅ๋ฌธ์ ์์๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด์ง ์๋๋ค๊ณ ์๊ฐํ๋ค.
- ๋ณด์์
์คํํ๋ฉด์ loss์ acc๋ฅผ Wandb (Weight&Bias)๋ผ๋ ํด๋ก ์๊ฐํํ์ฌ ๊ด๋ฆฌํ๋ค.
์ ๋ฆฌํ์๋ฉด 99%์ ์ ํ๋๋ฅผ ๋ฌ์ฑํ๊ฒ ๋ค๋ ๋ชฉํ๋ก ํ์ดํผ ํ๋ผ๋ฏธํฐ ํ๋์ ํ๊ณ ๋ค๋ฅธ ๋ชจ๋ธ๋ค์ ์ฌ์ฉํ์ง๋ง ์ ํ๋๋ฅผ ํฌ๊ฒ ์ฌ๋ฆฌ์ง ๋ชปํ๋ค. ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ๋ชจ์ํ ๊ฒฐ๊ณผ sweeps๋ฅผ ์ฌ์ฉํ ํ์ดํผํ๋ผ๋ฏธํฐ ์์นญ, ๋ฐ์ดํฐ ์ฆ๊ฐ์ ๊ณํํ์์ง๋ง, ์ฝ๋ฉ์ GPU ์์ง์ผ๋ก ์ค๊ณํ ๋ชจ๋ ์คํ์ ์๋ฃํ ์ ์์๋ค.(์ญ์ ์ธํ๋ผ๊ฐ ์ค์ํ๋ค..) ์ฌ๊ธฐ์๋ถํฐ๋ ์๋ํ์ง๋ง ์ ์ฉ์ ํด๋ณด์ง ๋ชปํ๋ ์์ด๋์ด๋ค์ ๋ํด ์๊ฐํด๋ณด๋ ค ํ๋ค.
-Sweeps ๋ฅผ ํตํ ํ์ดํผํ๋ผ๋ฏธํฐ ์์นญ
ํ์ดํผ ํ๋ผ๋ฏธํฐ๋ฅผ ์ผ์ผ์ด ๋ฐ๊ฟ์ฃผ๋ ๊ณผ์ ์ ์ข ๋ ์ฉ์ดํ๊ฒ ํ๊ธฐ์ํด Sweeps๋ผ๋ ํด์ ์ฌ์ฉํ๋ค. ๋ค์์ Sweeps๋ฅผ ์ ์ฉํ์ฌ BERT ๋ชจ๋ธ๋ก, count=3์ผ๋ก ํ์ฌ ์์ ์ธ๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ณํ์ํค๋ฉด์ ํด๋ณธ ๊ฒฐ๊ณผ์ด๋ค
์ธ ํ๋ผ๋ฏธํฐ์ค ๊ฐ์ฅ ์ค์ํ ํ๋ผ๋ฏธํฐ๋ learning rate๋ผ๋ ์ , ์ฆ ๊ฐ์ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ค๋ฉด ๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ๋๋ฆฌ๋ ๊ฒ ์ด์์ผ๋ก learing rate๋ฅผ ์ ์ ํํ๋ ๊ฒ์ด ๋ ์ค์ํ๋ค๋ ๊ฒ์ ๋ณผ ์ ์์๋ค.
์๊ฐ๊ณผ GPU ์ ์ฝ์ผ๋ก ์ถฉ๋ถํ ์๋๋ฅผ ๋ชปํ ์ ์ด ์์ฝ์ง๋ง count ๊ฐ์ ์ถฉ๋ถํ ํฌ๊ฒํ์ฌ์ ์ฌ๋ฌ ํ๋ผ๋ฏธํฐ๋ค์ ๋น๊ตํด ๋ณธ๋ค๋ฉด ์ํ๋ ์ต์ ์ ํ๋ผ๋ฏธํฐ๊ฐ์ ์ฐพ๋๋ฐ ๋์์ด ๋ ๊ฒ์ด๋ค.
- back translation ์ ํตํ ๋ฐ์ดํฐ์ฆ๊ฐ
๋ฐ์ดํฐ ์ฆ๊ฐ์ ๋จ์ด ๊ต์ฒด, ์ฝ์ , ์์น ๋ณ๊ฒฝ, ์ญ์ ๋ฅผ ํ๋ ๋ฐฉ๋ฒ๋ ์์ง๋ง. ํฐ ๋ค์์ฑ์ ํ๋ณดํ์ง๋ ๋ชปํ๋ค๋ ๋จ์ ์ด ์กด์ฌํ๋ค. ์ด๋ฅผ ๊ทน๋ณตํ๊ณ ์ํ ๋ํ์ ์ธ Text Generation ๋ฐฉ๋ฒ์ธ Back Translation์ ๋ฒ์ญ๊ธฐ๋ฅผ ์ด์ฉํ์ฌ Label์ ์ ์งํ ์ฑ๋ก ์๋ณธ Data๋ฅผ ํ ์ธ์ด๋ก ๋ฒ์ญํ ๋ค, ๋ค์ ์๋์ ์ธ์ด๋ก ์ฌ๋ฒ์ญํ๋๋ฐ, ์ด๋ ์ถ๊ฐ๋๋ ๋ ธ์ด์ฆ๊ฐ ์ฆ๊ฐ ์ธก๋ฉด์์ ์ฑ๋ฅ์ ํฌ๊ฒ ์ฌ๋ฆฐ๋ค๊ณ ํ๋ค.
์ ๊ทํํ์์ ํ์ฉํ์ฌ ์ ์ฒ๋ฆฌ๋ฅผ ํ๊ณ , ์์ด → ํ๋์ค์ด → ์์ด๋ก ์ฌ๋ฒ์ญํ์ฌ ๋ถ์ ๋ฆฌ๋ทฐ 20๋ง๊ฐ๊น์ง ์ฆ๊ฐํด๋ณผ ๊ณํ์ ์ธ์ ๋ค. ๊ธ์ ๋ฆฌ๋ทฐ์ ๋ถ์ ๋ฆฌ๋ทฐ ๊ฐ์๊ฐ 7๋ง๊ฐ๋ก ์๋์ ์ผ๋ก ๊ธ์ ์ ์ธ ๋ฆฌ๋ทฐ๊ฐ ํ์ต๊ณผ ๋ถ๋ฅ๊ฐ ๋ ์ ๋๋ค๊ณ ์๊ฐํ๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋ฌ๋, 3300๊ฐ ๊น์ง ์ฆ๊ฐํ ์ํฉ์์ ์ฝ๋ฉ GPU์ computer units ํ ๋น๋์ ๋ค ์์งํ์ฌ ์ ์ฉํ์ง ๋ชปํ๋ค.
- ๋ ์ข์ ์๊ฐํ๋ฅผ ์ํด
์ธ๋ฒ์งธ๋ก๋, ์๊ฐํํ ๋ชจ์ต์ ๋ณด์๋ฉด train loss์ validation loss๊ฐ ์๋ก ๋ค๋ฅธ time์ ์ฐํ์๋๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. train์ ๊ณ์ฐ๋ ๋๋ง๋ค, val์ ํน์ ๊ตฌ๊ฐ์ ํ๊ท ์ผ๋ก ๊ณ์ฐํ์ฌ ์๊ฐํํ๊ธฐ ๋๋ฌธ์, ์๋ก ๋ช ํํ ๋น๊ต๊ฐ ์ด๋ ค์ ๋ค. train loss๊ฐ ๊ธฐ๋ก๋๋ timestep์ validation loss๊ฐ ๊ธฐ๋ก๋๋ time step๊ณผ ๋ง์ถฐ์ผ ํ ๊ฒ์ด๋ค. ์ด ๋ฐฉ๋ฒ์๋ ์ฌ๋ฌ๊ฐ์ง๊ฐ ์๋๋ฐ, vaildation์ด ์ฐํ๋ train์ ๊ฐ์ด ์ฐ๊ฑฐ๋ log์ ์ ์ฅ๋๋ train loss ๋ง ์ญ ์ ์ฅํ๋ค, ๋์ค์ ๋ง์ถฐ์ ๊ทธ๋ ค์ฃผ๋ ๋ฐฉ๋ฒ์ด ์๋ค.
- ensemble ๋ชจ๋ธ
์ถํ์๋ ๊ฐ ๋ชจ๋ธ ๋ณ๋ก ์ต์ ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฐพ์์ ๋ง๋ ๋ช ๊ฐ์ ๋ชจ๋ธ๋ค์ ์ด์ฉํด์ ensemble ๋ชจ๋ธ (hard voting classifier) ์ ๋ง๋ค์ด์ ์ฌ์ฉํ๋ ๊ฒ๋ ์ข์ ๋ฐฉ๋ฒ์ด ๋ ๊ฒ ๊ฐ์ต๋๋ค.
'๋ํ ํ๋ก์ ํธ > ํ๋ก์ ํธ' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
BERT๋ก ๋ด์ค ๊ธฐ์ฌ ์นดํ ๊ณ ๋ฆฌ ๋ถ๋ฅํ๊ธฐ (1) | 2022.12.12 |
---|