-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.py
141 lines (123 loc) · 4.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from matplotlib import pyplot as plt
import matplotlib.colors as mcolors
from itertools import cycle
from config import PlotConfig, DynamicsConfig
import numpy as np
import torch
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
def smooth(data, a=0.5):
data = np.array(data).reshape(-1, 1)
for ind in range(data.shape[0] - 1):
data[ind + 1, 0] = data[ind, 0] * (1-a) + data[ind + 1, 0] * a
return data
def numpy2torch(input, size):
"""
Parameters
----------
input
Returns
-------
"""
u = np.array(input, dtype='float32').reshape(size)
return torch.from_numpy(u)
def step_relative(statemodel, state, u):
"""
Parameters
----------
state_r
u_r
Returns
-------
"""
x_ref = statemodel.reference_trajectory(state[:, -1])
state_r = state.detach().clone() # relative state
state_r[:, 0:4] = state_r[:, 0:4] - x_ref
state_next, deri_state, utility, F_y1, F_y2, alpha_1, alpha_2 = statemodel.step(state, u)
state_r_next_bias, _, _, _, _, _, _ = statemodel.step(state_r, u) # update by relative value
state_r_next = state_r_next_bias.detach().clone()
state_r_next_bias[:, [0, 2]] = state_next[:, [0, 2]] # y psi with reference update by absolute value
x_ref_next = statemodel.reference_trajectory(state_next[:, -1])
state_r_next[:, 0:4] = state_r_next_bias[:, 0:4] - x_ref_next
return state_next.clone().detach(), state_r_next.clone().detach(), x_ref.detach().clone()
def recover_absolute_state(state_r_predict, x_ref, length=None):
if length is None:
length = state_r_predict.shape[0]
c = DynamicsConfig()
ref_predict = [x_ref]
for i in range(length-1):
ref_t = np.copy(ref_predict[-1])
# ref_t[0] += c.u * c.Ts * np.tan(x_ref[2])
ref_predict.append(ref_t)
state = state_r_predict[:, 0:4] + ref_predict
return state, np.array(ref_predict)
def idplot(data,
figure_num=1,
mode="xy",
fname=None,
xlabel=None,
ylabel=None,
legend=None,
legend_loc="best",
color_list=None,
xlim=None,
ylim=None,
ncol=1,
figsize_scalar=1):
"""
plot figures
"""
if (color_list is None) or len(color_list) < figure_num:
tableau_colors = cycle(mcolors.TABLEAU_COLORS)
color_list = [next(tableau_colors) for _ in range(figure_num)]
l = 5
fig_size = (PlotConfig.fig_size * figsize_scalar, PlotConfig.fig_size * figsize_scalar)
_, ax = plt.subplots(figsize=cm2inch(*fig_size), dpi=PlotConfig.dpi)
if figure_num == 1:
data = [data]
if color_list is not None:
for (i, d) in enumerate(data):
if mode == "xy":
if i == l - 2:
plt.plot(d[0], d[1], linestyle='-.', color=color_list[i])
elif i == l - 1:
plt.plot(d[0], d[1], linestyle='dotted', color=color_list[i])
else:
plt.plot(d[0], d[1], color=color_list[i])
if mode == "y":
plt.plot(d, color=color_list[i])
if mode == "scatter":
plt.scatter(d[0], d[1], color=color_list[i], marker=".", s =5.,)
else:
for (i, d) in enumerate(data):
if mode == "xy":
if i == l - 2:
plt.plot(d[0], d[1], linestyle='-.')
elif i == l - 1:
plt.plot(d[0], d[1], linestyle='dotted')
else:
plt.plot(d[0], d[1])
if mode == "y":
plt.plot(d)
if mode == "scatter":
plt.scatter(d[0], d[1], marker=".", s =5.,)
plt.tick_params(labelsize=PlotConfig.tick_size)
labels = ax.get_xticklabels() + ax.get_yticklabels()
[label.set_fontname(PlotConfig.tick_label_font) for label in labels]
if legend is not None:
plt.legend(legend, loc=legend_loc, ncol=ncol, prop=PlotConfig.legend_font)
plt.xlabel(xlabel, PlotConfig.label_font)
plt.ylabel(ylabel, PlotConfig.label_font)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.tight_layout(pad=PlotConfig.pad)
if fname is None:
plt.show()
else:
plt.savefig(fname)