Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: session creator is no longer singleton #154

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.aws.greengrass.clientdevices.auth.session.MqttSessionFactory;
import com.aws.greengrass.clientdevices.auth.session.SessionConfig;
import com.aws.greengrass.clientdevices.auth.session.SessionCreator;
import com.aws.greengrass.clientdevices.auth.session.SessionManager;
import com.aws.greengrass.clientdevices.auth.util.ResizableLinkedBlockingQueue;
import com.aws.greengrass.config.Node;
import com.aws.greengrass.config.Topics;
Expand Down Expand Up @@ -95,6 +94,7 @@ protected void install() throws InterruptedException {
initializeInfrastructure();
initializeHandlers();
subscribeToConfigChanges();
initializeIpc();
}

private int getValidCloudCallQueueSize(Topics topics) {
Expand Down Expand Up @@ -127,8 +127,8 @@ private void initializeInfrastructure() {

private void initializeHandlers() {
// Register auth session handlers
context.get(SessionManager.class).setSessionConfig(new SessionConfig(getConfig()));
SessionCreator.registerSessionFactory("mqtt", context.get(MqttSessionFactory.class));
context.put(SessionConfig.class, new SessionConfig(getConfig()));
context.get(SessionCreator.class).registerSessionFactory("mqtt", context.get(MqttSessionFactory.class));

// Register domain event handlers
context.get(CACertificateChainChangedHandler.class).listen();
Expand All @@ -137,6 +137,36 @@ private void initializeHandlers() {
context.get(SecurityConfigurationChangedHandler.class).listen();
}

private void initializeIpc() {
AuthorizationHandler authorizationHandler = context.get(AuthorizationHandler.class);
try {
authorizationHandler.registerComponent(this.getName(),
new HashSet<>(Arrays.asList(SUBSCRIBE_TO_CERTIFICATE_UPDATES,
VERIFY_CLIENT_DEVICE_IDENTITY,
GET_CLIENT_DEVICE_AUTH_TOKEN,
AUTHORIZE_CLIENT_DEVICE_ACTION)));
} catch (com.aws.greengrass.authorization.exceptions.AuthorizationException e) {
logger.atError("initialize-cda-service-authorization-error", e)
.log("Failed to initialize the client device auth service with the Authorization module.");
}

GreengrassCoreIPCService greengrassCoreIPCService = context.get(GreengrassCoreIPCService.class);
ClientDevicesAuthServiceApi serviceApi = context.get(ClientDevicesAuthServiceApi.class);
CertificateManager certificateManager = context.get(CertificateManager.class);

greengrassCoreIPCService.setSubscribeToCertificateUpdatesHandler(context ->
new SubscribeToCertificateUpdatesOperationHandler(context, certificateManager, authorizationHandler));
greengrassCoreIPCService.setVerifyClientDeviceIdentityHandler(context ->
new VerifyClientDeviceIdentityOperationHandler(context, serviceApi,
authorizationHandler, cloudCallThreadPool));
greengrassCoreIPCService.setGetClientDeviceAuthTokenHandler(context ->
new GetClientDeviceAuthTokenOperationHandler(context, serviceApi, authorizationHandler,
cloudCallThreadPool));
greengrassCoreIPCService.setAuthorizeClientDeviceActionHandler(context ->
new AuthorizeClientDeviceActionOperationHandler(context, serviceApi,
authorizationHandler));
}

private void subscribeToConfigChanges() {
onConfigurationChanged();
config.lookupTopics(CONFIGURATION_CONFIG_KEY).subscribe(this::configChangeHandler);
Expand Down Expand Up @@ -199,38 +229,6 @@ protected void shutdown() throws InterruptedException {
context.get(BackgroundCertificateRefresh.class).stop();
}

@Override
public void postInject() {
super.postInject();
AuthorizationHandler authorizationHandler = context.get(AuthorizationHandler.class);
try {
authorizationHandler.registerComponent(this.getName(),
new HashSet<>(Arrays.asList(SUBSCRIBE_TO_CERTIFICATE_UPDATES,
VERIFY_CLIENT_DEVICE_IDENTITY,
GET_CLIENT_DEVICE_AUTH_TOKEN,
AUTHORIZE_CLIENT_DEVICE_ACTION)));
} catch (com.aws.greengrass.authorization.exceptions.AuthorizationException e) {
logger.atError("initialize-cda-service-authorization-error", e)
.log("Failed to initialize the client device auth service with the Authorization module.");
}

GreengrassCoreIPCService greengrassCoreIPCService = context.get(GreengrassCoreIPCService.class);
ClientDevicesAuthServiceApi serviceApi = context.get(ClientDevicesAuthServiceApi.class);
CertificateManager certificateManager = context.get(CertificateManager.class);

greengrassCoreIPCService.setSubscribeToCertificateUpdatesHandler(context ->
new SubscribeToCertificateUpdatesOperationHandler(context, certificateManager, authorizationHandler));
greengrassCoreIPCService.setVerifyClientDeviceIdentityHandler(context ->
new VerifyClientDeviceIdentityOperationHandler(context, serviceApi,
authorizationHandler, cloudCallThreadPool));
greengrassCoreIPCService.setGetClientDeviceAuthTokenHandler(context ->
new GetClientDeviceAuthTokenOperationHandler(context, serviceApi, authorizationHandler,
cloudCallThreadPool));
greengrassCoreIPCService.setAuthorizeClientDeviceActionHandler(context ->
new AuthorizeClientDeviceActionOperationHandler(context, serviceApi,
authorizationHandler));
}

public CertificateManager getCertificateManager() {
return context.get(CertificateManager.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,12 @@
public class SessionCreator {
private static final Logger logger = LogManager.getLogger(SessionCreator.class);

@SuppressWarnings("PMD.UnusedPrivateField")
private final Map<String, SessionFactory> factoryMap;

private SessionCreator() {
public SessionCreator() {
factoryMap = new ConcurrentHashMap<>();
}

private static class SessionFactorySingleton {
@SuppressWarnings("PMD.AccessorClassGeneration")
private static final SessionCreator INSTANCE = new SessionCreator();
}

/**
* Create a client device session.
*
Expand All @@ -35,9 +29,9 @@ private static class SessionFactorySingleton {
* @return new session if the client can be authenticated
* @throws AuthenticationException if the client fails to be authenticated
*/
public static Session createSession(String credentialType, Map<String, String> credentialMap)
public Session createSession(String credentialType, Map<String, String> credentialMap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change impacts Moquette?

throws AuthenticationException {
SessionFactory sessionFactory = SessionFactorySingleton.INSTANCE.factoryMap.get(credentialType);
SessionFactory sessionFactory = factoryMap.get(credentialType);
if (sessionFactory == null) {
logger.atWarn().kv("credentialType", credentialType)
.log("no registered handler to process device credentials");
Expand All @@ -46,11 +40,7 @@ public static Session createSession(String credentialType, Map<String, String> c
return sessionFactory.createSession(credentialMap);
}

public static void registerSessionFactory(String credentialType, SessionFactory sessionFactory) {
SessionFactorySingleton.INSTANCE.factoryMap.put(credentialType, sessionFactory);
}

public static void unregisterSessionFactory(String credentialType) {
SessionFactorySingleton.INSTANCE.factoryMap.remove(credentialType);
public void registerSessionFactory(String credentialType, SessionFactory sessionFactory) {
factoryMap.put(credentialType, sessionFactory);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import javax.inject.Inject;

/**
* Singleton class for managing AuthN and AuthZ sessions.
*/
public class SessionManager {
private static final Logger logger = LogManager.getLogger(SessionManager.class);
private static final String SESSION_ID = "SessionId";

private final SessionCreator sessionCreator;
private final SessionConfig sessionConfig;

// Thread-safe LRU Session Cache that evicts the eldest entry (based on access order) upon reaching its size.
// TODO: Support time-based cache eviction (Session timeout) and Session deduping.
@Getter(AccessLevel.PACKAGE)
Expand All @@ -40,7 +41,11 @@ protected boolean removeEldestEntry(Map.Entry<String, Session> eldest) {
}
});

private SessionConfig sessionConfig;
@Inject
public SessionManager(SessionCreator sessionCreator, SessionConfig sessionConfig) {
this.sessionCreator = sessionCreator;
this.sessionConfig = sessionConfig;
}

/**
* Looks up a session by id.
Expand All @@ -62,7 +67,7 @@ public Session findSession(String sessionId) {
*/
public String createSession(String credentialType, Map<String, String> credentialMap)
throws AuthenticationException {
Session session = SessionCreator.createSession(credentialType, credentialMap);
Session session = sessionCreator.createSession(credentialType, credentialMap);
return addSessionInternal(session);
}

Expand All @@ -76,15 +81,6 @@ public void closeSession(String sessionId) {
closeSessionInternal(sessionId);
}

/**
* Session configuration setter.
*
* @param sessionConfig session configuration
*/
public void setSessionConfig(SessionConfig sessionConfig) {
this.sessionConfig = sessionConfig;
}

private synchronized void closeSessionInternal(String sessionId) {
sessionMap.remove(sessionId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import com.aws.greengrass.clientdevices.auth.exception.AuthenticationException;
import com.aws.greengrass.testcommons.testutilities.GGExtension;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand All @@ -27,18 +27,20 @@ public class SessionCreatorTest {
private static final String mqttCredentialType = "mqtt";
private static final String unknownCredentialType = "unknown";

private SessionCreator sessionCreator;

@Mock
private MqttSessionFactory mqttSessionFactory;

@AfterEach
void afterEach() {
SessionCreator.unregisterSessionFactory(mqttCredentialType);
@BeforeEach
void beforeEach() {
sessionCreator = new SessionCreator();
}

@Test
void GIVEN_noRegisteredFactories_WHEN_createSession_THEN_throwsException() {
Assertions.assertThrows(IllegalArgumentException.class,
() -> SessionCreator.createSession(mqttCredentialType, new HashMap<>()));
() -> sessionCreator.createSession(mqttCredentialType, new HashMap<>()));
}

@Test
Expand All @@ -47,14 +49,14 @@ void GIVEN_registeredMqttSessionFactory_WHEN_createSessionWithMqttCredentials_TH
Session mockSession = mock(SessionImpl.class);
when(mqttSessionFactory.createSession(any())).thenReturn(mockSession);

SessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
assertThat(SessionCreator.createSession(mqttCredentialType, new HashMap<>()), is(mockSession));
sessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
assertThat(sessionCreator.createSession(mqttCredentialType, new HashMap<>()), is(mockSession));
}

@Test
void GIVEN_registeredMqttSessionFactory_WHEN_createSession_WithNonMqttCredentials_THEN_throwsException() {
SessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
sessionCreator.registerSessionFactory(mqttCredentialType, mqttSessionFactory);
Assertions.assertThrows(IllegalArgumentException.class,
() -> SessionCreator.createSession(unknownCredentialType, new HashMap<>()));
() -> sessionCreator.createSession(unknownCredentialType, new HashMap<>()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import com.aws.greengrass.clientdevices.auth.exception.AuthenticationException;
import com.aws.greengrass.testcommons.testutilities.GGExtension;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -35,6 +34,7 @@ class SessionManagerTest {
private static final String CREDENTIAL_TYPE = "mqtt";
private static final int MOCK_SESSION_CAPACITY = 10;

private SessionCreator sessionCreator;
private SessionManager sessionManager;
@Mock
private MqttSessionFactory mockSessionFactory;
Expand Down Expand Up @@ -66,19 +66,14 @@ class SessionManagerTest {
@BeforeEach
void beforeEach() throws AuthenticationException {
lenient().when(mockSessionConfig.getSessionCapacity()).thenReturn(MOCK_SESSION_CAPACITY);
sessionManager = new SessionManager();
sessionManager.setSessionConfig(mockSessionConfig);
SessionCreator.registerSessionFactory(CREDENTIAL_TYPE, mockSessionFactory);
sessionCreator = new SessionCreator();
sessionManager = new SessionManager(sessionCreator, mockSessionConfig);
sessionCreator.registerSessionFactory(CREDENTIAL_TYPE, mockSessionFactory);
lenient().when(mockSessionFactory.createSession(credentialMap)).thenReturn(mockSession);
lenient().when(mockSessionFactory.createSession(credentialMap2)).thenReturn(mockSession2);
lenient().when(mockSessionFactory.createSession(invalidCredentialMap)).thenThrow(new AuthenticationException(""));
}

@AfterEach
void afterEach() {
SessionCreator.unregisterSessionFactory(CREDENTIAL_TYPE);
}

@Test
void GIVEN_validDeviceCredentials_WHEN_createSession_THEN_sessionCreatedWithUniqueIds()
throws AuthenticationException {
Expand Down Expand Up @@ -111,7 +106,7 @@ void GIVEN_invalidDeviceCredentials_WHEN_createSession_THEN_throwsAuthentication
}

@Test
void GIVEN_validDeviceCredentials_WHEN_createSession_beyond_capacity_THEN_passes_evicting_eldest_session()
void GIVEN_maxOpenSessions_WHEN_createSession_THEN_oldestSessionIsEvicted()
throws AuthenticationException {
reset(mockSessionConfig);
reset(mockSessionFactory);
Expand Down Expand Up @@ -150,10 +145,8 @@ void GIVEN_validDeviceCredentials_WHEN_createSession_beyond_capacity_THEN_passes
when(mockSessionFactory.createSession(credentialMap3)).thenReturn(mockSession3);
when(mockSessionFactory.createSession(credentialMap4)).thenReturn(mockSession4);

int mockSessionCapacity = 3;
when(mockSessionConfig.getSessionCapacity()).thenReturn(mockSessionCapacity);
SessionManager sessionManager = new SessionManager();
sessionManager.setSessionConfig(mockSessionConfig);
int sessionCapacity = 3;
when(mockSessionConfig.getSessionCapacity()).thenReturn(sessionCapacity);

// fill session cache to its capacity
String id1 = sessionManager.createSession(CREDENTIAL_TYPE, credentialMap1);
Expand Down