์ž์—ฐ์–ด ์ฒ˜๋ฆฌ/Today I learned :

์–ธ์–ด๋ชจ๋ธ GPT

์ฃผ์˜ ๐Ÿฑ 2023. 1. 17. 11:40
728x90
 
 
 
BERT ๊ฐ€ ํŠธ๋žœ์Šคํฌ๋จธ์˜ ์ธ์ฝ”๋”๋ฅผ ํ™œ์šฉํ–ˆ๋‹ค๋ฉด, GPT๋Š” ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๋””์ฝ”๋”๋งŒ ํ™œ์šฉํ•ฉ๋‹ˆ๋‹ค. ๋””์ฝ”๋” ์ค‘์—์„œ๋„ encoder-decoder attention์ด ๋น ์ง„ ๋””์ฝ”๋”๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. 
GPT์˜ ๊ตฌ์กฐ
 
Masked Multi-Head Attention์—์„œ ์ผ์–ด๋‚˜๋Š” ์ผ์„ ๋ณด๋ฉด, 
 
์ œ๊ฐ€ ๊ณ„์† ์˜ˆ์‹œ๋กœ ๋“œ๋Š” ๋ฌธ์žฅ์„ ๊ฐ€์ ธ์™€ ์ ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. 
 

'๋‚˜๋Š” ํ† ๋ผ๋ฅผ ํ‚ค์›Œ. ๋ชจ๋“  ์‚ฌ๋žŒ์ด ๊ทธ๋ฅผ ์ข‹์•„ํ•ด'๋ผ๋Š” ๋ฌธ์žฅ์—์„œ ์ฒ˜์Œ์—๋Š” ๋‚˜๋Š”์„ ๋บด๊ณ  ๋ชจ๋‘ ๋งˆ์Šคํ‚น์ฒ˜๋ฆฌํ•ด์ค๋‹ˆ๋‹ค. ๋‚˜๋Š” ๋งŒ๋ณด๊ณ  ํ† ๋ผ๋ฅผ ์„ ์˜ˆ์ธกํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด์„œ ํ† ๋ผ๋ฅผ์— ํ™•๋ฅ ์„ ๋†’์ด๋Š” ์‹์œผ๋กœ ์—…๋ฐ์ดํ„ฐํ•˜๋ฉฐ ํ•™์Šต์ด ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค. 

 
 ๊ทธ๋ฆฌ๊ณ  ๋‚˜๋Š” ํ† ๋ผ๋ฅผ ๋งŒ์œผ๋กœ ํ‚ค์›Œ๋ฅผ ์˜ˆ์ธกํ•  ์ˆ˜ ์žˆ๊ฒŒ , ํ‚ค์›Œ์— ํ™•๋ฅ ์„ ๋†’์ด๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. BERT์—์„œ๋Š” ๊ฐ€์šด๋ฐ ๋‹จ์–ด๋ฅผ [MASK]๋กœ ์ฒ˜๋ฆฌํ•˜๊ณ  ์•ž๊ณผ ๋’ค ๋‹จ์–ด๋“ค์„ ๋ณด๊ณ  ๊ฐ€์šด๋ฐ ๋งˆ์Šคํ‚น์ฒ˜๋ฆฌ๋œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ์‹์œผ๋กœ ํ”„๋ฆฌํŠธ๋ ˆ์ธ์„ ํ–ˆ์—ˆ์—ˆ๋Š”๋ฐ, GPT์—์„œ๋Š” ์ด๋Ÿฐ ๋ฐฉ์‹์œผ๋กœ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ทธ ์ด์œ ๋Š” ๋ฐ”๋กœ ๊ตฌ์กฐ์ ์ธ ์ฐจ์ด์— ์žˆ๊ธฐ ๋•Œ๋ฌธ์ธ๋ฐ, 
 

BERT๋Š” ์–‘๋ฐฉํ–ฅ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์ธ ๋ฐ˜๋ฉด, GPT๋Š” ๊ทธ๋ ‡์ง€ ์•Š๋‹ค๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. GPT๋Š” ์ˆœ์ฐจ์ ์œผ๋กœ ํ•™์Šต์ด ์ด๋ฃจ์–ด์ง€๋Š” ๊ฒƒ์ด์ฃ 

 
 GPT๋Š” ์Œ“๋Š” ๋””์ฝ”๋” ๊ฐœ์ˆ˜์— ๋”ฐ๋ผ small ๋ถ€ํ„ฐ large ๋ชจ๋ธ๋กœ ๋ถˆ๋ฆฝ๋‹ˆ๋‹ค. 
 
 
 
 
 
ํ˜„์žฌ๋Š” GPT-3 ๋ชจ๋ธ๊นŒ์ง€ ๋‚˜์™”์œผ๋ฉฐ ์•ž์œผ๋กœ ๋” ๋‚˜์˜ค๊ณ  ์„ฑ๋Šฅ๋„ ์ข‹์•„์งˆ ๊ฒƒ์ด๋ผ ๋ณด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. 
 
 
 
 
 

 
๊ฐ„๋‹จํ•˜๊ฒŒ GPT-2๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์„ ์ฝ”๋“œ๋กœ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
huggingface์˜ transformers ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜๋ฉด ์‰ฝ๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
 
 
 
pip install transformers

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

input_ids = tokenizer.encode("Some text to encode", return_tensors='pt')

 

generated_text_samples = model.generate(

    input_ids,

    max_length=150,

    num_return_sequences=5,

    no_repeat_ngram_size=2,

    repetition_penalty=1.5,

    top_p=0.92,

    temperature=0.85,

    do_sample=True,

    top_k=125,

    early_stopping=True

)

 

for i, beam in enumerate(generated_text_samples):

    print("{}: {}".format(i, tokenizer.decode(beam, skip_special_tokens=True)))

    print()

๋ฐ˜์‘ํ˜•