diff --git a/gateway-service/src/test/java/org/zowe/apiml/gateway/ws/WebSocketProxyServerHandlerTest.java b/gateway-service/src/test/java/org/zowe/apiml/gateway/ws/WebSocketProxyServerHandlerTest.java index 9ac41fcc60..928902112d 100644 --- a/gateway-service/src/test/java/org/zowe/apiml/gateway/ws/WebSocketProxyServerHandlerTest.java +++ b/gateway-service/src/test/java/org/zowe/apiml/gateway/ws/WebSocketProxyServerHandlerTest.java @@ -14,6 +14,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.cloud.client.ServiceInstance; import org.springframework.cloud.client.loadbalancer.LoadBalancerClient; import org.springframework.scheduling.annotation.AsyncResult; @@ -30,6 +33,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.notNullValue; @@ -45,7 +49,9 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +@ExtendWith(MockitoExtension.class) class WebSocketProxyServerHandlerTest { + private WebSocketProxyServerHandler underTest; private WebSocketRoutedSessionFactory webSocketRoutedSessionFactory; private Map routedSessions; @@ -68,6 +74,7 @@ public void setup() { @Nested class WhenTheConnectionIsEstablished { + WebSocketSession establishedSession; @BeforeEach @@ -77,6 +84,7 @@ void prepareSessionMock() { @Nested class ThenTheValidSessionIsStoredInternally { + @BeforeEach void prepareRoutedService() { String serviceId = "valid-service"; @@ -99,15 +107,19 @@ void prepareRoutedService() { * Proper WebSocketSession is stored. */ @Test - void givenValidRoute() throws Exception { + void givenValidRoute() throws Exception { String path = "wss://gatewayHost:1443/valid-service/ws/v1/valid-path"; - when(webSocketRoutedSessionFactory.session(any(), any(), any())).thenReturn(mock(WebSocketRoutedSession.class)); + WebSocketRoutedSession routedSession = mock(WebSocketRoutedSession.class); + when(webSocketRoutedSessionFactory.session(any(), any(), any())).thenReturn(routedSession); ServiceInstance serviceInstance = mock(ServiceInstance.class); when(lbClient.choose(any())).thenReturn(serviceInstance); String establishedSessionId = "validAndUniqueId"; when(establishedSession.getId()).thenReturn(establishedSessionId); when(establishedSession.getUri()).thenReturn(new URI(path)); + CountDownLatch latch = new CountDownLatch(1); + when(routedSession.getWebSocketClientSession()).thenReturn(null); + underTest.afterConnectionEstablished(establishedSession); verify(webSocketRoutedSessionFactory).session(any(), any(), any()); @@ -119,6 +131,7 @@ void givenValidRoute() throws Exception { @Nested class ThenTheSocketIsClosed { + @BeforeEach void sessionIsOpen() { when(establishedSession.isOpen()).thenReturn(true); @@ -186,7 +199,6 @@ void givenNoInstanceOfTheServiceIsInTheRepository() throws Exception { @Test void givenNullService_thenCloseWebSocket() throws Exception { when(establishedSession.getUri()).thenReturn(new URI("wss://gatewayHost:1443/service-without-instance/ws/v1/valid-path")); - when(lbClient.choose(any())).thenReturn(null); RoutedServices routesForSpecificValidService = mock(RoutedServices.class); when(routesForSpecificValidService.findServiceByGatewayUrl("ws/v1")) .thenReturn(null); @@ -201,17 +213,18 @@ void givenNullService_thenCloseWebSocket() throws Exception { @Nested class GivenValidExistingSession { + + @Mock WebSocketSession establishedSession; + @Mock WebSocketRoutedSession internallyStoredSession; + @Mock WebSocketMessage passedMessage; @BeforeEach void prepareSessionMock() { - establishedSession = mock(WebSocketSession.class); String validSessionId = "123"; when(establishedSession.getId()).thenReturn(validSessionId); - passedMessage = mock(WebSocketMessage.class); - internallyStoredSession = mock(WebSocketRoutedSession.class); routedSessions.put(validSessionId, internallyStoredSession); } @@ -239,6 +252,10 @@ void whenTheConnectionIsClosed_thenClientSessionIsAlsoClosed() throws IOExceptio @Test void whenTheMessageIsReceived_thenTheMessageIsPassedToTheSession() throws Exception { + WebSocketSession clientSession = mock(WebSocketSession.class); + when(clientSession.isOpen()).thenReturn(true); + when(internallyStoredSession.getWebSocketClientSession()).thenReturn(AsyncResult.forValue(clientSession)); + underTest.handleMessage(establishedSession, passedMessage); verify(internallyStoredSession).sendMessageToServer(passedMessage); @@ -280,15 +297,18 @@ void whenClosingRoutedSessionThrowException_thenCatchIt() throws IOException { @Nested class WhenGettingRoutedSessions { + @Test void thenReturnThem() { Map expectedRoutedSessions = underTest.getRoutedSessions(); assertThat(expectedRoutedSessions, is(routedSessions)); } + } @Nested class WhenGettingSubProtocols { + @Test void thenReturnThem() { List protocol = new ArrayList<>(); @@ -297,5 +317,7 @@ void thenReturnThem() { List subProtocols = underTest.getSubProtocols(); assertThat(subProtocols, is(protocol)); } + } + }