전 포스트에서 벨만 기대 방정식에 대해서 다루어 보았다.

Vπ(s) = E[Rt+1 + γVπ(St+1) | St = s]

이는 벨만 기대 방정식인데 현재의 정책 π을 따라갔을때 받을 보상에 대한 기댓값이라고 할 수 있다.

우리는 벨만 기대 방정식을 통해 한 timestep의 reward, value 그리고 policy를 고려해 각 state에 대한 timestep + 1시점의 value값을 구할 수 있다.

그러나 그걸가지고 어떻게 강화학습의 목적인 최적인 action을 선택 하게 학습을 시킬것인가?

우리는 앞에서 정책이라는 것을 배웠다.

정책 무엇이였냐면 “모든 상태에 대해 agent가 어떤 행동을 해야 하는지 정해놓은 규칙” 이였다.

그럼 우리는 정책을 얼마나 잘 정하는지에 따라 최적인 action을 선택할 것임을 깨달을 수 있다.

따라서 이번엔 벨만 기대 방정식을 통해 정책과 value값을 계산하는 과정을 배울 것이다.


Policy Iteration

2021-10-17-rlpost6-01.png

정책 이터레이션은 위의 그림의 형태랑 같다고 생각하면 된다.

좀 더 자세하게 설명해보자면,

agent가 현재 정책에 따라 행동 했을때의 value값을 얻는 정책 평가(policy evaluation)와

정책평가를 해서 얻은 value값을 바탕으로 다시 정책을 업데이트 시키는 정책 발전(policy improvement)을

계속해서 반복을 하는게 policy iteration(정책 반복)이다.

이를 통해 최적의 정책을 찾아 agent가 최대한 효율적으로 움직이게 끔 학습을 시키는것이 목적이다.(강화학습의 목적)




Policy Iteration 순서

앞서 말한 부분에서 policy iteration가 이루어지는 과정이 정리가 좀 안되었다.

따라서 순서도를 통해 policy iteration을 설명할것이다.

1. 처음에는 random policy(무작위 정책)로 정책을 설정한다.
2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다.
3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.
4. 2 ~ 3을 반복한다.

여기서 중요한점은 policy 에 대해서만 improvement가 되는 것은 아니다.

value값 또한 policy가 업데이트 되면서 동시에 서로서로 에게 영향을 주며 업데이트를 해 나간다.

우리는 1번 2번까지는 할 수 있다.

어떻게?

policy evaluation은 우리가 아까 계속 언급했던 벨만 기대 방정식을 쓰면 되기 때문이다.

그런데.. 3번 policy imporvement는??

우리는 이 policy improvement를 greedy policy improvement 를 이용해서 구할 것이다.

이 부분은 밑에 예제를 설명하면서 설명하겠다.

grid world 예제를 통해 위 1번에서 4번까지를 직접 반복해 보겠다.




Policy Iteration in grid-world

현재 grid world의 상황이다.

v = 0 v = 0 v = 0
v = 0 v = 0 v = 0 Trap(R = -1)
v = 0 v = 0 v = 0 Goal(R = 1)

1. 처음에는 random policy(무작위 정책)로 정책을 설정한다.

이 policy는 모든 state에 각각 해당되는 policy임을 기억하자

π(up | s) = 0.25

π(down | s) = 0.25

π(left | s) = 0.25

π(right | s) = 0.25


2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

1st iteration

s1 (v = 0) s2 (v = 0) s3 (v = -0.25)
s4 (v = 0) s5 (v = - 0.25) s6 (v = 0.25)
s7 (v = 0) s8 (v = 0.25) s9 (v = 0)
Vπ(s1) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = 0
Vπ(s2) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = 0
Vπ(s3) = [0.25 * (-1 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = - 0.25
Vπ(s4) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = 0
Vπ(s5) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] + [0.25 * (-1 + 0.8 * 0)]= - 0.25
Vπ(s6) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (1 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = 0.25
Vπ(s7) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] = 0
Vπ(s8) = [0.25 * (0 + 0.8 * 0)] + [0.25 * (0 + 0.8 * 0)] + [0.25 * (1 + 0.8 * 0)] = 0.25


3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

그런데 아까부터 계속 언급했듯이 어떻게 정책을 발전 시켜나가는 것일까?

우리는 policy improvement 를 Greedy Policy Improvement를 사용하여 할것이다.

말 그대로 탐욕 정책 발전인데 자세히 설명해보겠다.

우리는 value값을 policy evaluation으로 부터 구해왔다.

value값을 알면 q값도 알것이다.

그 q값을 이용하는데, 가장 높은 q 값을 가지고 그에 해당하는 action 을 선택하는 정책이라고 할 수 있다.

만약 가장 높은 q값이 여러개가 있다고 하면 1에다 여러개를 나누어서 정책을 구한다.

식으로 설명해보았다.

argmaxa∈A qπ(s, a)

그럼 위에서 해놓은 policy evaluation을 토대로 geedy policy improvement를 해볼까?

위에서 계산해둔 value값의 식에서 [ ] 사이의 값들이 q값들이다.

1st iteration

π(up ,s)=0, π(down ,s)=0.5, π(left ,s)=0, π(right ,s)=0.5 π(up ,s)=0, π(down ,s)=0.33, π(left ,s)=0.33, π(right ,s)=0.33 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0
π(up ,s)=0.33, π(down ,s)=0.33, π(left ,s)=0, π(right ,s)=0.33 π(up ,s)=0.33, π(down ,s)=0.33, π(left ,s)=0.33, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0.5, π(down ,s)=0, π(left ,s)=0, π(right ,s)=0.5 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
↓ → ← ↓ →
↑ ↓ → ← ↑ ↓
↑ →  



4. 2 ~ 3을 반복한다.

2rd iteration

현재 grid world의 상황이다.

s1 (v = 0) s2 (v = 0) s3 (v = -0.25)
s4 (v = 0) s5 (v = - 0.25) s6 (v = 0.25) Trap(R = -1)
s7 (v = 0) s8 (v = 0.25) s9 (v = 0) Goal(R = 1)
↓ → ← ↓ →
↑ ↓ → ← ↑ ↓
↑ →  

2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

2nd iteration

s1 (v = 0) s2 (v = - 0.132) s3 (v = 0)
s4 (v = - 0.066) s5 (v = 0.066) s6 (v = 1)
s7 (v = 0.1) s8 (v = 1) s9 (v = 0)
Vπ(s1) = [0.5 * (0 + 0.8 * 0)] + [0.5 * (0 + 0.8 * 0)] = 0
Vπ(s2) = [0.33 * (0 + 0.8 * (- 0.25))] + [0.33 * (0 + 0.8 * 0)] + [0.33 * (0 + 0.8 * (- 0.25))] = - 0.132
Vπ(s3) = [1 * (0 + 0.8 * 0)] = 0
Vπ(s4) = [0.33 * (0 + 0.8 * 0)] + [0.33 * (0 + 0.8 * 0)] + [0.33 * (0 + 0.8 * (- 0.25))] = - 0.066
Vπ(s5) = [0.33 * (0 + 0.8 * 0)] + [0.33 * (0 + 0.8 * 0.25)] + [0.33 * (0 + 0.8 * 0)] = 0.066
Vπ(s6) = [1 * (1 + 0.8 * 0)] = 1
Vπ(s7) = [0.5 * (0 + 0.8 * 0)] + [0.5 * (0 + 0.8 * 0.25)] = 0.1
Vπ(s8) = [1 * (1 + 0.8 * 0)] = 1

3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

2nd iteration

π(up ,s)=0, π(down ,s)=0.5, π(left ,s)=0, π(right ,s)=0.5 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0 π(up ,s)=0, π(down ,s)=0.5, π(left ,s)=0.5, π(right ,s)=0
π(up ,s)=0.5, π(down ,s)=0.5, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
↓ → ← ↓
↑ ↓
 


3rd iteration

현재 grid world의 상황이다.

s1 (v = 0) s2 (v = - 0.132) s3 (v = 0)
s4 (v = - 0.066) s5 (v = 0.066) s6 (v = 1) Trap(R = -1)
s7 (v = 0.1) s8 (v = 1) s9 (v = 0) Goal(R = 1)
↓ → ← ↓
↑ ↓
 

2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

3rd iteration

s1 (v = - 0.0792) s2 (v = 0) s3 (v = - 0.1528)
s4 (v = 0.04) s5 (v = 0.8) s6 (v = 1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0)
Vπ(s1) = [0.5 * (0 + 0.8 * (- 0.066))] + [0.5 * (0 + 0.8 * (- 0.132))] = - 0.0792
Vπ(s2) = [1 * (0 + 0.8 * 0)]= 0
Vπ(s3) = [0.5 * (-1 + 0.8 * 1)] + [0.5 * (0 + 0.8 * (- 0.132))] = - 0.1528
Vπ(s4) = [0.5 * (0 + 0.8 * 0)] + [0.5 * (0 + 0.8 * 0.1)] = 0.04
Vπ(s5) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s6) = [1 * (1 + 0.8 * 0)] = 1
Vπ(s7) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s8) = [1 * (1 + 0.8 * 0)] = 1

3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

3rd iteration

π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=0.33, π(left ,s)=0.33, π(right ,s)=0.33 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0
π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
← ↓ →
 


4th iteration

현재 grid world의 상황이다.

s1 (v = - 0.0792) s2 (v = 0) s3 (v = - 0.1528)
s4 (v = 0.04) s5 (v = 0.8) s6 (v = 1) Trap(R = -1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0) Goal(R = 1)
← ↓ →
 

2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

4th iteration

s1 (v = 0.512) s2 (v = 0.150952) s3 (v = 0.8)
s4 (v = 0.64) s5 (v = 0.8) s6 (v = 1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0)
Vπ(s1) = [1 * (0 + 0.8 * (0.64))] = 0.512
Vπ(s2) = [0.33 * (0 + 0.8 * 0.8] + [0.33 * (0 + 0.8 * (- 0.0792))] + [0.33 * (0 + 0.8 * (- 0.1528))]= 0.150952
Vπ(s3) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s4) = [1 * (0 + 0.8 * 0.8)]= 0.64
Vπ(s5) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s6) = [1 * (1 + 0.8 * 0)] = 1
Vπ(s7) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s8) = [1 * (1 + 0.8 * 0)] = 1

3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

4th iteration

π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0
π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
 


5th iteration

현재 grid world의 상황이다.

s1 (v = 0.512) s2 (v = 0.150952) s3 (v = 0.8)
s4 (v = 0.64) s5 (v = 0.8) s6 (v = 1) Trap(R = -1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0) Goal(R = 1)
 

2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

5th iteration

s1 (v = 0.512) s2 (v = 0.64) s3 (v = 0.1207616)
s4 (v = 0.64) s5 (v = 0.8) s6 (v = 1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0)
Vπ(s1) = [1 * (0 + 0.8 * (0.64))] = 0.512
Vπ(s2) = [1 * (0 + 0.8 * 0.8] = 0.64
Vπ(s3) = [1 * (0 + 0.8 * 0.150952)] = 0.1207616
Vπ(s4) = [1 * (0 + 0.8 * 0.8)]= 0.64
Vπ(s5) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s6) = [1 * (1 + 0.8 * 0)] = 1
Vπ(s7) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s8) = [1 * (1 + 0.8 * 0)] = 1

3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

5th iteration

π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0
π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
 


6th iteration

현재 grid world의 상황이다.

s1 (v = 0.512) s2 (v = 0.64) s3 (v = 0.1207616)
s4 (v = 0.64) s5 (v = 0.8) s6 (v = 1) Trap(R = -1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0) Goal(R = 1)
 

2. 모든 state에 대해서 정해진 업데이트 된 현재 policy를 통해 모든 state의 value값을 정하는 policy evaluation(정책 평가)를 한다. (γ = 0.8)

Vπ(s) = Σa∈A π(At | St = s) (r(St = s, At = a) + γVπ(s'))

6th iteration

s1 (v = 0.512) s2 (v = 0.64) s3 (v = 0.512)
s4 (v = 0.64) s5 (v = 0.8) s6 (v = 1)
s7 (v = 0.8) s8 (v = 1) s9 (v = 0)
Vπ(s1) = [1 * (0 + 0.8 * (0.64))] = 0.512
Vπ(s2) = [1 * (0 + 0.8 * 0.8] = 0.64
Vπ(s3) = [1 * (0 + 0.8 * 0.64)] = 0.512
Vπ(s4) = [1 * (0 + 0.8 * 0.8)]= 0.64
Vπ(s5) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s6) = [1 * (1 + 0.8 * 0)] = 1
Vπ(s7) = [1 * (0 + 0.8 * 1)] = 0.8
Vπ(s8) = [1 * (1 + 0.8 * 0)] = 1

3. 모든 state의 value를 바탕으로 모든 state의 policy를 업데이트를 하는 policy improvement(정책 발전)을 한다.

6th iteration

π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=0, π(left ,s)=1, π(right ,s)=0
π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0 π(up ,s)=0, π(down ,s)=1, π(left ,s)=0, π(right ,s)=0
π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1 π(up ,s)=0, π(down ,s)=0, π(left ,s)=0, π(right ,s)=1  
 




policy iteration 결론

방금 6번의 iteration을 직접 해보니 신기하게도 policy가 goal 있는 지점으로 action을 하는 모습을 보였다.

이런식으로 한 timestep마다 모든 state의 value값과 q값을 구하는 방식이 dynamic programming 이라고 하고

policy iteration은 그걸 사용하여 policy를 구하였다.

다음 포스트에는 value iteration을 설명할 것이다.



제가 올린 글에서 잘못된 부분이 있으면 제 메일로 연락주세요!

Reference : 파이썬과 케라스로 배우는 강화학습

크리에이티브 커먼즈 라이선스
이승수의 저작물인 이 저작물은(는) 크리에이티브 커먼즈 저작자표시-비영리-동일조건변경허락 4.0 국제 라이선스에 따라 이용할 수 있습니다.