diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitor.java b/src/main/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitor.java index 32000315c..85d9d96f7 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitor.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitor.java @@ -15,6 +15,9 @@ import com.aws.greengrass.mqttclient.WrapperMqttClientConnection; import com.aws.greengrass.util.Coerce; import com.aws.greengrass.util.RetryUtils; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.Setter; import software.amazon.awssdk.crt.mqtt.MqttClientConnection; import software.amazon.awssdk.crt.mqtt.MqttException; import software.amazon.awssdk.crt.mqtt.QualityOfService; @@ -33,6 +36,7 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -48,6 +52,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import javax.inject.Inject; @@ -56,7 +61,6 @@ public class CISShadowMonitor implements Consumer getShadowTask; + AtomicReference> shadowReceived = new AtomicReference<>(); + private MqttClientConnection connection; private IotShadowClient iotShadowClient; @@ -113,7 +133,7 @@ public void startMonitor() { subscribeTaskFuture = executorService.submit(() -> { try { subscribeToShadowTopics(); - publishToGetCISShadowTopic(); + fetchCISShadowWithRetriesAsync(); } catch (InterruptedException e) { LOGGER.atWarn().cause(e).log("Interrupted while subscribing to CIS shadow topics"); Thread.currentThread().interrupt(); @@ -158,21 +178,40 @@ private void subscribeToShadowTopics() throws InterruptedException { ShadowDeltaUpdatedSubscriptionRequest shadowDeltaUpdatedSubscriptionRequest = new ShadowDeltaUpdatedSubscriptionRequest(); shadowDeltaUpdatedSubscriptionRequest.thingName = shadowName; - iotShadowClient.SubscribeToShadowDeltaUpdatedEvents(shadowDeltaUpdatedSubscriptionRequest, + iotShadowClient.SubscribeToShadowDeltaUpdatedEvents( + shadowDeltaUpdatedSubscriptionRequest, QualityOfService.AT_LEAST_ONCE, - this::processCISShadow, + resp -> { + reportShadowReceived(); + processCISShadow(resp); + }, (e) -> LOGGER.atError() .log("Error processing shadowDeltaUpdatedSubscription Response", e)) - .get(TIMEOUT_FOR_SUBSCRIBING_TO_TOPICS_SECONDS, TimeUnit.SECONDS); + .get(mqttOperationTimeoutSeconds, TimeUnit.SECONDS); LOGGER.info("Subscribed to shadow update delta topic"); GetShadowSubscriptionRequest getShadowSubscriptionRequest = new GetShadowSubscriptionRequest(); getShadowSubscriptionRequest.thingName = shadowName; - iotShadowClient.SubscribeToGetShadowAccepted(getShadowSubscriptionRequest, - QualityOfService.AT_LEAST_ONCE, this::processCISShadow, + iotShadowClient.SubscribeToGetShadowAccepted( + getShadowSubscriptionRequest, + QualityOfService.AT_LEAST_ONCE, + resp -> { + reportShadowReceived(); + processCISShadow(resp); + }, (e) -> LOGGER.atError().log("Error processing getShadowSubscription Response", e)) - .get(TIMEOUT_FOR_SUBSCRIBING_TO_TOPICS_SECONDS, TimeUnit.SECONDS); + .get(mqttOperationTimeoutSeconds, TimeUnit.SECONDS); LOGGER.info("Subscribed to shadow get accepted topic"); + + GetShadowSubscriptionRequest getShadowRejectedSubscriptionRequest = new GetShadowSubscriptionRequest(); + getShadowRejectedSubscriptionRequest.thingName = shadowName; + iotShadowClient.SubscribeToGetShadowRejected( + getShadowRejectedSubscriptionRequest, + QualityOfService.AT_LEAST_ONCE, + err -> reportShadowReceived(), + (e) -> LOGGER.atError().log("Error processing get shadow rejected response", e)) + .get(mqttOperationTimeoutSeconds, TimeUnit.SECONDS); + LOGGER.info("Subscribed to shadow get rejected topic"); return; } catch (ExecutionException e) { @@ -281,6 +320,7 @@ private synchronized void processCISShadow(String version, Map d for (CertificateGenerator cg : monitoredCertificateGenerators) { cg.generateCertificate(() -> new ArrayList<>(cachedHostAddresses), "connectivity info was updated"); } + LOGGER.atDebug().log("certificates rotated"); } catch (CertificateGenerationException e) { LOGGER.atError().kv(VERSION, version).cause(e).log("Failed to generate new certificates"); return; @@ -315,14 +355,60 @@ private void updateCISShadowReportedState(Map reportedState) { }); } - private void publishToGetCISShadowTopic() { + private void reportShadowReceived() { + CompletableFuture shadowReceived = this.shadowReceived.get(); + if (shadowReceived != null) { + shadowReceived.complete(null); + } + } + + @SuppressWarnings("PMD.AvoidCatchingGenericException") + private void fetchCISShadowWithRetriesAsync() { + synchronized (getShadowLock) { + if (getShadowTask != null && !getShadowTask.isDone()) { + // operation already in progress + return; + } + getShadowTask = executorService.submit(() -> { + try { + RetryUtils.runWithRetry( + GET_CIS_SHADOW_RETRY_CONFIG, + () -> { + CompletableFuture shadowReceived = + this.shadowReceived.updateAndGet(ignore -> new CompletableFuture<>()); + publishToGetCISShadowTopic().get(mqttOperationTimeoutSeconds, TimeUnit.SECONDS); + // await shadow get accepted, rejected, or update delta + shadowReceived.get(mqttOperationTimeoutSeconds, TimeUnit.SECONDS); + return null; + }, + "get-cis-shadow", + LOGGER); + } catch (InterruptedException ignore) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + LOGGER.atError().cause(e).log("unable to get CIS shadow"); + } + }); + } + } + + private void cancelGetCISShadow() { + synchronized (getShadowLock) { + if (getShadowTask != null) { + getShadowTask.cancel(true); + } + } + } + + private CompletableFuture publishToGetCISShadowTopic() { LOGGER.atDebug().log("Publishing to get shadow topic"); GetShadowRequest getShadowRequest = new GetShadowRequest(); getShadowRequest.thingName = shadowName; - iotShadowClient.PublishGetShadow(getShadowRequest, QualityOfService.AT_LEAST_ONCE).exceptionally(e -> { - LOGGER.atWarn().cause(e).log("Unable to retrieve CIS shadow"); - return null; - }); + return iotShadowClient.PublishGetShadow(getShadowRequest, QualityOfService.AT_LEAST_ONCE) + .exceptionally(e -> { + LOGGER.atWarn().cause(e).log("Unable to retrieve CIS shadow"); + return null; + }); } private void unsubscribeFromShadowTopics() { @@ -339,7 +425,9 @@ private void unsubscribeFromShadowTopics() { @Override public void accept(NetworkStateProvider.ConnectionState state) { if (state == NetworkStateProvider.ConnectionState.NETWORK_UP) { - publishToGetCISShadowTopic(); + fetchCISShadowWithRetriesAsync(); + } else if (state == NetworkStateProvider.ConnectionState.NETWORK_DOWN) { + cancelGetCISShadow(); } } } diff --git a/src/test/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitorTest.java b/src/test/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitorTest.java index a197930cc..f7d29ae2b 100644 --- a/src/test/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitorTest.java +++ b/src/test/java/com/aws/greengrass/clientdevices/auth/connectivity/CISShadowMonitorTest.java @@ -8,6 +8,7 @@ import com.aws.greengrass.clientdevices.auth.certificate.CertificateGenerator; import com.aws.greengrass.clientdevices.auth.exception.CertificateGenerationException; import com.aws.greengrass.clientdevices.auth.helpers.TestHelpers; +import com.aws.greengrass.logging.impl.config.LogConfig; import com.aws.greengrass.testcommons.testutilities.GGExtension; import com.aws.greengrass.util.Pair; import com.aws.greengrass.util.Utils; @@ -26,6 +27,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.slf4j.event.Level; import software.amazon.awssdk.crt.mqtt.MqttClientConnection; import software.amazon.awssdk.crt.mqtt.MqttMessage; import software.amazon.awssdk.crt.mqtt.QualityOfService; @@ -42,6 +44,7 @@ import javax.annotation.Nonnull; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.Date; @@ -53,11 +56,11 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -81,7 +84,6 @@ class CISShadowMonitorTest { private static final ObjectMapper MAPPER = new ObjectMapper(); private static final String SHADOW_NAME = "testThing-gci"; private static final String UPDATE_SHADOW_TOPIC = String.format("$aws/things/%s/shadow/update", SHADOW_NAME); - private static final String GET_SHADOW_TOPIC = String.format("$aws/things/%s/shadow/get", SHADOW_NAME); private final FakeIotShadowClient shadowClient = spy(new FakeIotShadowClient()); private final MqttClientConnection shadowClientConnection = shadowClient.getConnection(); private final ExecutorService executor = Executors.newCachedThreadPool(); @@ -94,6 +96,7 @@ class CISShadowMonitorTest { @BeforeEach void setup() { + LogConfig.getRootLogConfig().setLevel(Level.DEBUG); cisShadowMonitor = new CISShadowMonitor( shadowClientConnection, shadowClient, @@ -101,6 +104,8 @@ void setup() { SHADOW_NAME, connectivityInfoProvider ); + // avoid unnecessary waiting + cisShadowMonitor.setMqttOperationTimeoutSeconds(1L); } @AfterEach @@ -133,10 +138,9 @@ static class Scenario { int numShadowUpdatePublishFailures; /** - * Amount of times to fail monitor's attempts - * to get the CIS shadow. + * If present, shadow will be created after the monitor starts. */ - int numGetRequestFailures; + Duration createShadowAfterDelay; /** * If true, simulate monitor receiving duplicate @@ -164,7 +168,7 @@ public static Stream cisShadowMonitorScenarios() { // when monitor can't get shadow on startup, // it'll recover on subsequent shadow updates Arguments.of(Scenario.builder() - .numGetRequestFailures(1) + .createShadowAfterDelay(Duration.ofSeconds(2L)) .serialShadowUpdates(true) .build()), // if shadow is never updated, @@ -172,15 +176,14 @@ public static Stream cisShadowMonitorScenarios() { Arguments.of(Scenario.builder() .numShadowUpdates(0) .build()), - // TODO add support in CISShadowMonitor // if shadow is never updated, // monitor still works because it fetches shadow on startup. // if shadow fetching fails, it will be retried -// Arguments.of(Scenario.builder() -// .numShadowUpdates(0) -// .numGetRequestFailures(1) -// .serialShadowUpdates(true) -// .build()), + Arguments.of(Scenario.builder() + .numShadowUpdates(0) + .createShadowAfterDelay(Duration.ofSeconds(2L)) + .serialShadowUpdates(true) + .build()), Arguments.of(Scenario.builder() .numShadowUpdatePublishFailures(1) .serialShadowUpdates(true) @@ -208,6 +211,7 @@ public static Stream cisShadowMonitorScenarios() { void GIVEN_monitor_WHEN_cis_shadow_changes_THEN_monitor_updates_certificates(Scenario scenario, ExtensionContext context) throws Exception { ignoreExceptionOfType(context, ValidationException.class); ignoreExceptionOfType(context, RuntimeException.class); + ignoreExceptionOfType(context, TimeoutException.class); connectivityInfoProvider.setMode(scenario.getConnectivityProviderMode()); @@ -216,26 +220,40 @@ void GIVEN_monitor_WHEN_cis_shadow_changes_THEN_monitor_updates_certificates(Sce // and is used by this test to (optionally) feed messages to the monitor serially AtomicReference updateProcessedByMonitor = updateProcessedByMonitor(); - // set initial shadow state. - // the shadow delta triggered by this update will be ignored - // since the monitor is not listening at this point - // TODO handle case where shadow doesn't exist on startup - updateShadowDesiredState( + Runnable putInitialShadowState = () -> updateShadowDesiredState( Utils.immutableMap("version", "-1"), scenario.isReceiveDuplicateShadowDeltaUpdates() ); - shadowClient.failOnGet(GET_SHADOW_TOPIC, scenario.getNumGetRequestFailures()); + // optional, handle case where shadow exists before monitor starts up + if (scenario.getCreateShadowAfterDelay() == null) { + putInitialShadowState.run(); + } cisShadowMonitor.addToMonitor(certificateGenerator); cisShadowMonitor.startMonitor(); + // optional, handle case where shadow is created AFTER monitor is started + if (scenario.getCreateShadowAfterDelay() != null) { + if (!scenario.isSerialShadowUpdates()) { + fail("initializing shadow after monitor startup " + + "only supported when shadow updates are serialized, otherwise there's no point"); + } + executor.submit(() -> { + try { + Thread.sleep(scenario.getCreateShadowAfterDelay().toMillis()); + } catch (InterruptedException e) { + return; + } + putInitialShadowState.run(); + }); + } + // on startup, the monitor directly requests a shadow and processes it. // optionally wait for the monitor to process the get shadow response. if (scenario.isSerialShadowUpdates()) { boolean monitorExpectedToUpdateReportedState = - scenario.getConnectivityProviderMode() != FakeConnectivityInformation.Mode.FAIL_ONCE - && scenario.getNumGetRequestFailures() == 0; + scenario.getConnectivityProviderMode() != FakeConnectivityInformation.Mode.FAIL_ONCE; waitForMonitorToProcessUpdate(updateProcessedByMonitor, monitorExpectedToUpdateReportedState); } @@ -286,7 +304,7 @@ private void waitForMonitorToProcessUpdate(AtomicReference updat if (monitorExpectedToUpdateReportedState) { // wait for monitor to update shadow state, which // means that it has finished processing that particular shadow version - assertTrue(updateProcessedByMonitor.get().await(5L, TimeUnit.SECONDS)); + assertTrue(updateProcessedByMonitor.get().await(15L, TimeUnit.SECONDS)); } else { // monitor will not update the shadow state, // so we don't have a way to know when the monitor has completed its work @@ -320,7 +338,7 @@ private void assertShadowEventuallyEquals(Map desired, Map> CONNECTIVITY_INFO_SAMPLE = new AtomicReference<>(Collections.singletonList(connectivityInfoWithRandomHost())); - private final Set responseHashes = new CopyOnWriteArraySet<>(); + + @Getter + private int numConnectivityInfoChanges; + private Set prevAddresses; private final AtomicReference mode = new AtomicReference<>(Mode.RANDOM); private final AtomicBoolean failed = new AtomicBoolean(); @@ -631,15 +652,6 @@ void setMode(Mode mode) { this.mode.set(mode); } - /** - * Get the number of unique responses to getConnectivityInfo provided by this fake. - * - * @return number of unique connectivity info responses generated by this fake - */ - int getNumUniqueConnectivityInfoResponses() { - return responseHashes.size(); - } - @Override public Optional> getConnectivityInfo() { List connectivityInfo = doGetConnectivityInfo(); @@ -648,7 +660,11 @@ public Optional> getConnectivityInfo() { .map(HostAddress::of) .collect(Collectors.toSet()); getConnectivityInfoCache().put("source", addresses); - responseHashes.add(addresses.hashCode()); + Set prevAddresses = this.prevAddresses; + this.prevAddresses = addresses; + if (!Objects.equals(addresses, prevAddresses)) { + numConnectivityInfoChanges++; + } } return Optional.ofNullable(connectivityInfo); }