Skip to content

Commit

Permalink
refactor: session creator is no longer singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
jbutler committed Oct 7, 2022
1 parent 133807b commit 3f0b018
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.aws.greengrass.clientdevices.auth.session.MqttSessionFactory;
import com.aws.greengrass.clientdevices.auth.session.SessionConfig;
import com.aws.greengrass.clientdevices.auth.session.SessionCreator;
import com.aws.greengrass.clientdevices.auth.session.SessionManager;
import com.aws.greengrass.clientdevices.auth.util.ResizableLinkedBlockingQueue;
import com.aws.greengrass.config.Node;
import com.aws.greengrass.config.Topics;
Expand Down Expand Up @@ -93,6 +92,7 @@ protected void install() throws InterruptedException {
initializeInfrastructure();
initializeHandlers();
subscribeToConfigChanges();
initializeIpc();
}

private int getValidCloudCallQueueSize(Topics topics) {
Expand Down Expand Up @@ -121,8 +121,8 @@ private void initializeInfrastructure() {

private void initializeHandlers() {
// Register auth session handlers
context.get(SessionManager.class).setSessionConfig(new SessionConfig(getConfig()));
SessionCreator.registerSessionFactory("mqtt", context.get(MqttSessionFactory.class));
context.put(SessionConfig.class, new SessionConfig(getConfig()));
context.get(SessionCreator.class).registerSessionFactory("mqtt", context.get(MqttSessionFactory.class));

// Register domain event handlers
context.get(CACertificateChainChangedHandler.class).listen();
Expand All @@ -134,6 +134,36 @@ private void initializeHandlers() {
networkState.registerHandler(context.get(CISShadowMonitor.class));
}

private void initializeIpc() {
AuthorizationHandler authorizationHandler = context.get(AuthorizationHandler.class);
try {
authorizationHandler.registerComponent(this.getName(),
new HashSet<>(Arrays.asList(SUBSCRIBE_TO_CERTIFICATE_UPDATES,
VERIFY_CLIENT_DEVICE_IDENTITY,
GET_CLIENT_DEVICE_AUTH_TOKEN,
AUTHORIZE_CLIENT_DEVICE_ACTION)));
} catch (com.aws.greengrass.authorization.exceptions.AuthorizationException e) {
logger.atError("initialize-cda-service-authorization-error", e)
.log("Failed to initialize the client device auth service with the Authorization module.");
}

GreengrassCoreIPCService greengrassCoreIPCService = context.get(GreengrassCoreIPCService.class);
ClientDevicesAuthServiceApi serviceApi = context.get(ClientDevicesAuthServiceApi.class);
CertificateManager certificateManager = context.get(CertificateManager.class);

greengrassCoreIPCService.setSubscribeToCertificateUpdatesHandler(context ->
new SubscribeToCertificateUpdatesOperationHandler(context, certificateManager, authorizationHandler));
greengrassCoreIPCService.setVerifyClientDeviceIdentityHandler(context ->
new VerifyClientDeviceIdentityOperationHandler(context, serviceApi,
authorizationHandler, cloudCallThreadPool));
greengrassCoreIPCService.setGetClientDeviceAuthTokenHandler(context ->
new GetClientDeviceAuthTokenOperationHandler(context, serviceApi, authorizationHandler,
cloudCallThreadPool));
greengrassCoreIPCService.setAuthorizeClientDeviceActionHandler(context ->
new AuthorizeClientDeviceActionOperationHandler(context, serviceApi,
authorizationHandler));
}

private void subscribeToConfigChanges() {
onConfigurationChanged();
config.lookupTopics(CONFIGURATION_CONFIG_KEY).subscribe(this::configChangeHandler);
Expand Down Expand Up @@ -195,38 +225,6 @@ protected void shutdown() throws InterruptedException {
context.get(CertificateManager.class).stopMonitors();
}

@Override
public void postInject() {
super.postInject();
AuthorizationHandler authorizationHandler = context.get(AuthorizationHandler.class);
try {
authorizationHandler.registerComponent(this.getName(),
new HashSet<>(Arrays.asList(SUBSCRIBE_TO_CERTIFICATE_UPDATES,
VERIFY_CLIENT_DEVICE_IDENTITY,
GET_CLIENT_DEVICE_AUTH_TOKEN,
AUTHORIZE_CLIENT_DEVICE_ACTION)));
} catch (com.aws.greengrass.authorization.exceptions.AuthorizationException e) {
logger.atError("initialize-cda-service-authorization-error", e)
.log("Failed to initialize the client device auth service with the Authorization module.");
}

GreengrassCoreIPCService greengrassCoreIPCService = context.get(GreengrassCoreIPCService.class);
ClientDevicesAuthServiceApi serviceApi = context.get(ClientDevicesAuthServiceApi.class);
CertificateManager certificateManager = context.get(CertificateManager.class);

greengrassCoreIPCService.setSubscribeToCertificateUpdatesHandler(context ->
new SubscribeToCertificateUpdatesOperationHandler(context, certificateManager, authorizationHandler));
greengrassCoreIPCService.setVerifyClientDeviceIdentityHandler(context ->
new VerifyClientDeviceIdentityOperationHandler(context, serviceApi,
authorizationHandler, cloudCallThreadPool));
greengrassCoreIPCService.setGetClientDeviceAuthTokenHandler(context ->
new GetClientDeviceAuthTokenOperationHandler(context, serviceApi, authorizationHandler,
cloudCallThreadPool));
greengrassCoreIPCService.setAuthorizeClientDeviceActionHandler(context ->
new AuthorizeClientDeviceActionOperationHandler(context, serviceApi,
authorizationHandler));
}

public CertificateManager getCertificateManager() {
return context.get(CertificateManager.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
public class SessionCreator {
private static final Logger logger = LogManager.getLogger(SessionCreator.class);

@SuppressWarnings("PMD.UnusedPrivateField")
private final Map<String, SessionFactory> factoryMap;

private SessionCreator() {
public SessionCreator() {
factoryMap = new ConcurrentHashMap<>();
}

private static class SessionFactorySingleton {
@SuppressWarnings("PMD.AccessorClassGeneration")
private static final SessionCreator INSTANCE = new SessionCreator();
}

/**
* Create a client device session.
*
Expand All @@ -35,9 +29,9 @@ private static class SessionFactorySingleton {
* @return new session if the client can be authenticated
* @throws AuthenticationException if the client fails to be authenticated
*/
public static Session createSession(String credentialType, Map<String, String> credentialMap)
public Session createSession(String credentialType, Map<String, String> credentialMap)
throws AuthenticationException {
SessionFactory sessionFactory = SessionFactorySingleton.INSTANCE.factoryMap.get(credentialType);
SessionFactory sessionFactory = factoryMap.get(credentialType);
if (sessionFactory == null) {
logger.atWarn().kv("credentialType", credentialType)
.log("no registered handler to process device credentials");
Expand All @@ -46,11 +40,7 @@ public static Session createSession(String credentialType, Map<String, String> c
return sessionFactory.createSession(credentialMap);
}

public static void registerSessionFactory(String credentialType, SessionFactory sessionFactory) {
SessionFactorySingleton.INSTANCE.factoryMap.put(credentialType, sessionFactory);
}

public static void unregisterSessionFactory(String credentialType) {
SessionFactorySingleton.INSTANCE.factoryMap.remove(credentialType);
public void registerSessionFactory(String credentialType, SessionFactory sessionFactory) {
factoryMap.put(credentialType, sessionFactory);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import javax.inject.Inject;

/**
* Singleton class for managing AuthN and AuthZ sessions.
*/
public class SessionManager {
private static final Logger logger = LogManager.getLogger(SessionManager.class);
private static final String SESSION_ID = "SessionId";

private final SessionCreator sessionCreator;
private final SessionConfig sessionConfig;

// Thread-safe LRU Session Cache that evicts the eldest entry (based on access order) upon reaching its size.
// TODO: Support time-based cache eviction (Session timeout) and Session deduping.
@Getter(AccessLevel.PACKAGE)
Expand All @@ -40,7 +41,11 @@ protected boolean removeEldestEntry(Map.Entry<String, Session> eldest) {
}
});

private SessionConfig sessionConfig;
@Inject
public SessionManager(SessionCreator sessionCreator, SessionConfig sessionConfig) {
this.sessionCreator = sessionCreator;
this.sessionConfig = sessionConfig;
}

/**
* Looks up a session by id.
Expand All @@ -62,7 +67,7 @@ public Session findSession(String sessionId) {
*/
public String createSession(String credentialType, Map<String, String> credentialMap)
throws AuthenticationException {
Session session = SessionCreator.createSession(credentialType, credentialMap);
Session session = sessionCreator.createSession(credentialType, credentialMap);
return addSessionInternal(session);
}

Expand All @@ -76,15 +81,6 @@ public void closeSession(String sessionId) {
closeSessionInternal(sessionId);
}

/**
* Session configuration setter.
*
* @param sessionConfig session configuration
*/
public void setSessionConfig(SessionConfig sessionConfig) {
this.sessionConfig = sessionConfig;
}

private synchronized void closeSessionInternal(String sessionId) {
sessionMap.remove(sessionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import com.aws.greengrass.clientdevices.auth.exception.AuthenticationException;
import com.aws.greengrass.testcommons.testutilities.GGExtension;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -27,18 +27,20 @@ public class SessionCreatorTest {
private static final String mqttCredentialType = "mqtt";
private static final String unknownCredentialType = "unknown";

private SessionCreator sessionCreator;

@Mock
private MqttSessionFactory mqttSessionFactory;

@AfterEach
void afterEach() {
SessionCreator.unregisterSessionFactory(mqttCredentialType);
@BeforeEach
void beforeEach() {
sessionCreator = new SessionCreator();
}

@Test
void GIVEN_noRegisteredFactories_WHEN_createSession_THEN_throwsException() {
Assertions.assertThrows(IllegalArgumentException.class,
() -> SessionCreator.createSession(mqttCredentialType, new HashMap<>()));
() -> sessionCreator.createSession(mqttCredentialType, new HashMap<>()));
}

@Test
Expand All @@ -47,14 +49,14 @@ void GIVEN_registeredMqttSessionFactory_WHEN_createSessionWithMqttCredentials_TH
Session mockSession = mock(SessionImpl.class);
when(mqttSessionFactory.createSession(any())).thenReturn(mockSession);

SessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
assertThat(SessionCreator.createSession(mqttCredentialType, new HashMap<>()), is(mockSession));
sessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
assertThat(sessionCreator.createSession(mqttCredentialType, new HashMap<>()), is(mockSession));
}

@Test
void GIVEN_registeredMqttSessionFactory_WHEN_createSession_WithNonMqttCredentials_THEN_throwsException() {
SessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
sessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
Assertions.assertThrows(IllegalArgumentException.class,
() -> SessionCreator.createSession(unknownCredentialType, new HashMap<>()));
() -> sessionCreator.createSession(unknownCredentialType, new HashMap<>()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import com.aws.greengrass.clientdevices.auth.exception.AuthenticationException;
import com.aws.greengrass.testcommons.testutilities.GGExtension;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -35,6 +34,7 @@ class SessionManagerTest {
private static final String CREDENTIAL_TYPE = "mqtt";
private static final int MOCK_SESSION_CAPACITY = 10;

private SessionCreator sessionCreator;
private SessionManager sessionManager;
@Mock
private MqttSessionFactory mockSessionFactory;
Expand Down Expand Up @@ -66,19 +66,14 @@ class SessionManagerTest {
@BeforeEach
void beforeEach() throws AuthenticationException {
lenient().when(mockSessionConfig.getSessionCapacity()).thenReturn(MOCK_SESSION_CAPACITY);
sessionManager = new SessionManager();
sessionManager.setSessionConfig(mockSessionConfig);
SessionCreator.registerSessionFactory(CREDENTIAL_TYPE, mockSessionFactory);
sessionCreator = new SessionCreator();
sessionManager = new SessionManager(sessionCreator, mockSessionConfig);
sessionCreator.registerSessionFactory(CREDENTIAL_TYPE, mockSessionFactory);
lenient().when(mockSessionFactory.createSession(credentialMap)).thenReturn(mockSession);
lenient().when(mockSessionFactory.createSession(credentialMap2)).thenReturn(mockSession2);
lenient().when(mockSessionFactory.createSession(invalidCredentialMap)).thenThrow(new AuthenticationException(""));
}

@AfterEach
void afterEach() {
SessionCreator.unregisterSessionFactory(CREDENTIAL_TYPE);
}

@Test
void GIVEN_validDeviceCredentials_WHEN_createSession_THEN_sessionCreatedWithUniqueIds()
throws AuthenticationException {
Expand Down Expand Up @@ -111,7 +106,7 @@ void GIVEN_invalidDeviceCredentials_WHEN_createSession_THEN_throwsAuthentication
}

@Test
void GIVEN_validDeviceCredentials_WHEN_createSession_beyond_capacity_THEN_passes_evicting_eldest_session()
void GIVEN_maxOpenSessions_WHEN_createSession_THEN_oldestSessionIsEvicted()
throws AuthenticationException {
reset(mockSessionConfig);
reset(mockSessionFactory);
Expand Down Expand Up @@ -150,10 +145,8 @@ void GIVEN_validDeviceCredentials_WHEN_createSession_beyond_capacity_THEN_passes
when(mockSessionFactory.createSession(credentialMap3)).thenReturn(mockSession3);
when(mockSessionFactory.createSession(credentialMap4)).thenReturn(mockSession4);

int mockSessionCapacity = 3;
when(mockSessionConfig.getSessionCapacity()).thenReturn(mockSessionCapacity);
SessionManager sessionManager = new SessionManager();
sessionManager.setSessionConfig(mockSessionConfig);
int sessionCapacity = 3;
when(mockSessionConfig.getSessionCapacity()).thenReturn(sessionCapacity);

// fill session cache to its capacity
String id1 = sessionManager.createSession(CREDENTIAL_TYPE, credentialMap1);
Expand Down

0 comments on commit 3f0b018

Please sign in to comment.