-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathFinite_LQR.py
74 lines (69 loc) · 2.92 KB
/
Finite_LQR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gym
from scipy import linalg
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
class Cartpole:
def __init__(self):
self.m=0.1
self.g=9.8
self.l=0.5
self.M=1
self.N=150#planning horizon
self.env=gym.make('CartPole-v1',render_mode="rgb_array")
self.J=np.zeros(self.N) # cost per each time step
self.K=np.zeros((self.N,1,4)) # k for each time step
self.P=np.zeros((self.N,4,4)) # P for each time step
self.U=np.zeros(self.N)
self.dt=0.02
self.terminal_state=np.array([[0],[0],[0],[0]])
self.state_space_parameters()
self.backwardpass()
def state_space_parameters(self):
self.A=np.array([[1,self.dt,0,0],[0,1,self.dt*(-self.m*self.g/self.M),0],[0,0,1,self.dt],[0,0,self.dt*(self.M+self.m)*self.g/(self.l*self.M),1]])
self.B=np.array([[0],[self.dt*1/self.M],[0],[self.dt*-1/(self.M*self.l)]])
self.Q=np.array([[10,0,0,0],[0,1,0,0],[0,0,10,0],[0,0,0,1]])
self.R=np.array([[1]])
def backwardpass(self):
self.terminal_cost=np.transpose(self.terminal_state)@[email protected]_state
self.P[0]=self.Q
for n in range(1,self.N):
self.K[n]=-np.linalg.inv(self.R+np.transpose(self.B)@self.P[n-1]@self.B)@np.transpose(self.B)@self.P[n-1]@self.A
self.P[n]= self.Q+np.transpose(self.K[n])@[email protected][n]+np.transpose([email protected][n])@self.P[n-1]@([email protected][n])
def lqr(self,state,n):
self.U[n]=self.K[self.N-n-1]@state
self.J[self.N-n-1][email protected][self.N-n-1]@np.transpose(state)
print(self.J[self.N-n-1])
return float(self.U[n])
def forward_pass(self):
self.env = gym.wrappers.RecordVideo(self.env,'video',name_prefix="finite_lqr_cartpole")
self.env.reset()
state=np.transpose(np.array(self.env.state))
# terminated=False
truncated=False
t=0
self.states=np.zeros((self.N,4))
self.u=np.zeros(self.N)
for n in range(self.N):
action=self.lqr(state,n)
observation, reward, terminated, truncated, info=self.env.step(action)
state=np.transpose(np.array(observation))
self.states[t,:]=state
t+=1
print(self.terminal_cost)
def plotting(self):
with plt.style.context('seaborn-v0_8'):
plt.figure(figsize =(8, 4))
plt.plot(self.J[:-1],linestyle="--")
plt.plot(self.U,linestyle="--")
plt.ylabel("input N")
plt.xlabel("time")
plt.figure(figsize =(8,4))
plt.plot(self.states,linestyle="--",label=('CartPosition',"CartVelocity","PoleAngle","PoleAngularVelocity"))
plt.ylabel("states")
plt .xlabel("time")
plt.show()
if __name__ == "__main__":
cartpole=Cartpole()
cartpole.forward_pass()
cartpole.plotting()