์ž์—ฐ์–ด ์ฒ˜๋ฆฌ/Today I learned :

๋ฒ„ํŠธ๋ฅผ ํ™œ์šฉํ•œ ์˜ํ™”๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜

์ฃผ์˜ ๐Ÿฑ 2023. 1. 16. 17:20
728x90

Pre-trained BERT๋ฅผ ํ™œ์šฉํ•˜์—ฌ ์˜ํ™”๋ฆฌ๋ทฐ ๊ฐ์ •๋ถ„๋ฅ˜

 

๋‘๊ฐ€์ง€ ๋ฐฉ๋ฒ•:

1. transformers ์˜ Trainer๋ฅผ ํ™œ์šฉํ•œ๋‹ค.

2. pytorch๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. 

์ด ๋‘๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ๋‹ค ์ตํ˜€๋‘๋Š” ๊ฒƒ์ด ์ข‹์œผ๋ฉฐ, 1๋ฒˆ์€ ์ˆ˜์ • ํ˜น์€ ๋ฏธ์„ธ์กฐ์ •ํ•˜๊ธฐ๊ฐ€ ์กฐ๊ธˆ ๊นŒ๋‹ค๋กœ์›Œ์„œ 2๋ฒˆ์œผ๋กœ ํ•˜๋Š” ๊ฒƒ์„ ์„ ํ˜ธ ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค๊ณ  ํ•œ๋‹ค. 

 

 
 
tokenizer.model_max_length= 512

def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(1000))
full_train_dataset = tokenized_datasets['train']
full_eval_dataset = tokenized_datasets['test']

์ฒซ๋ฒˆ์งธ๋ฐฉ๋ฒ•

#Transformers library๋ฅผ ์ด์šฉํ•œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜๊ธฐ ํ•™์Šต

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments("test_trainer")
# ์ „์ฒด dataset ํ•™์Šต/ํ‰๊ฐ€์„ ์›ํ•˜์‹œ๋Š” ๋ถ„๋“ค์€ full_train_dataset, full_eval_dataset์„ ์‚ฌ์šฉํ•˜์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค.
trainer = Trainer(model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset)

trainer.train()

model = BertForSequenceClassification.from_pretrained('finiteautomata/beto-sentiment-analysis')
trainer = Trainer(model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset)

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
    
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=small_train_dataset,
                  eval_dataset=small_eval_dataset,
                  compute_metrics=compute_metrics)
trainer.evaluate()

๋‘๋ฒˆ์งธ๋ฐฉ๋ฒ•

#Pytorch library๋ฅผ ์ด์šฉํ•œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜๊ธฐ ํ•™์Šต

from transformers import AdamW

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
optimizer = AdamW(model.parameters(), lr=5e-5)

tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

# ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ 1000๊ฐœ์˜ ํ•™์Šต/ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹๋งŒ์„ ์ด์šฉํ•ด ์ง„ํ–‰ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=32)

import torch

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.train()

for epoch in range(num_epochs):
    for input in train_dataloader:
        input = {k: v.to(device) for k, v in input.items()}
        outputs = model(**input)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        progress_bar.update()
        
metric = load_metric("accuracy")
model.eval()
all_pred = []
all_ref = []
for input in eval_dataloader:
    input = {k: v.to(device) for k, v in input.items()}
    with torch.no_grad():
        outputs = model(**input)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    all_pred.append(predictions.cpu().detach().numpy())
    all_ref.append(input['labels'].cpu().detach().numpy())
    metric.add_batch(predictions=predictions, references=input['labels'])

metric.compute()
๋ฐ˜์‘ํ˜•