Skip to content

Commit

Permalink
use callback
Browse files Browse the repository at this point in the history
Signed-off-by: Pablo Hernán Carle <pablo.carle@broadcom.com>
  • Loading branch information
Pablo Hernán Carle committed Aug 26, 2024
1 parent da8a193 commit 36fb764
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* This program and the accompanying materials are made available under the terms of the
* Eclipse Public License v2.0 which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-v20.html
*
* SPDX-License-Identifier: EPL-2.0
*
* Copyright Contributors to the Zowe Project.
*/

package org.zowe.apiml.gateway.ws;

import org.springframework.web.socket.WebSocketSession;

public interface ClientSessionFailureCallback {

void onClientSessionFailure(WebSocketRoutedSession webSocketRoutedSession, WebSocketSession serverSession, Throwable throwable);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* This program and the accompanying materials are made available under the terms of the
* Eclipse Public License v2.0 which accompanies this distribution, and is available at
* https://www.eclipse.org/legal/epl-v20.html
*
* SPDX-License-Identifier: EPL-2.0
*
* Copyright Contributors to the Zowe Project.
*/

package org.zowe.apiml.gateway.ws;

import org.springframework.web.socket.WebSocketSession;

public interface ClientSessionSuccessCallback {

void onClientSessionSuccess(WebSocketRoutedSession webSocketRoutedSession, WebSocketSession serverSession, WebSocketSession clientSession);

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
@Component
@Singleton
@Slf4j
public class WebSocketProxyServerHandler extends AbstractWebSocketHandler implements RoutedServicesUser, SubProtocolCapable {
public class WebSocketProxyServerHandler extends AbstractWebSocketHandler implements RoutedServicesUser, SubProtocolCapable, ClientSessionFailureCallback, ClientSessionSuccessCallback {

@Value("${server.webSocket.supportedProtocols:-}")
private List<String> subProtocols;
Expand Down Expand Up @@ -185,21 +185,7 @@ private void openWebSocketConnection(RoutedService service, ServiceInstance serv

log.debug(String.format("Opening routed WebSocket session from %s to %s with %s by %s", uri.toString(), targetUrl, webSocketClientFactory, this));

WebSocketRoutedSession session = webSocketRoutedSessionFactory.session(webSocketSession, targetUrl, webSocketClientFactory);
session.getWebSocketClientSession().addCallback(
successSession -> {
session.setClientSession(successSession);
routedSessions.put(webSocketSession.getId(), session);
},
failure -> {
log.debug("Failed opening client web socket session against {}. Server WebSocket session is {}", targetUrl, webSocketSession.getId(), failure);
try {
session.getClientHandler().handleMessage(webSocketSession, new TextMessage("Failed opening a session against " + targetUrl + ": " + failure.getMessage()));
webSocketSession.close(CloseStatus.SERVER_ERROR);
} catch (Exception e) {
log.debug("Failed sending / closing WebSocket session", e);
}
});
webSocketRoutedSessionFactory.session(webSocketSession, targetUrl, webSocketClientFactory, this, this);
}

@Override
Expand Down Expand Up @@ -282,7 +268,7 @@ public void handleMessage(WebSocketSession webSocketSession, WebSocketMessage<?>

private boolean isClientConnectionClosed(WebSocketRoutedSession session) {
WebSocketSession clientSession = session.getClientSession();
return session != null && session.getWebSocketClientSession().isDone() && clientSession != null && !clientSession.isOpen();
return session != null && clientSession != null && !clientSession.isOpen();
}

private boolean isClientConnectionReady(WebSocketRoutedSession session) {
Expand All @@ -292,4 +278,20 @@ private boolean isClientConnectionReady(WebSocketRoutedSession session) {
private WebSocketRoutedSession getRoutedSession(WebSocketSession webSocketSession) {
return routedSessions.get(webSocketSession.getId());
}

@Override
public void onClientSessionSuccess(WebSocketRoutedSession routedSession, WebSocketSession serverSession, WebSocketSession clientSession) {
routedSessions.put(serverSession.getId(), routedSession);
}

@Override
public void onClientSessionFailure(WebSocketRoutedSession routedSession, WebSocketSession serverSession, Throwable throwable) {
log.debug("Failed opening client web socket session against {}. Server WebSocket session is {}", routedSession.getTargetUrl(), serverSession.getId(), throwable);
try {
routedSession.getClientHandler().handleMessage(serverSession, new TextMessage("Failed opening a session against " + routedSession.getTargetUrl() + ": " + throwable.getMessage()));
serverSession.close(CloseStatus.SERVER_ERROR);
} catch (Exception e) {
log.debug("Failed sending / closing WebSocket session", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpHeaders;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;

import javax.annotation.Nullable;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;

/**
* Represents a connection in the proxying chain, establishes 'client' to
Expand All @@ -35,23 +34,33 @@
@Slf4j
public class WebSocketRoutedSession {

private final ListenableFuture<WebSocketSession> webSocketClientSession;
private volatile WebSocketSession clientSession;
private final WebSocketSession webSocketServerSession;
private final WebSocketProxyClientHandler clientHandler;
private final String targetUrl;

private WebSocketSession clientSession;
private final List<ClientSessionSuccessCallback> successCallbacks = new ArrayList<>();
private final List<ClientSessionFailureCallback> failureCallbacks = new ArrayList<>();

public WebSocketRoutedSession(WebSocketSession webSocketServerSession, String targetUrl, WebSocketClientFactory webSocketClientFactory) {
this.webSocketServerSession = webSocketServerSession;
this.targetUrl = targetUrl;
this.clientHandler = new WebSocketProxyClientHandler(webSocketServerSession);
this.webSocketClientSession = createWebSocketClientSession(webSocketServerSession, targetUrl, webSocketClientFactory);

public WebSocketRoutedSession(
WebSocketSession webSocketServerSession,
String targetUrl,
WebSocketClientFactory webSocketClientFactory,
ClientSessionSuccessCallback successCallback,
ClientSessionFailureCallback failureCallback) {
this.webSocketServerSession = webSocketServerSession;
this.targetUrl = targetUrl;
this.clientHandler = new WebSocketProxyClientHandler(webSocketServerSession);
this.successCallbacks.add(successCallback);
this.failureCallbacks.add(failureCallback);

createWebSocketClientSession(webSocketServerSession, targetUrl, webSocketClientFactory);
}

@VisibleForTesting
WebSocketRoutedSession(WebSocketSession webSocketServerSession, ListenableFuture<WebSocketSession> webSocketClientSession, WebSocketProxyClientHandler clientHandler, String targetUrl) {
this.webSocketClientSession = webSocketClientSession;
WebSocketRoutedSession(WebSocketSession webSocketServerSession, WebSocketSession webSocketClientSession, WebSocketProxyClientHandler clientHandler, String targetUrl) {
this.clientSession = webSocketClientSession;
this.webSocketServerSession = webSocketServerSession;
this.clientHandler = clientHandler;
this.targetUrl = targetUrl;
Expand All @@ -68,10 +77,6 @@ private WebSocketHttpHeaders getWebSocketHttpHeaders(WebSocketSession webSocketS
return headers;
}

public ListenableFuture<WebSocketSession> getWebSocketClientSession() {
return webSocketClientSession;
}

public WebSocketSession getWebSocketServerSession() {
return webSocketServerSession;
}
Expand All @@ -80,21 +85,29 @@ WebSocketProxyClientHandler getClientHandler() {
return clientHandler;
}

void setClientSession(WebSocketSession clientSession) {
this.clientSession = clientSession;
}

@Nullable
WebSocketSession getClientSession() {
return clientSession;
}

private ListenableFuture<WebSocketSession> createWebSocketClientSession(WebSocketSession webSocketServerSession, String targetUrl, WebSocketClientFactory webSocketClientFactory) {
String getTargetUrl() {
return targetUrl;
}

private void onSuccess(WebSocketSession serverSession, WebSocketSession clientSession) {
this.successCallbacks.forEach(callback -> callback.onClientSessionSuccess(this, serverSession, clientSession));
}

private void onFailure(WebSocketSession serverSession, Throwable throwable) {
this.failureCallbacks.forEach(callback -> callback.onClientSessionFailure(this, serverSession, throwable));
}

private void createWebSocketClientSession(WebSocketSession webSocketServerSession, String targetUrl, WebSocketClientFactory webSocketClientFactory) {
try {
JettyWebSocketClient client = webSocketClientFactory.getClientInstance(targetUrl);
URI targetURI = new URI(targetUrl);
WebSocketHttpHeaders headers = getWebSocketHttpHeaders(webSocketServerSession);
return client.doHandshake(clientHandler, headers, targetURI);
client.doHandshake(clientHandler, headers, targetURI)
.addCallback(clientSession -> this.onSuccess(webSocketServerSession, clientSession), e -> this.onFailure(webSocketServerSession, e));
} catch (IllegalStateException e) {
throw webSocketProxyException(targetUrl, e, webSocketServerSession, true);
} catch (Exception e) {
Expand All @@ -111,7 +124,7 @@ private WebSocketProxyError webSocketProxyException(String targetUrl, Exception
}

public void sendMessageToServer(WebSocketMessage<?> webSocketMessage) throws IOException {
log.debug("sendMessageToServer(session={}, message={})", webSocketClientSession, webSocketMessage);
log.debug("sendMessageToServer(session={}, message={})", clientSession, webSocketMessage);
if (isClientConnected()) {
try {
clientSession.sendMessage(webSocketMessage);
Expand All @@ -124,7 +137,7 @@ public void sendMessageToServer(WebSocketMessage<?> webSocketMessage) throws IOE
}

public boolean isClientConnected() {
return webSocketClientSession.isDone() && clientSession != null && clientSession.isOpen();
return clientSession != null && clientSession.isOpen();
}

public void close(CloseStatus status) throws IOException {
Expand Down Expand Up @@ -167,7 +180,7 @@ public String getClientUri() {
* @throws ServerNotYetAvailableException If a client session is not established
*/
public String getClientId() {
if (!webSocketClientSession.isDone()) {
if (clientSession == null) {
throw new ServerNotYetAvailableException();
}
if (isClientConnected()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ public interface WebSocketRoutedSessionFactory {
* @param webSocketSession Valid Server side WebSocket Session.
* @param targetUrl Full websocket URL towards the server
* @param webSocketClientFactory Factory producing the current SSL Context.
* @return Valid routed session handling the client session
* @param successCallback callback to be called if the session is established successfully
* @param failureCallback callback to be called if the session fails to be established
*/
WebSocketRoutedSession session(WebSocketSession webSocketSession, String targetUrl, WebSocketClientFactory webSocketClientFactory);
WebSocketRoutedSession session(
WebSocketSession webSocketSession,
String targetUrl,
WebSocketClientFactory webSocketClientFactory,
ClientSessionSuccessCallback successCallback,
ClientSessionFailureCallback failureCallback);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
public class WebSocketRoutedSessionFactoryImpl implements WebSocketRoutedSessionFactory {

@Override
public WebSocketRoutedSession session(WebSocketSession webSocketSession, String targetUrl, WebSocketClientFactory webSocketClientFactory) {
return new WebSocketRoutedSession(webSocketSession, targetUrl, webSocketClientFactory);
public WebSocketRoutedSession session(
WebSocketSession webSocketSession,
String targetUrl,
WebSocketClientFactory webSocketClientFactory,
ClientSessionSuccessCallback successCallback,
ClientSessionFailureCallback failureCallback) {
return new WebSocketRoutedSession(webSocketSession, targetUrl, webSocketClientFactory, successCallback, failureCallback);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.LoadBalancerClient;
import org.springframework.scheduling.annotation.AsyncResult;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketMessage;
Expand Down Expand Up @@ -106,18 +105,17 @@ void prepareRoutedService() {
void givenValidRoute() throws Exception {
String path = "wss://gatewayHost:1443/valid-service/ws/v1/valid-path";
WebSocketRoutedSession routedSession = mock(WebSocketRoutedSession.class);
when(webSocketRoutedSessionFactory.session(any(), any(), any())).thenReturn(routedSession);
when(webSocketRoutedSessionFactory.session(any(), any(), any(), any(), any())).thenReturn(routedSession);
ServiceInstance serviceInstance = mock(ServiceInstance.class);
when(lbClient.choose(any())).thenReturn(serviceInstance);
String establishedSessionId = "validAndUniqueId";
when(establishedSession.getId()).thenReturn(establishedSessionId);
when(establishedSession.getUri()).thenReturn(new URI(path));

when(routedSession.getWebSocketClientSession()).thenReturn(AsyncResult.forValue(establishedSession));

underTest.onClientSessionSuccess(routedSession, establishedSession, establishedSession);
underTest.afterConnectionEstablished(establishedSession);

verify(webSocketRoutedSessionFactory).session(any(), any(), any());
verify(webSocketRoutedSessionFactory).session(any(), any(), any(), any(), any());
WebSocketRoutedSession preparedSession = routedSessions.get(establishedSessionId);
assertThat(preparedSession, is(notNullValue()));
}
Expand Down Expand Up @@ -252,7 +250,7 @@ void whenTheMessageIsReceived_thenTheMessageIsPassedToTheSession() throws Except
WebSocketSession clientSession = mock(WebSocketSession.class);
when(clientSession.isOpen()).thenReturn(true);
when(internallyStoredSession.isClientConnected()).thenReturn(true);
when(internallyStoredSession.getWebSocketClientSession()).thenReturn(AsyncResult.forValue(clientSession));
when(internallyStoredSession.getClientSession()).thenReturn(clientSession);
when(internallyStoredSession.getClientSession()).thenReturn(clientSession);

underTest.handleMessage(establishedSession, passedMessage);
Expand All @@ -263,7 +261,7 @@ void whenTheMessageIsReceived_thenTheMessageIsPassedToTheSession() throws Except
@Test
void whenExceptionIsThrown_thenRemoveRoutedSession() throws Exception {
doThrow(new WebSocketException("error")).when(routedSessions.get("123")).sendMessageToServer(passedMessage);
when(internallyStoredSession.getWebSocketClientSession()).thenReturn(AsyncResult.forValue(establishedSession));
when(internallyStoredSession.getClientSession()).thenReturn(establishedSession);
when(internallyStoredSession.getClientSession()).thenReturn(establishedSession);
when(internallyStoredSession.isClientConnected()).thenReturn(true);
when(establishedSession.isOpen()).thenReturn(true);
Expand Down
Loading

0 comments on commit 36fb764

Please sign in to comment.