Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cda work queue #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package com.aws.greengrass.integrationtests.ipc;

import com.aws.greengrass.clientdevices.auth.infra.CDAExecutor;
import com.aws.greengrass.dependency.State;
import com.aws.greengrass.clientdevices.auth.ClientDevicesAuthService;
import com.aws.greengrass.clientdevices.auth.exception.AuthenticationException;
Expand All @@ -27,6 +28,7 @@
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.aws.greengrass.GetClientDeviceAuthTokenResponseHandler;
import software.amazon.awssdk.aws.greengrass.GreengrassCoreIPCClient;
Expand All @@ -47,21 +49,19 @@
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static com.aws.greengrass.componentmanager.KernelConfigResolver.CONFIGURATION_CONFIG_KEY;
import static com.aws.greengrass.clientdevices.auth.ClientDevicesAuthService.CLOUD_REQUEST_QUEUE_SIZE_TOPIC;
import static com.aws.greengrass.clientdevices.auth.ClientDevicesAuthService.PERFORMANCE_TOPIC;
import static com.aws.greengrass.testcommons.testutilities.ExceptionLogProtector.ignoreExceptionOfType;
import static com.aws.greengrass.testcommons.testutilities.TestUtils.asyncAssertOnConsumer;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

@ExtendWith({GGExtension.class, UniqueRootPathExtension.class, MockitoExtension.class})
Expand Down Expand Up @@ -159,8 +159,15 @@ void GIVEN_brokerWithValidCredentials_WHEN_GetClientDeviceAuthToken_THEN_returns
@Test
void GIVEN_brokerWithInvalidCredentials_WHEN_GetClientDeviceAuthToken_THEN_throwsInvalidArgumentsError_and_WHEN_queueIsFull_THEN_throwsServiceError()
throws Exception {
// Inject work queue which will reject any work added
LinkedBlockingQueue<Runnable> mockQueue = Mockito.mock(LinkedBlockingQueue.class);
when(mockQueue.offer(any())).thenReturn(false);
ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 1, TimeUnit.SECONDS, mockQueue);
kernel.getContext().put(CDAExecutor.class, new CDAExecutor(executor));
kernel.getContext().put(SessionManager.class, sessionManager);

startNucleusWithConfig("cda.yaml");

try (EventStreamRPCConnection connection = IPCTestUtils.getEventStreamRpcConnection(kernel,
"BrokerWithGetClientDeviceAuthTokenPermission")) {
GreengrassCoreIPCClient ipcClient = new GreengrassCoreIPCClient(connection);
Expand All @@ -174,11 +181,6 @@ void GIVEN_brokerWithInvalidCredentials_WHEN_GetClientDeviceAuthToken_THEN_throw
assertThat(err.getCause().getMessage(), containsString("Invalid client device credentials"));
}

// Update the cloud queue size to 1 so that we'll just reject the second request
kernel.findServiceTopic(ClientDevicesAuthService.CLIENT_DEVICES_AUTH_SERVICE_NAME)
.lookup(CONFIGURATION_CONFIG_KEY, PERFORMANCE_TOPIC, CLOUD_REQUEST_QUEUE_SIZE_TOPIC).withValue(1);
kernel.getContext().waitForPublishQueueToClear();

// Verify that we get a good error that the request couldn't be queued
try (EventStreamRPCConnection connection = IPCTestUtils.getEventStreamRpcConnection(kernel,
"BrokerWithGetClientDeviceAuthTokenPermission")) {
Expand All @@ -187,27 +189,10 @@ void GIVEN_brokerWithInvalidCredentials_WHEN_GetClientDeviceAuthToken_THEN_throw
new CredentialDocument().withMqttCredential(
new MQTTCredential().withClientId("some-client-id").withCertificatePem("VALID PEM")));

CountDownLatch cdl = new CountDownLatch(1);
when(sessionManager.createSession(anyString(), anyMap())).thenAnswer((a) -> {
cdl.countDown();
Thread.sleep(1_000); // slow down the first request so that the second will be rejected
return "uuid";
});
// Request 1 (immediately runs)
CompletableFuture<GetClientDeviceAuthTokenResponse> fut1 =
ipcClient.getClientDeviceAuthToken(request, Optional.empty()).getResponse();
// Ensure the threadpool is actively blocked before we send the next requests to fill the queue and then
// overflow the queue.
cdl.await(2, TimeUnit.SECONDS);
// Request 2 (queued so that queue size is 1)
CompletableFuture<GetClientDeviceAuthTokenResponse> fut2 =
ipcClient.getClientDeviceAuthToken(request, Optional.empty()).getResponse();
// Request 3 (expect rejection)
// Expect rejection
Exception err = assertThrows(Exception.class, () -> clientDeviceAuthToken(ipcClient, request, (r) -> {}));
assertThat(err.getCause().getMessage(), containsString("Unable to queue request"));
assertEquals(ServiceError.class, err.getCause().getClass());
fut1.get(2, TimeUnit.SECONDS);
fut2.get(2, TimeUnit.SECONDS);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import com.aws.greengrass.clientdevices.auth.configuration.CDAConfiguration;
import com.aws.greengrass.clientdevices.auth.configuration.GroupConfiguration;
import com.aws.greengrass.clientdevices.auth.configuration.GroupManager;
import com.aws.greengrass.clientdevices.auth.configuration.InfrastructureConfiguration;
import com.aws.greengrass.clientdevices.auth.connectivity.CISShadowMonitor;
import com.aws.greengrass.clientdevices.auth.infra.CDAExecutor;
import com.aws.greengrass.clientdevices.auth.infra.NetworkState;
import com.aws.greengrass.clientdevices.auth.session.MqttSessionFactory;
import com.aws.greengrass.clientdevices.auth.session.SessionConfig;
Expand All @@ -30,7 +32,6 @@
import com.aws.greengrass.ipc.SubscribeToCertificateUpdatesOperationHandler;
import com.aws.greengrass.ipc.VerifyClientDeviceIdentityOperationHandler;
import com.aws.greengrass.lifecyclemanager.PluginService;
import com.aws.greengrass.util.Coerce;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import software.amazon.awssdk.aws.greengrass.GreengrassCoreIPCService;
Expand All @@ -45,6 +46,8 @@
import java.util.concurrent.TimeUnit;
import javax.inject.Inject;

import static com.aws.greengrass.clientdevices.auth.configuration.InfrastructureConfiguration.DEFAULT_THREAD_POOL_SIZE;
import static com.aws.greengrass.clientdevices.auth.configuration.InfrastructureConfiguration.DEFAULT_WORK_QUEUE_DEPTH;
import static com.aws.greengrass.componentmanager.KernelConfigResolver.CONFIGURATION_CONFIG_KEY;
import static software.amazon.awssdk.aws.greengrass.GreengrassCoreIPCService.AUTHORIZE_CLIENT_DEVICE_ACTION;
import static software.amazon.awssdk.aws.greengrass.GreengrassCoreIPCService.GET_CLIENT_DEVICE_AUTH_TOKEN;
Expand All @@ -55,21 +58,11 @@
public class ClientDevicesAuthService extends PluginService {
public static final String CLIENT_DEVICES_AUTH_SERVICE_NAME = "aws.greengrass.clientdevices.Auth";

private CDAConfiguration cdaConfiguration;
private InfrastructureConfiguration infrastructureConfig;

// TODO: Move configuration related constants to appropriate configuration class
public static final String DEVICE_GROUPS_TOPICS = "deviceGroups";
public static final String PERFORMANCE_TOPIC = "performance";
public static final String MAX_ACTIVE_AUTH_TOKENS_TOPIC = "maxActiveAuthTokens";
public static final String CLOUD_REQUEST_QUEUE_SIZE_TOPIC = "cloudRequestQueueSize";
public static final String MAX_CONCURRENT_CLOUD_REQUESTS_TOPIC = "maxConcurrentCloudRequests";
// Limit the queue size before we start rejecting requests
private static final int DEFAULT_CLOUD_CALL_QUEUE_SIZE = 100;
private static final int DEFAULT_THREAD_POOL_SIZE = 1;
public static final int DEFAULT_MAX_ACTIVE_AUTH_TOKENS = 2500;

// Create a threadpool for calling the cloud. Single thread will be used by default.
private ThreadPoolExecutor cloudCallThreadPool;
private int cloudCallQueueSize;
private CDAConfiguration cdaConfiguration;


/**
Expand All @@ -88,32 +81,10 @@ protected void install() throws InterruptedException {

context.get(UseCases.class).init(context);
context.get(CertificateManager.class).updateCertificatesConfiguration(new CertificatesConfig(getConfig()));
initializeInfrastructure();
initializeHandlers();
subscribeToConfigChanges();
}

private int getValidCloudCallQueueSize(Topics topics) {
int newSize = Coerce.toInt(
topics.findOrDefault(DEFAULT_CLOUD_CALL_QUEUE_SIZE,
CONFIGURATION_CONFIG_KEY, PERFORMANCE_TOPIC, CLOUD_REQUEST_QUEUE_SIZE_TOPIC));
if (newSize <= 0) {
logger.atWarn().log("{} illegal size, will not change the queue size from {}",
CLOUD_REQUEST_QUEUE_SIZE_TOPIC, cloudCallQueueSize);
return cloudCallQueueSize; // existing size
}
return newSize;
}

private void initializeInfrastructure() {
cloudCallQueueSize = DEFAULT_CLOUD_CALL_QUEUE_SIZE;
cloudCallQueueSize = getValidCloudCallQueueSize(config);
cloudCallThreadPool = new ThreadPoolExecutor(1,
DEFAULT_THREAD_POOL_SIZE, 60, TimeUnit.SECONDS,
new ResizableLinkedBlockingQueue<>(cloudCallQueueSize));
cloudCallThreadPool.allowCoreThreadTimeOut(true); // act as a cached threadpool
}

private void initializeHandlers() {
// Register auth session handlers
context.get(SessionManager.class).setSessionConfig(new SessionConfig(getConfig()));
Expand Down Expand Up @@ -146,34 +117,20 @@ private void configChangeHandler(WhatHappened whatHappened, Node node) {
return;
}
logger.atDebug().kv("why", whatHappened).kv("node", node).log();

// NOTE: This should not live here. The service doesn't have to have knowledge about where/how
// keys are stored
Topics deviceGroupTopics = this.config.lookupTopics(CONFIGURATION_CONFIG_KEY, DEVICE_GROUPS_TOPICS);

try {
// NOTE: Extract this to a method these are infrastructure concerns.
int threadPoolSize = Coerce.toInt(this.config.findOrDefault(DEFAULT_THREAD_POOL_SIZE,
CONFIGURATION_CONFIG_KEY, PERFORMANCE_TOPIC, MAX_CONCURRENT_CLOUD_REQUESTS_TOPIC));
if (threadPoolSize >= cloudCallThreadPool.getCorePoolSize()) {
cloudCallThreadPool.setMaximumPoolSize(threadPoolSize);
}
} catch (IllegalArgumentException e) {
logger.atWarn().log("Unable to update CDA threadpool size due to {}", e.getMessage());
}

if (whatHappened != WhatHappened.initialized && node != null && node.childOf(CLOUD_REQUEST_QUEUE_SIZE_TOPIC)) {
// NOTE: Extract this to a method these are infrastructure concerns.
BlockingQueue<Runnable> q = cloudCallThreadPool.getQueue();
if (q instanceof ResizableLinkedBlockingQueue) {
cloudCallQueueSize = getValidCloudCallQueueSize(this.config);
((ResizableLinkedBlockingQueue) q).resize(cloudCallQueueSize);
}
}

if (whatHappened == WhatHappened.initialized || node == null || node.childOf(DEVICE_GROUPS_TOPICS)) {
Topics deviceGroupTopics = this.config.lookupTopics(CONFIGURATION_CONFIG_KEY, DEVICE_GROUPS_TOPICS);
updateDeviceGroups(whatHappened, deviceGroupTopics);
}

InfrastructureConfiguration newInfraConfig = InfrastructureConfiguration.from(getConfig());
if (infrastructureConfig == null || !newInfraConfig.equals(infrastructureConfig)) {
updateInfrastructure(newInfraConfig);
infrastructureConfig = newInfraConfig;
}

onConfigurationChanged();
}

Expand All @@ -189,10 +146,24 @@ protected void shutdown() throws InterruptedException {
context.get(CertificateManager.class).stopMonitors();
}

@Override
public void postInject() {
super.postInject();
private void updateInfrastructure(InfrastructureConfiguration infraConfig) {
context.get(CDAExecutor.class).accept(infraConfig);
}

private void initializeInfrastructure() {
// Don't re-inject this if it is already present
CDAExecutor cdaExecutor = context.getIfExists(CDAExecutor.class, null);
if (cdaExecutor == null) {
BlockingQueue<Runnable> queue = new ResizableLinkedBlockingQueue<>(DEFAULT_WORK_QUEUE_DEPTH);
ThreadPoolExecutor executor = new ThreadPoolExecutor(DEFAULT_THREAD_POOL_SIZE,
DEFAULT_THREAD_POOL_SIZE, 60, TimeUnit.SECONDS, queue);
context.put(CDAExecutor.class, new CDAExecutor(executor));
}
}

private void initializeIPC() {
AuthorizationHandler authorizationHandler = context.get(AuthorizationHandler.class);

try {
authorizationHandler.registerComponent(this.getName(),
new HashSet<>(Arrays.asList(SUBSCRIBE_TO_CERTIFICATE_UPDATES,
Expand All @@ -212,17 +183,20 @@ public void postInject() {
new SubscribeToCertificateUpdatesOperationHandler(context, certificateManager, authorizationHandler));
greengrassCoreIPCService.setVerifyClientDeviceIdentityHandler(context ->
new VerifyClientDeviceIdentityOperationHandler(context, serviceApi,
authorizationHandler, cloudCallThreadPool));
authorizationHandler, this.context.get(CDAExecutor.class)));
greengrassCoreIPCService.setGetClientDeviceAuthTokenHandler(context ->
new GetClientDeviceAuthTokenOperationHandler(context, serviceApi, authorizationHandler,
cloudCallThreadPool));
this.context.get(CDAExecutor.class)));
greengrassCoreIPCService.setAuthorizeClientDeviceActionHandler(context ->
new AuthorizeClientDeviceActionOperationHandler(context, serviceApi,
authorizationHandler));
}

public CertificateManager getCertificateManager() {
return context.get(CertificateManager.class);
@Override
public void postInject() {
super.postInject();
initializeInfrastructure();
initializeIPC();
}

private void updateDeviceGroups(WhatHappened whatHappened, Topics deviceGroupsTopics) {
Expand Down Expand Up @@ -250,7 +224,7 @@ void updateCACertificateConfig(List<String> caCerts) {
protected CompletableFuture<Void> close(boolean waitForDependers) {
// shutdown the threadpool in close, not in shutdown() because it is created
// and injected in the constructor and we won't be able to restart it after it stops.
cloudCallThreadPool.shutdown();
context.get(CDAExecutor.class).shutdown();
return super.close(waitForDependers);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package com.aws.greengrass.clientdevices.auth.configuration;

import com.aws.greengrass.config.Topics;
import com.aws.greengrass.util.Coerce;
import lombok.Value;

/**
* Represents client device infrastructure configuration.
* </p>
* NOTE: currently we're shoving some unrelated things under the `performance` key.
* Things like maxActiveAuthTokens and refresh periods should be grouped separately.
* <p>
* |---- configuration
* | |---- performance:
* | |---- cloudRequestQueueSize: "..."
* | |---- maxConcurrentCloudRequests: [...]
* </p>
*/
@Value
public final class InfrastructureConfiguration {
public static final int DEFAULT_WORK_QUEUE_DEPTH = 100;
public static final int DEFAULT_THREAD_POOL_SIZE = 1;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this configuration - might make sense to re-organize a bit. But for now, we'll keep it how it is.

Note: I didn't want to change the existing SessionConfig, so this performance key is split between the two config classes. I don't love it, but I also wanted to keep this PR as contained as possible.

public static final String PERFORMANCE_TOPIC = "performance";
// TODO: Need to determine if this is useful. We may want different numbers for internal
// vs external usage - e.g. IPC work queue throttling vs internal cert refreshes
public static final String WORK_QUEUE_DEPTH = "cloudRequestQueueSize"; // Deprecate?
public static final String THREAD_POOL_SIZE = "maxConcurrentCloudRequests"; // Deprecate?

int workQueueDepth;
int threadPoolSize;

private InfrastructureConfiguration(int workQueueDepth, int threadPoolSize) {
this.workQueueDepth = workQueueDepth;
this.threadPoolSize = threadPoolSize;
}

/**
* Factory method for creating an immutable InfrastructureConfiguration from the service configuration.
*
* @param configurationTopics the configuration key of the service configuration
*/
public static InfrastructureConfiguration from(Topics configurationTopics) {
Topics infraTopics = configurationTopics.lookupTopics(PERFORMANCE_TOPIC);

return new InfrastructureConfiguration(
getWorkQueueDepthFromConfiguration(infraTopics),
getThreadPoolSizeFromConfiguration(infraTopics)
);
}

private static int getWorkQueueDepthFromConfiguration(Topics infraTopics) {
return Coerce.toInt(infraTopics.findOrDefault(DEFAULT_WORK_QUEUE_DEPTH, WORK_QUEUE_DEPTH));
}

private static int getThreadPoolSizeFromConfiguration(Topics infraTopics) {
return Coerce.toInt(infraTopics.findOrDefault(DEFAULT_THREAD_POOL_SIZE, THREAD_POOL_SIZE));
}
}
Loading