From 9c880172bc26f126f2faf381f41a81253670e034 Mon Sep 17 00:00:00 2001 From: JulioJerez Date: Tue, 24 Sep 2024 21:17:52 -0700 Subject: [PATCH] ok try conditional reward, not sure how will go. --- .../demos/ndAdvancedIndustrialRobot.cpp | 71 ++++++++++--------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp b/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp index 808f095f0..630d4641d 100644 --- a/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp +++ b/newton-4.00/applications/ndSandbox/demos/ndAdvancedIndustrialRobot.cpp @@ -24,7 +24,7 @@ namespace ndAdvancedRobot { - //#define ND_TRAIN_MODEL + #define ND_TRAIN_MODEL #define CONTROLLER_NAME "ndRobotArmReach" //#define CONTROLLER_RESUME_TRAINING @@ -32,7 +32,7 @@ namespace ndAdvancedRobot class ndActionVector { public: - ndBrainFloat m_actions[5]; + ndBrainFloat m_actions[6]; }; class ndObservationVector @@ -53,8 +53,10 @@ namespace ndAdvancedRobot ndBrainFloat m_delta_Azimuth; ndBrainFloat m_target_Azimuth; - ndBrainFloat m_sourcePin[3]; - ndBrainFloat m_targetPin[3]; + ndBrainFloat m_sourceSidePin[3]; + ndBrainFloat m_targetSidePin[3]; + ndBrainFloat m_sourceFrontPin[3]; + ndBrainFloat m_targetFrontPin[3]; }; class ndControlParameters @@ -532,17 +534,6 @@ namespace ndAdvancedRobot ndFloat32 CalculateDeltaTargetRotation(const ndMatrix& currentEffectorMatrix) const { - //const ndMatrix targetMatrix(ndPitchMatrix(m_targetLocation.m_pitch) * ndYawMatrix(m_targetLocation.m_yaw) * ndRollMatrix(m_targetLocation.m_roll)); - //const ndQuaternion targetRotation(targetMatrix); - //ndQuaternion currentRotation(currentEffectorMatrix); - //if (currentRotation.DotProduct(targetRotation).GetScalar() < 0.0f) - //{ - // currentRotation = currentRotation.Scale(-1.0f); - //} - // - //const ndVector omega(currentRotation.CalcAverageOmega(targetRotation, 1.0f)); - //return omega; - const ndMatrix targetMatrix(ndPitchMatrix(m_targetLocation.m_pitch) * ndYawMatrix(m_targetLocation.m_yaw) * ndRollMatrix(m_targetLocation.m_roll)); const ndMatrix relativeRotation(currentEffectorMatrix * targetMatrix.OrthoInverse()); ndFloat32 angleCos = currentEffectorMatrix.m_front.DotProduct(targetMatrix.m_front).GetScalar(); @@ -576,15 +567,28 @@ namespace ndAdvancedRobot return param * param * param * param; }; - ndFloat32 rewardWeigh = 1.0f / 4.0f; + ndFloat32 rewardWeigh = 1.0f / 5.0f; ndFloat32 posit_xReward = rewardWeigh * ScalarReward(positError2.m_x); ndFloat32 posit_zReward = rewardWeigh * ScalarReward(positError2.m_z); ndFloat32 azimuthReward = rewardWeigh * ScalarReward(positError2.m_w); - ndFloat32 angleError = CalculateDeltaTargetRotation(currentEffectorMatrix); - ndFloat32 angularReward = rewardWeigh * GaussianReward((angleError + 1.0f) * 0.5f); - return angularReward + posit_xReward + posit_zReward + azimuthReward; + //ndFloat32 angleError = CalculateDeltaTargetRotation(currentEffectorMatrix); + const ndMatrix targetMatrix(ndPitchMatrix(m_targetLocation.m_pitch) * ndYawMatrix(m_targetLocation.m_yaw) * ndRollMatrix(m_targetLocation.m_roll)); + const ndMatrix relativeRotation(currentEffectorMatrix * targetMatrix.OrthoInverse()); + ndFloat32 sideCos = currentEffectorMatrix.m_up.DotProduct(targetMatrix.m_up).GetScalar(); + ndFloat32 frontCos = currentEffectorMatrix.m_front.DotProduct(targetMatrix.m_front).GetScalar(); + + ndFloat32 angularReward0 = rewardWeigh * GaussianReward((sideCos + 1.0f) * 0.5f); + ndFloat32 angularReward1 = rewardWeigh * GaussianReward((frontCos + 1.0f) * 0.5f); + + ndFloat32 reward = angularReward0 + angularReward1; + if ((angularReward0 > 0.195) && (angularReward1 > 0.195f)) + { + reward = reward + posit_xReward + posit_zReward + azimuthReward; + } + //return angularReward + posit_xReward + posit_zReward + azimuthReward; //return GaussianReward((angleError + 1.0f) * 0.5f);; + return reward; } #pragma optimize( "", off ) @@ -618,16 +622,21 @@ namespace ndAdvancedRobot observation->m_target_z = ndBrainFloat(m_targetLocation.m_z); observation->m_target_Azimuth = ndBrainFloat(m_targetLocation.m_azimuth); - //ndFloat32 angleError(CalculateDeltaTargetRotation(currentEffectorMatrix)); - //observation->m_deltaRotation = ndBrainFloat(angleError); - observation->m_sourcePin[0] = ndBrainFloat(currentEffectorMatrix.m_front.m_x); - observation->m_sourcePin[1] = ndBrainFloat(currentEffectorMatrix.m_front.m_y); - observation->m_sourcePin[2] = ndBrainFloat(currentEffectorMatrix.m_front.m_z); + observation->m_sourceFrontPin[0] = ndBrainFloat(currentEffectorMatrix.m_front.m_x); + observation->m_sourceFrontPin[1] = ndBrainFloat(currentEffectorMatrix.m_front.m_y); + observation->m_sourceFrontPin[2] = ndBrainFloat(currentEffectorMatrix.m_front.m_z); + observation->m_sourceSidePin[0] = ndBrainFloat(currentEffectorMatrix.m_up.m_x); + observation->m_sourceSidePin[1] = ndBrainFloat(currentEffectorMatrix.m_up.m_y); + observation->m_sourceSidePin[2] = ndBrainFloat(currentEffectorMatrix.m_up.m_z); const ndMatrix targetMatrix(ndPitchMatrix(m_targetLocation.m_pitch) * ndYawMatrix(m_targetLocation.m_yaw) * ndRollMatrix(m_targetLocation.m_roll)); - observation->m_targetPin[0] = ndBrainFloat(targetMatrix.m_front.m_x); - observation->m_targetPin[1] = ndBrainFloat(targetMatrix.m_front.m_y); - observation->m_targetPin[2] = ndBrainFloat(targetMatrix.m_front.m_z); + observation->m_targetFrontPin[0] = ndBrainFloat(targetMatrix.m_front.m_x); + observation->m_targetFrontPin[1] = ndBrainFloat(targetMatrix.m_front.m_y); + observation->m_targetFrontPin[2] = ndBrainFloat(targetMatrix.m_front.m_z); + observation->m_targetSidePin[0] = ndBrainFloat(targetMatrix.m_up.m_x); + observation->m_targetSidePin[1] = ndBrainFloat(targetMatrix.m_up.m_y); + observation->m_targetSidePin[2] = ndBrainFloat(targetMatrix.m_up.m_z); + } //#pragma optimize( "", off ) @@ -641,16 +650,12 @@ namespace ndAdvancedRobot hinge->SetTargetAngle(targetAngle); }; - //SetParamter(m_arm_2, 2); - //SetParamter(m_arm_3, 3); - SetParamter(m_arm_0, 0); SetParamter(m_arm_1, 1); SetParamter(m_arm_2, 2); SetParamter(m_arm_3, 3); - //SetParamter(m_arm_4, 4); - //SetParamter(m_base_rotator, 5); - SetParamter(m_base_rotator, 4); + SetParamter(m_arm_4, 4); + SetParamter(m_base_rotator, 5); } void CheckModelStability()