Skip to content

Commit 60d5d93

Browse files
author
Chris Elion
authored
[MLA-1783] built-in actuator type (#4950)
1 parent ad620ec commit 60d5d93

File tree

15 files changed

+225
-27
lines changed

15 files changed

+225
-27
lines changed

Diff for: com.unity.ml-agents.extensions/Runtime/Input/InputActionActuator.cs

+9-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Unity.MLAgents.Extensions.Input
1414
/// <see cref="Agent"/>'s <see cref="BehaviorParameters"/> indicate that the Agent is running in Heuristic Mode,
1515
/// this Actuator will write actions from the <see cref="InputSystem"/> to the <see cref="ActionBuffers"/> object.
1616
/// </summary>
17-
public class InputActionActuator : IActuator, IHeuristicProvider
17+
public class InputActionActuator : IActuator, IHeuristicProvider, IBuiltInActuator
1818
{
1919
readonly BehaviorParameters m_BehaviorParameters;
2020
readonly InputAction m_Action;
@@ -35,8 +35,8 @@ public class InputActionActuator : IActuator, IHeuristicProvider
3535
/// <param name="adaptor">The <see cref="IRLActionInputAdaptor"/> that will convert data between ML-Agents
3636
/// and the <see cref="InputSystem"/>.</param>
3737
public InputActionActuator(InputDevice inputDevice, BehaviorParameters behaviorParameters,
38-
InputAction action,
39-
IRLActionInputAdaptor adaptor)
38+
InputAction action,
39+
IRLActionInputAdaptor adaptor)
4040
{
4141
m_BehaviorParameters = behaviorParameters;
4242
Name = $"InputActionActuator-{action.name}";
@@ -83,6 +83,12 @@ public void Heuristic(in ActionBuffers actionBuffersOut)
8383
m_InputAdaptor.WriteToHeuristic(m_Action, actionBuffersOut);
8484
Profiler.EndSample();
8585
}
86+
87+
/// <inheritdoc/>
88+
public BuiltInActuatorType GetBuiltInActuatorType()
89+
{
90+
return BuiltInActuatorType.InputActionActuator;
91+
}
8692
}
8793
}
8894

Diff for: com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Unity.MLAgents.Extensions.Match3
99
/// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid())
1010
/// in action masks, and applies the action to the board via AbstractBoard.MakeMove().
1111
/// </summary>
12-
public class Match3Actuator : IActuator, IHeuristicProvider
12+
public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator
1313
{
1414
protected AbstractBoard m_Board;
1515
protected System.Random m_Random;
@@ -92,6 +92,12 @@ public void ResetData()
9292
{
9393
}
9494

95+
/// <inheritdoc/>
96+
public BuiltInActuatorType GetBuiltInActuatorType()
97+
{
98+
return BuiltInActuatorType.Match3Actuator;
99+
}
100+
95101
IEnumerable<int> InvalidMoveIndices()
96102
{
97103
var numValidMoves = m_Board.NumMoves();
@@ -179,6 +185,5 @@ protected virtual int EvalMovePoints(Move move)
179185
{
180186
return 1;
181187
}
182-
183188
}
184189
}
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
namespace Unity.MLAgents.Actuators
2+
{
3+
/// <summary>
4+
/// Identifiers for "built in" actuator types.
5+
/// These are only used for analytics, and should not be used for any runtime decisions.
6+
///
7+
/// NOTE: Do not renumber these, since the values are used for analytics. Renaming is allowed though.
8+
/// </summary>
9+
public enum BuiltInActuatorType
10+
{
11+
/// <summary>
12+
/// Default Sensor type if it cannot be determined.
13+
/// </summary>
14+
Unknown = 0,
15+
16+
/// <summary>
17+
/// VectorActuator used by the Agent
18+
/// </summary>
19+
AgentVectorActuator = 1,
20+
21+
/// <summary>
22+
/// Corresponds to <see cref="VectorActuator"/>
23+
/// </summary>
24+
VectorActuator = 2,
25+
26+
/// <summary>
27+
/// Corresponds to the Match3Actuator in com.unity.ml-agents.extensions.
28+
/// </summary>
29+
Match3Actuator = 3,
30+
31+
/// <summary>
32+
/// Corresponds to the InputActionActuator in com.unity.ml-agents.extensions.
33+
/// </summary>
34+
InputActionActuator = 4,
35+
}
36+
37+
/// <summary>
38+
/// Interface for actuators that are provided as part of ML-Agents.
39+
/// User-implemented actuators don't need to use this interface.
40+
/// </summary>
41+
internal interface IBuiltInActuator
42+
{
43+
/// <summary>
44+
/// Return the corresponding BuiltInActuatorType for the actuator.
45+
/// </summary>
46+
/// <returns>A BuiltInActuatorType corresponding to the actuator.</returns>
47+
BuiltInActuatorType GetBuiltInActuatorType();
48+
}
49+
}

Diff for: com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs.meta

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace Unity.MLAgents.Actuators
55
/// <summary>
66
/// IActuator implementation that forwards calls to an <see cref="IActionReceiver"/> and an <see cref="IHeuristicProvider"/>.
77
/// </summary>
8-
internal class VectorActuator : IActuator, IHeuristicProvider
8+
internal class VectorActuator : IActuator, IHeuristicProvider, IBuiltInActuator
99
{
1010
IActionReceiver m_ActionReceiver;
1111
IHeuristicProvider m_HeuristicProvider;
@@ -95,5 +95,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
9595

9696
/// <inheritdoc />
9797
public string Name { get; }
98+
99+
/// <inheritdoc />
100+
public virtual BuiltInActuatorType GetBuiltInActuatorType()
101+
{
102+
return BuiltInActuatorType.VectorActuator;
103+
}
98104
}
99105
}

Diff for: com.unity.ml-agents/Runtime/Agent.cs

+20-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ public void CopyActions(ActionBuffers actionBuffers)
7070
}
7171
}
7272

73+
/// <summary>
74+
/// Simple wrapper around VectorActuator that overrides GetBuiltInActuatorType
75+
/// so that it can be distinguished from a standard VectorActuator.
76+
/// </summary>
77+
internal class AgentVectorActuator : VectorActuator
78+
{
79+
public AgentVectorActuator(IActionReceiver actionReceiver,
80+
IHeuristicProvider heuristicProvider,
81+
ActionSpec actionSpec,
82+
string name = "VectorActuator"
83+
) : base(actionReceiver, heuristicProvider, actionSpec, name)
84+
{ }
85+
86+
public override BuiltInActuatorType GetBuiltInActuatorType()
87+
{
88+
return BuiltInActuatorType.AgentVectorActuator;
89+
}
90+
}
91+
7392
/// <summary>
7493
/// An agent is an actor that can observe its environment, decide on the
7594
/// best course of action using those observations, and execute those actions
@@ -997,7 +1016,7 @@ void InitializeActuators()
9971016
// Support legacy OnActionReceived
9981017
// TODO don't set this up if the sizes are 0?
9991018
var param = m_PolicyFactory.BrainParameters;
1000-
m_VectorActuator = new VectorActuator(this, this, param.ActionSpec);
1019+
m_VectorActuator = new AgentVectorActuator(this, this, param.ActionSpec);
10011020
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
10021021
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()];
10031022
m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()];

Diff for: com.unity.ml-agents/Runtime/Analytics/Events.cs

+31
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ internal struct InferenceEvent
2121
public int InferenceDevice;
2222
public List<EventObservationSpec> ObservationSpecs;
2323
public EventActionSpec ActionSpec;
24+
public List<EventActuatorInfo> ActuatorInfos;
2425
public int MemorySize;
2526
public long TotalWeightSizeBytes;
2627
public string ModelHash;
@@ -48,6 +49,35 @@ public static EventActionSpec FromActionSpec(ActionSpec actionSpec)
4849
}
4950
}
5051

52+
/// <summary>
53+
/// Information about an actuator.
54+
/// </summary>
55+
[Serializable]
56+
internal struct EventActuatorInfo
57+
{
58+
public int BuiltInActuatorType;
59+
public int NumContinuousActions;
60+
public int NumDiscreteActions;
61+
62+
public static EventActuatorInfo FromActuator(IActuator actuator)
63+
{
64+
BuiltInActuatorType builtInActuatorType = Actuators.BuiltInActuatorType.Unknown;
65+
if (actuator is IBuiltInActuator builtInActuator)
66+
{
67+
builtInActuatorType = builtInActuator.GetBuiltInActuatorType();
68+
}
69+
70+
var actionSpec = actuator.ActionSpec;
71+
72+
return new EventActuatorInfo
73+
{
74+
BuiltInActuatorType = (int)builtInActuatorType,
75+
NumContinuousActions = actionSpec.NumContinuousActions,
76+
NumDiscreteActions = actionSpec.NumDiscreteActions
77+
};
78+
}
79+
}
80+
5181
/// <summary>
5282
/// Information about one dimension of an observation.
5383
/// </summary>
@@ -101,6 +131,7 @@ internal struct RemotePolicyInitializedEvent
101131
public string BehaviorName;
102132
public List<EventObservationSpec> ObservationSpecs;
103133
public EventActionSpec ActionSpec;
134+
public List<EventActuatorInfo> ActuatorInfos;
104135

105136
/// <summary>
106137
/// This will be the same as TrainingEnvironmentInitializedEvent if available, but

Diff for: com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs

+14-4
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ public static bool IsAnalyticsEnabled()
8787
/// <param name="inferenceDevice">Whether inference is being performed on the CPU or GPU</param>
8888
/// <param name="sensors">List of ISensors for the Agent. Used to generate information about the observation space.</param>
8989
/// <param name="actionSpec">ActionSpec for the Agent. Used to generate information about the action space.</param>
90+
/// <param name="actuators">List of IActuators for the Agent. Used to generate information about the action space.</param>
9091
/// <returns></returns>
9192
public static void InferenceModelSet(
9293
NNModel nnModel,
9394
string behaviorName,
9495
InferenceDevice inferenceDevice,
9596
IList<ISensor> sensors,
96-
ActionSpec actionSpec
97+
ActionSpec actionSpec,
98+
IList<IActuator> actuators
9799
)
98100
{
99101
// The event shouldn't be able to report if this is disabled but if we know we're not going to report
@@ -112,9 +114,9 @@ ActionSpec actionSpec
112114
return;
113115
}
114116

115-
var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec);
117+
var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec, actuators);
116118
// Note - to debug, use JsonUtility.ToJson on the event.
117-
//Debug.Log(JsonUtility.ToJson(data, true));
119+
// Debug.Log(JsonUtility.ToJson(data, true));
118120
#if UNITY_EDITOR
119121
if (AnalyticsUtils.s_SendEditorAnalytics)
120122
{
@@ -133,13 +135,15 @@ ActionSpec actionSpec
133135
/// <param name="inferenceDevice"></param>
134136
/// <param name="sensors"></param>
135137
/// <param name="actionSpec"></param>
138+
/// <param name="actuators"></param>
136139
/// <returns></returns>
137140
internal static InferenceEvent GetEventForModel(
138141
NNModel nnModel,
139142
string behaviorName,
140143
InferenceDevice inferenceDevice,
141144
IList<ISensor> sensors,
142-
ActionSpec actionSpec
145+
ActionSpec actionSpec,
146+
IList<IActuator> actuators
143147
)
144148
{
145149
var barracudaModel = ModelLoader.Load(nnModel);
@@ -175,6 +179,12 @@ ActionSpec actionSpec
175179
inferenceEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor));
176180
}
177181

182+
inferenceEvent.ActuatorInfos = new List<EventActuatorInfo>(actuators.Count);
183+
foreach (var actuator in actuators)
184+
{
185+
inferenceEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator));
186+
}
187+
178188
inferenceEvent.TotalWeightSizeBytes = GetModelWeightSize(barracudaModel);
179189
inferenceEvent.ModelHash = GetModelHash(barracudaModel);
180190
return inferenceEvent;

Diff for: com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs

+13-4
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitialized
139139
public static void RemotePolicyInitialized(
140140
string fullyQualifiedBehaviorName,
141141
IList<ISensor> sensors,
142-
ActionSpec actionSpec
142+
ActionSpec actionSpec,
143+
IList<IActuator> actuators
143144
)
144145
{
145146
if (!IsAnalyticsEnabled())
@@ -158,7 +159,7 @@ ActionSpec actionSpec
158159
return;
159160
}
160161

161-
var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec);
162+
var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec, actuators);
162163
// Note - to debug, use JsonUtility.ToJson on the event.
163164
// Debug.Log(
164165
// $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}"
@@ -220,10 +221,12 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
220221
#endif
221222
}
222223

223-
static RemotePolicyInitializedEvent GetEventForRemotePolicy(
224+
internal static RemotePolicyInitializedEvent GetEventForRemotePolicy(
224225
string behaviorName,
225226
IList<ISensor> sensors,
226-
ActionSpec actionSpec)
227+
ActionSpec actionSpec,
228+
IList<IActuator> actuators
229+
)
227230
{
228231
var remotePolicyEvent = new RemotePolicyInitializedEvent();
229232

@@ -238,6 +241,12 @@ static RemotePolicyInitializedEvent GetEventForRemotePolicy(
238241
remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor));
239242
}
240243

244+
remotePolicyEvent.ActuatorInfos = new List<EventActuatorInfo>(actuators.Count);
245+
foreach (var actuator in actuators)
246+
{
247+
remotePolicyEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator));
248+
}
249+
241250
remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion;
242251
remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion;
243252
return remotePolicyEvent;

Diff for: com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs

+9-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ internal class BarracudaPolicy : IPolicy
4848

4949
private string m_BehaviorName;
5050

51+
/// <summary>
52+
/// List of actuators, only used for analytics
53+
/// </summary>
54+
private IList<IActuator> m_Actuators;
55+
5156
/// <summary>
5257
/// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy,
5358
/// and do additional deduplication in the analytics code.
@@ -57,6 +62,7 @@ internal class BarracudaPolicy : IPolicy
5762
/// <inheritdoc />
5863
public BarracudaPolicy(
5964
ActionSpec actionSpec,
65+
IList<IActuator> actuators,
6066
NNModel model,
6167
InferenceDevice inferenceDevice,
6268
string behaviorName
@@ -66,6 +72,7 @@ string behaviorName
6672
m_ModelRunner = modelRunner;
6773
m_BehaviorName = behaviorName;
6874
m_ActionSpec = actionSpec;
75+
m_Actuators = actuators;
6976
}
7077

7178
/// <inheritdoc />
@@ -79,7 +86,8 @@ public void RequestDecision(AgentInfo info, List<ISensor> sensors)
7986
m_BehaviorName,
8087
m_ModelRunner.InferenceDevice,
8188
sensors,
82-
m_ActionSpec
89+
m_ActionSpec,
90+
m_Actuators
8391
);
8492
}
8593
m_AgentId = info.episodeId;

Diff for: com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -221,16 +221,16 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM
221221
"Either assign a model, or change to a different Behavior Type."
222222
);
223223
}
224-
return new BarracudaPolicy(actionSpec, m_Model, m_InferenceDevice, m_BehaviorName);
224+
return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
225225
}
226226
case BehaviorType.Default:
227227
if (Academy.Instance.IsCommunicatorOn)
228228
{
229-
return new RemotePolicy(actionSpec, FullyQualifiedBehaviorName);
229+
return new RemotePolicy(actionSpec, actuatorManager, FullyQualifiedBehaviorName);
230230
}
231231
if (m_Model != null)
232232
{
233-
return new BarracudaPolicy(actionSpec, m_Model, m_InferenceDevice, m_BehaviorName);
233+
return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
234234
}
235235
else
236236
{

0 commit comments

Comments
 (0)