๋”ฅ๋Ÿฌ๋‹/Today I learned :

[๋”ฅ๋Ÿฌ๋‹] ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•(Gradient Descent)

์ฃผ์˜ ๐Ÿฑ 2021. 3. 15. 20:40
728x90
๋ฐ˜์‘ํ˜•

์˜ค์ฐจ์˜ ๋ฑํ™”์— ๋”ฐ๋ผ ์ด์ฐจํ•จ์ˆ˜ ๊ทธ๋ž˜ํ”„๋ฅผ ๋งŒ๋“œ๋ก ์ ์ ˆํ•œ ํ•™์Šต๋ฅ ์„ ์„ค์ •ํ•ด ๋ฏธ๋ถ„ ๊ฐ’์ด 0์ธ ์ง€์ ์„ ๊ตฌํ•˜๋Š” ๊ฒƒ.

์ตœ์†Œ์ œ๊ณฑ๋ฒ•์„ ์“ฐ์ง€ ์•Š๊ณ  ํ‰๊ท ์ œ๊ณฑ์˜ค์ฐจ, ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์œผ๋กœ ์›ํ•˜๋Š” ๊ฐ’ ๊ตฌํ•˜๊ธฐ ๊ฐ€๋Šฅ

 

  • ํ•จ์ˆ˜์˜ ๊ธฐ์šธ๊ธฐa๋ฅผ ๊ธฐ์šธ๊ธฐ๊ฐ€ ๋‚ฎ์€ ์ชฝ์œผ๋กœ ๊ณ„์† ์ด๋™์‹œ์ผœ ์ตœ์†Ÿ๊ฐ’ m์— ์ด๋ฅผ ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณตํ•œ๋‹ค.
  • ์ตœ์†Ÿ๊ฐ’ m์—์„œ์˜ ์ˆœ๊ฐ„๊ธฐ์šธ๊ธฐ , ๊ธฐ์šธ๊ธฐ๊ฐ€ 0 = ๋ฏธ๋ถ„๊ฐ’์ด 0์ธ ์ง€์  ์ฐพ๊ธฐ
  • ์ตœ์ ์˜ b, y์ ˆํŽธ ๊ตฌํ•  ๋•Œ๋„ ์‚ฌ์šฉ

 

1. a1์—์„œ ๋ฏธ๋ถ„

2. ๊ตฌํ•œ ๊ธฐ์šธ๊ธฐ์˜ ๋ฐ˜๋Œ€ ๋ฐฉํ–ฅ(๊ธฐ์šธ๊ธฐ๊ฐ€ +๋ฉด ์Œ์˜ ๋ฐฉํ–ฅ)์œผ๋กœ ์–ผ๋งˆ๊ฐ„ ์ด๋™์‹œํ‚จ a2์—์„œ ๋ฏธ๋ถ„

3. ๋ฏธ๋ถ„๊ฐ’์ด 0์ด ๋‚˜์˜ฌ ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณต

 

 

 

๊ธฐ์šธ๊ธฐ์˜ ๋ฐ˜๋Œ€ ๋ฐฉํ–ฅ์œผ๋กœ ์ด๋™์‹œ ๋„ˆ๋ฌด ๋ฉ€๋ฆฌ ์ด๋™์‹œํ‚ค๋ฉด a๊ฐ’์ด ์œ„๋กœ ์น˜์†Ÿ์•„๋ฒ„๋ฆผ

 

ํ•™์Šต๋ฅ  - ์ด๋™ ๊ฑฐ๋ฆฌ๋ฅผ ์ •ํ•ด์ฃผ๋Š” ๊ฒƒ. ์ตœ์ ์˜ ํ•™์Šต๋ฅ ์„ ์ฐพ์•„์•ผ ํ•œ๋‹ค.

 


 

 

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#x,y๋ฐ์ดํ„ฐ๊ฐ’
data=[[2,81],[4,93],[6,91],[8,97]]
x = [i[0] for i in data]
y = [i[1] for i in data]

#๊ทธ๋ž˜ํ”„๋กœ ๋‚˜ํƒ€๋‚ด๊ธฐ
plt.figure(figsize=(8,5)
plt.scatter(x,y)
plt.show()

#x,y๋ฆฌ์ŠคํŠธ๋ฅผ numpy๋ฐฐ์—ด๋กœ ๋ฐ”๊พธ๊ธฐ(์ธ๋ฑ์Šค๋ฅผ ์ฃผ์–ด ํ•˜๋‚˜์”ฉ ๋ถˆ๋Ÿฌ์™€ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด)
xdata=np.array(x)
ydata=np.array(y)

#์ดˆ๊ธฐํ™”
a=0
b=0


#ํ•™์Šต๋ฅ  ์ •ํ•˜๊ธฐ
lr=0.05

#๋ช‡๋ฒˆ ๋ฐ˜๋ณต๋ ์ง€(0๋ถ€ํ„ฐ ์„ธ๋ฏ€๋กœ +1ํ•ด์ฃผ๊ธฐ)
epochs =2001

#๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•
for i in range(epochs):
	y_pred = a* xdata+b
    error =ydata-y_pred   #์˜ค์ฐจ
    a_diff = -(1/len(xdata))* sum(xdata *(error)) #์˜ค์ฐจํ•จ์ˆ˜ a๋กœ ๋ฏธ๋ถ„ํ•œ ๊ฐ’
    b_diff = -(1/len(xdata))* sum(ydata - y_pred) #์˜ค์ฐจํ•จ์ˆ˜ b๋กœ ๋ฏธ๋ถ„ํ•œ ๊ฐ’
    
    a = a-lr *a_diff #ํ•™์Šต๋ฅ  ๊ณฑํ•ด ๊ธฐ์กด a,b๊ฐ’ ์—…๋ฐ์ดํŠธ
    b = b-lr *b_diff    
    
    if i%100 ==0:
    	print("epoch=%.f, ๊ธฐ์šธ๊ธฐ=%.04f, ์ ˆํŽธ=%.04f" % (i,a,b)) #100๋ฒˆ ๋ฐ˜๋ณต๋ ๋•Œ๋งˆ๋‹ค ํ˜„์žฌ a,b์ถœ๋ ฅ
 
 
 #์—…๋ฐ์ดํŠธ๊ฐ’์œผ๋กœ ๊ทธ๋ž˜ํ”„ ๋‹ค์‹œ๊ทธ๋ฆฌ๊ธฐ
 y_pred = a* xdata +b
 plt.scatter(x,y)
 plt.plot([min(xdata),max(xdata)],[min(y_pred),max(y_pred)]
 plt.show()
๋ฐ˜์‘ํ˜•