From f13084a789a6a847303c63bef3a4f487375ad543 Mon Sep 17 00:00:00 2001 From: Dan Wang Date: Thu, 24 Oct 2024 14:24:36 +0800 Subject: [PATCH] fix(java-client): fix `Negotiation failed` after a session with authentication enabled is closed (#2132) Resolve https://github.com/apache/incubator-pegasus/issues/2133. After the session with authentication enabled to remote server is closed(e.g. the server is killed), the flag(i.e. `authSucceed`) marking whether the session has been authenticated is still kept `true`. Afterwards, while trying to creating another new connection to the same remote server for this session, some requests would be pending before it is connected successfully. Once the session is connected, the pending requests would be sent to the remote server. Typically, the first element of the pending queue would be a non-negotiation request(`query_cfg_request` to the meta server for example), and the second one would be a negotiation request. Normally, the non-negotiation request would be moved to another pending queue for authentication; however, since the flag still marks that the session has been authenticated successfully, the non-negotiation request would be directed to the remote server which would never reply to the client since the server think that the negotiation has not been launched. However, since the client has actually launched the negotiation, it would close the session due to timeout or having not receiving response from the server for long time. Therefore, the above process would repeat again. To fix this problem, the flag should be reset after the session is closed. Then, the client would never send non-negotiation requests before the negotiation with the server becomes successful. --- .../pegasus/rpc/async/ReplicaSession.java | 240 +++++++++++------- .../AuthReplicaSessionInterceptor.java | 4 +- .../pegasus/rpc/async/ReplicaSessionTest.java | 228 ++++++++++++----- 3 files changed, 303 insertions(+), 169 deletions(-) diff --git a/java-client/src/main/java/org/apache/pegasus/rpc/async/ReplicaSession.java b/java-client/src/main/java/org/apache/pegasus/rpc/async/ReplicaSession.java index ff0fd7960a..a1bd36407e 100644 --- a/java-client/src/main/java/org/apache/pegasus/rpc/async/ReplicaSession.java +++ b/java-client/src/main/java/org/apache/pegasus/rpc/async/ReplicaSession.java @@ -122,19 +122,20 @@ public void asyncSend( VolatileFields cache = fields; if (cache.state == ConnState.CONNECTED) { write(entry, cache); - } else { - synchronized (pendingSend) { - cache = fields; - if (cache.state == ConnState.CONNECTED) { - write(entry, cache); - } else { - if (!pendingSend.offer(entry)) { - logger.warn("pendingSend queue is full, drop the request"); - } + return; + } + + synchronized (pendingSend) { + cache = fields; + if (cache.state == ConnState.CONNECTED) { + write(entry, cache); + } else { + if (!pendingSend.offer(entry)) { + logger.warn("pendingSend queue is full for session {}, drop the request", name()); } } - tryConnect(); } + tryConnect(); } public void closeSession() { @@ -145,9 +146,9 @@ public void closeSession() { // but the connection may not be completely closed then, that is, // the state may not be marked as DISCONNECTED immediately. f.nettyChannel.close().sync(); - logger.info("channel to {} closed", address.toString()); + logger.info("channel to {} closed", name()); } catch (Exception ex) { - logger.warn("close channel {} failed: ", address.toString(), ex); + logger.warn("close channel {} failed: ", name(), ex); } } else if (f.state == ConnState.CONNECTING) { // f.nettyChannel == null // If our actively-close strategy fails to reconnect the session due to @@ -189,9 +190,7 @@ public ChannelFuture tryConnect() { boolean needConnect = false; synchronized (pendingSend) { if (fields.state == ConnState.DISCONNECTED) { - VolatileFields cache = new VolatileFields(); - cache.state = ConnState.CONNECTING; - fields = cache; + fields = new VolatileFields(ConnState.CONNECTING); needConnect = true; } } @@ -221,79 +220,96 @@ public void operationComplete(ChannelFuture channelFuture) throws Exception { } }); } catch (UnknownHostException ex) { - logger.error("invalid address: {}", address.toString()); + logger.error("invalid address: {}", name()); assert false; return null; // unreachable } } private void markSessionConnected(Channel activeChannel) { - VolatileFields newCache = new VolatileFields(); - newCache.state = ConnState.CONNECTED; - newCache.nettyChannel = activeChannel; + VolatileFields newCache = new VolatileFields(ConnState.CONNECTED, activeChannel); + // Note that actions in interceptor such as Negotiation might send request by + // ReplicaSession#asyncSend(), inside which ReplicaSession#tryConnect() would + // also be called since current state of this session is still CONNECTING. + // However, it would never create another connection, thus it is safe to do + // `fields = newCache` at the end. interceptorManager.onConnected(this); synchronized (pendingSend) { if (fields.state != ConnState.CONNECTING) { - // this session may have been closed or connected already + // This session may have been closed or connected already. logger.info("{}: session is {}, skip to mark it connected", name(), fields.state); return; } - while (!pendingSend.isEmpty()) { - RequestEntry e = pendingSend.poll(); - if (pendingResponse.get(e.sequenceId) != null) { - write(e, newCache); - } else { - logger.info("{}: {} is removed from pending, perhaps timeout", name(), e.sequenceId); - } - } + // Once the authentication is enabled, any request except Negotiation such as + // query_cfg_operator for meta would be cached in authPendingSend and sent after + // Negotiation is successful. Negotiation would be performed first before any other + // request for the reason that AuthProtocol#isAuthRequest() would return true for + // negotiation_operator. + sendPendingRequests(pendingSend, newCache); fields = newCache; } } void markSessionDisconnect() { VolatileFields cache = fields; + if (cache.state == ConnState.DISCONNECTED) { + logger.warn("{}: session is closed already", name()); + resetAuth(); + return; + } + synchronized (pendingSend) { - if (cache.state != ConnState.DISCONNECTED) { - // NOTICE: - // 1. when a connection is reset, the timeout response - // is not answered in the order they query - // 2. It's likely that when the session is disconnecting - // but the caller of the api query/asyncQuery didn't notice - // this. In this case, we are relying on the timeout task. - try { - while (!pendingSend.isEmpty()) { - RequestEntry e = pendingSend.poll(); - tryNotifyFailureWithSeqID( - e.sequenceId, error_code.error_types.ERR_SESSION_RESET, false); - } - List l = new LinkedList(); - for (Map.Entry entry : pendingResponse.entrySet()) { - l.add(entry.getValue()); - } - for (RequestEntry e : l) { - tryNotifyFailureWithSeqID( - e.sequenceId, error_code.error_types.ERR_SESSION_RESET, false); - } - } catch (Exception e) { - logger.error( - "failed to notify callers due to unexpected exception [state={}]: ", - cache.state.toString(), - e); - } finally { - logger.info("{}: mark the session to be disconnected from state={}", name(), cache.state); - // ensure the state must be set DISCONNECTED - cache = new VolatileFields(); - cache.state = ConnState.DISCONNECTED; - cache.nettyChannel = null; - fields = cache; + // NOTICE: + // 1. when a connection is reset, the timeout response + // is not answered in the order they query + // 2. It's likely that when the session is disconnecting + // but the caller of the api query/asyncQuery didn't notice + // this. In this case, we are relying on the timeout task. + try { + while (!pendingSend.isEmpty()) { + RequestEntry entry = pendingSend.poll(); + tryNotifyFailureWithSeqID( + entry.sequenceId, error_code.error_types.ERR_SESSION_RESET, false); } - } else { - logger.warn("{}: session is closed already", name()); + List pendingEntries = new LinkedList<>(); + for (Map.Entry entry : pendingResponse.entrySet()) { + pendingEntries.add(entry.getValue()); + } + for (RequestEntry entry : pendingEntries) { + tryNotifyFailureWithSeqID( + entry.sequenceId, error_code.error_types.ERR_SESSION_RESET, false); + } + } catch (Exception ex) { + logger.error( + "{}: failed to notify callers due to unexpected exception [state={}]: ", + name(), + cache.state.toString(), + ex); + } finally { + logger.info("{}: mark the session to be disconnected from state={}", name(), cache.state); + fields = new VolatileFields(ConnState.DISCONNECTED); } } + + // Reset the authentication once the connection is closed. + resetAuth(); + } + + // After the authentication is reset, a new Negotiation would be launched. + private void resetAuth() { + int pendingSize; + synchronized (authPendingSend) { + authSucceed = false; + pendingSize = authPendingSend.size(); + } + + logger.info( + "authentication is reset for session {}, with still {} request entries pending", + name(), + pendingSize); } // Notify the RPC sender if failure occurred. @@ -342,6 +358,7 @@ void tryNotifyFailureWithSeqID(int seqID, error_code.error_types errno, boolean } private void write(final RequestEntry entry, VolatileFields cache) { + // Under some circumstances requests are not allowed to be sent or delayed. if (!interceptorManager.onSendMessage(this, entry)) { return; } @@ -378,8 +395,8 @@ private ScheduledFuture addTimer(final int seqID, long timeoutInMillseconds) public void run() { try { tryNotifyFailureWithSeqID(seqID, error_code.error_types.ERR_TIMEOUT, true); - } catch (Exception e) { - logger.warn("try notify with sequenceID {} exception!", seqID, e); + } catch (Exception ex) { + logger.warn("{}: try notify with sequenceID {} exception!", name(), seqID, ex); } } }, @@ -388,43 +405,52 @@ public void run() { } public void onAuthSucceed() { - Queue swappedPendingSend = new LinkedList<>(); + Queue swappedPendingSend; synchronized (authPendingSend) { authSucceed = true; - swappedPendingSend.addAll(authPendingSend); + swappedPendingSend = new LinkedList<>(authPendingSend); authPendingSend.clear(); } - while (!swappedPendingSend.isEmpty()) { - RequestEntry e = swappedPendingSend.poll(); - if (pendingResponse.get(e.sequenceId) != null) { - write(e, fields); + logger.info( + "authentication is successful for session {}, then {} pending request entries would be sent", + name(), + swappedPendingSend.size()); + + sendPendingRequests(swappedPendingSend, fields); + } + + private void sendPendingRequests(Queue pendingEntries, VolatileFields cache) { + while (!pendingEntries.isEmpty()) { + RequestEntry entry = pendingEntries.poll(); + if (pendingResponse.get(entry.sequenceId) != null) { + write(entry, cache); } else { - logger.info("{}: {} is removed from pending, perhaps timeout", name(), e.sequenceId); + logger.info("{}: {} is removed from pending, perhaps timeout", name(), entry.sequenceId); } } } - // return value: + // Return value: // true - pend succeed // false - pend failed public boolean tryPendRequest(RequestEntry entry) { - // double check. the first one doesn't lock the lock. - // Because authSucceed only transfered from false to true. - // So if it is true now, it will not change in the later. - // But if it is false now, maybe it will change soon. So we should use lock to protect it. - if (!this.authSucceed) { - synchronized (authPendingSend) { - if (!this.authSucceed) { - if (!authPendingSend.offer(entry)) { - logger.warn("{}: pend request {} failed", name(), entry.sequenceId); - } - return true; - } + // Double check. + if (this.authSucceed) { + return false; + } + + synchronized (authPendingSend) { + if (this.authSucceed) { + return false; + } + + if (!authPendingSend.offer(entry)) { + logger.warn("{}: pend request {} failed", name(), entry.sequenceId); } } - return false; + return true; } final class DefaultHandler extends SimpleChannelInboundHandler { @@ -463,38 +489,56 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } } - // for test + // Only for test. ConnState getState() { return fields.state; } + interface AuthPendingChecker { + void onCheck(Queue realAuthPendingSend); + } + + void checkAuthPending(AuthPendingChecker checker) { + synchronized (authPendingSend) { + checker.onCheck(authPendingSend); + } + } + interface MessageResponseFilter { boolean abandonIt(error_code.error_types err, TMessage header); } MessageResponseFilter filter = null; - final ConcurrentHashMap pendingResponse = - new ConcurrentHashMap(); + final ConcurrentHashMap pendingResponse = new ConcurrentHashMap<>(); private final AtomicInteger seqId = new AtomicInteger(0); - final Queue pendingSend = new LinkedList(); + final Queue pendingSend = new LinkedList<>(); static final class VolatileFields { - public ConnState state = ConnState.DISCONNECTED; - public Channel nettyChannel = null; + public VolatileFields(ConnState state, Channel nettyChannel) { + this.state = state; + this.nettyChannel = nettyChannel; + } + + public VolatileFields(ConnState state) { + this(state, null); + } + + public ConnState state; + public Channel nettyChannel; } - volatile VolatileFields fields = new VolatileFields(); + volatile VolatileFields fields = new VolatileFields(ConnState.DISCONNECTED); private final rpc_address address; - private Bootstrap boot; - private EventLoopGroup timeoutTaskGroup; - private ReplicaSessionInterceptorManager interceptorManager; + private final Bootstrap boot; + private final EventLoopGroup timeoutTaskGroup; + private final ReplicaSessionInterceptorManager interceptorManager; private volatile boolean authSucceed; - final Queue authPendingSend = new LinkedList<>(); + private final Queue authPendingSend = new LinkedList<>(); - // Session will be actively closed if all the rpcs across `sessionResetTimeWindowMs` + // Session will be actively closed if all the RPCs across `sessionResetTimeWindowMs` // are timed out, in that case we suspect that the server is unavailable. // Timestamp of the first timed out rpc. diff --git a/java-client/src/main/java/org/apache/pegasus/security/AuthReplicaSessionInterceptor.java b/java-client/src/main/java/org/apache/pegasus/security/AuthReplicaSessionInterceptor.java index 6f99527232..55e1b9ba57 100644 --- a/java-client/src/main/java/org/apache/pegasus/security/AuthReplicaSessionInterceptor.java +++ b/java-client/src/main/java/org/apache/pegasus/security/AuthReplicaSessionInterceptor.java @@ -24,7 +24,7 @@ import org.apache.pegasus.rpc.interceptor.ReplicaSessionInterceptor; public class AuthReplicaSessionInterceptor implements ReplicaSessionInterceptor, Closeable { - private AuthProtocol protocol; + private final AuthProtocol protocol; public AuthReplicaSessionInterceptor(ClientOptions options) throws IllegalArgumentException { this.protocol = options.getCredential().getProtocol(); @@ -37,7 +37,7 @@ public void onConnected(ReplicaSession session) { @Override public boolean onSendMessage(ReplicaSession session, final ReplicaSession.RequestEntry entry) { - // tryPendRequest returns false means that the negotiation is succeed now + // tryPendRequest returns false means that the negotiation is successful now. return protocol.isAuthRequest(entry) || !session.tryPendRequest(entry); } diff --git a/java-client/src/test/java/org/apache/pegasus/rpc/async/ReplicaSessionTest.java b/java-client/src/test/java/org/apache/pegasus/rpc/async/ReplicaSessionTest.java index d33605de4e..98f8d979c2 100644 --- a/java-client/src/test/java/org/apache/pegasus/rpc/async/ReplicaSessionTest.java +++ b/java-client/src/test/java/org/apache/pegasus/rpc/async/ReplicaSessionTest.java @@ -18,15 +18,14 @@ */ package org.apache.pegasus.rpc.async; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import java.util.ArrayList; -import java.util.concurrent.Callable; +import java.util.Collection; +import java.util.Objects; +import java.util.Queue; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; import java.util.concurrent.atomic.AtomicBoolean; @@ -38,8 +37,10 @@ import org.apache.pegasus.client.ClientOptions; import org.apache.pegasus.client.PegasusClient; import org.apache.pegasus.operator.client_operator; +import org.apache.pegasus.operator.query_cfg_operator; import org.apache.pegasus.operator.rrdb_get_operator; import org.apache.pegasus.operator.rrdb_put_operator; +import org.apache.pegasus.replication.query_cfg_request; import org.apache.pegasus.rpc.KeyHasher; import org.apache.pegasus.rpc.interceptor.ReplicaSessionInterceptorManager; import org.apache.pegasus.tools.Toollet; @@ -49,17 +50,22 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; import org.mockito.Mockito; import org.slf4j.Logger; public class ReplicaSessionTest { - private String metaList = "127.0.0.1:34601,127.0.0.1:34602,127.0.0.1:34603"; private final Logger logger = org.slf4j.LoggerFactory.getLogger(ReplicaSessionTest.class); private ClusterManager manager; @BeforeEach - public void before() throws Exception { - manager = new ClusterManager(ClientOptions.builder().metaServers(metaList).build()); + public void before(TestInfo testInfo) throws Exception { + manager = + new ClusterManager( + ClientOptions.builder() + .metaServers("127.0.0.1:34601,127.0.0.1:34602,127.0.0.1:34603") + .build()); + logger.info("test started: {}", testInfo.getDisplayName()); } @AfterEach @@ -71,16 +77,16 @@ public void after() throws Exception { @Test public void testConnect() throws Exception { // test1: connect to an invalid address. - rpc_address addr = new rpc_address(); - addr.fromString("127.0.0.1:12345"); - ReplicaSession rs = manager.getReplicaSession(addr); + ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:12345"))); - ArrayList> callbacks = new ArrayList>(); + ArrayList> callbacks = new ArrayList<>(); for (int i = 0; i < 100; ++i) { final client_operator op = new rrdb_put_operator(new gpid(-1, -1), "", null, 0); final FutureTask cb = - new FutureTask( + new FutureTask<>( () -> { assertEquals(error_code.error_types.ERR_SESSION_RESET, op.rpc_error.errno); return null; @@ -98,30 +104,33 @@ public void testConnect() throws Exception { } } - final ReplicaSession cp_rs = rs; - Toollet.waitCondition(() -> ReplicaSession.ConnState.DISCONNECTED == cp_rs.getState(), 5); + ReplicaSession finalRs = rs; + assertTrue( + Toollet.waitCondition( + () -> ReplicaSession.ConnState.DISCONNECTED == finalRs.getState(), 5)); // test2: connect to a valid address, and then close the server. - addr.fromString("127.0.0.1:34801"); callbacks.clear(); + final rpc_address addr = rpc_address.fromIpPort("127.0.0.1:34801"); + assertNotNull(addr); + rs = manager.getReplicaSession(addr); rs.setMessageResponseFilter((err, header) -> true); for (int i = 0; i < 20; ++i) { // Send query request to replica server, expect it to be timeout. - final int index = i; - update_request req = + final update_request req = new update_request(new blob("hello".getBytes()), new blob("world".getBytes()), 0); - final client_operator op = new Toollet.test_operator(new gpid(-1, -1), req); - final rpc_address cp_addr = addr; + + final int index = i; final FutureTask cb = - new FutureTask( + new FutureTask<>( () -> { assertEquals(error_code.error_types.ERR_TIMEOUT, op.rpc_error.errno); - // for the last request, we kill the server + // Kill the server at the last request. if (index == 19) { - Toollet.closeServer(cp_addr); + Toollet.closeServer(addr); } return null; }); @@ -132,18 +141,20 @@ public void testConnect() throws Exception { for (int i = 0; i < 80; ++i) { // Re-send query request to replica server, but the timeout is longer. - update_request req = + final update_request req = new update_request(new blob("hello".getBytes()), new blob("world".getBytes()), 0); final client_operator op = new Toollet.test_operator(new gpid(-1, -1), req); + final FutureTask cb = - new FutureTask( + new FutureTask<>( () -> { assertEquals(error_code.error_types.ERR_SESSION_RESET, op.rpc_error.errno); return null; }); callbacks.add(cb); - // The request has longer timeout, so it should be responsed later than the server been + + // The request has longer timeout, so response should be received after the server was // killed. rs.asyncSend(op, cb, 2000, false); } @@ -177,26 +188,23 @@ public void recv_data(TProtocol iprot) throws TException { } } - rpc_address addr = new rpc_address(); - addr.fromString("127.0.0.1:34801"); - ReplicaSession rs = manager.getReplicaSession(addr); + final ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:34801"))); for (int pid = 0; pid < 16; pid++) { - // find a valid partition held on 127.0.0.1:34801 + // Find a valid partition held on 127.0.0.1:34801. blob req = new blob(PegasusClient.generateKey("a".getBytes(), "".getBytes())); final client_operator op = new test_operator(new gpid(1, pid), req); - FutureTask cb = - new FutureTask( - new Callable() { - @Override - public Void call() throws Exception { - if (op.rpc_error.errno != error_code.error_types.ERR_OBJECT_NOT_FOUND - && op.rpc_error.errno != error_code.error_types.ERR_INVALID_STATE - && op.rpc_error.errno != error_code.error_types.ERR_SESSION_RESET) { - assertEquals(error_code.error_types.ERR_INVALID_DATA, op.rpc_error.errno); - } - return null; + final FutureTask cb = + new FutureTask<>( + () -> { + if (op.rpc_error.errno != error_code.error_types.ERR_OBJECT_NOT_FOUND + && op.rpc_error.errno != error_code.error_types.ERR_INVALID_STATE + && op.rpc_error.errno != error_code.error_types.ERR_SESSION_RESET) { + assertEquals(error_code.error_types.ERR_INVALID_DATA, op.rpc_error.errno); } + return null; }); rs.asyncSend(op, cb, 2000, false); Tools.waitUninterruptable(cb, Integer.MAX_VALUE); @@ -205,11 +213,11 @@ public Void call() throws Exception { @Test public void testTryNotifyWithSequenceID() throws Exception { - rpc_address addr = new rpc_address(); - addr.fromString("127.0.0.1:34801"); - ReplicaSession rs = manager.getReplicaSession(addr); + final ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:34801"))); - // no pending RequestEntry, ensure no NPE thrown + // There's no pending RequestEntry, ensure no NPE thrown. assertTrue(rs.pendingResponse.isEmpty()); try { rs.tryNotifyFailureWithSeqID(100, error_code.error_types.ERR_TIMEOUT, false); @@ -223,73 +231,155 @@ public void testTryNotifyWithSequenceID() throws Exception { ReplicaSession.RequestEntry entry = new ReplicaSession.RequestEntry(); entry.sequenceId = 100; entry.callback = () -> passed.set(true); - entry.timeoutTask = null; // simulate the timeoutTask has been null + entry.timeoutTask = null; // Simulate the timeoutTask has been null. entry.op = new rrdb_put_operator(new gpid(1, 1), null, null, 0); rs.pendingResponse.put(100, entry); rs.tryNotifyFailureWithSeqID(100, error_code.error_types.ERR_TIMEOUT, false); assertTrue(passed.get()); - // simulate the entry has been removed, ensure no NPE thrown + // Simulate the entry has been removed, ensure no NPE thrown. rs.getAndRemoveEntry(entry.sequenceId); rs.tryNotifyFailureWithSeqID(entry.sequenceId, entry.op.rpc_error.errno, true); - // ensure mark session state to disconnect when TryNotifyWithSequenceID incur any exception - ReplicaSession mockRs = Mockito.spy(rs); + // Ensure mark session state to disconnect when TryNotifyWithSequenceID incur any exception. + final ReplicaSession mockRs = Mockito.spy(rs); mockRs.pendingSend.offer(entry); mockRs.fields.state = ReplicaSession.ConnState.CONNECTED; Mockito.doThrow(new Exception()) .when(mockRs) .tryNotifyFailureWithSeqID(entry.sequenceId, entry.op.rpc_error.errno, false); mockRs.markSessionDisconnect(); - assertEquals(mockRs.getState(), ReplicaSession.ConnState.DISCONNECTED); + assertEquals(ReplicaSession.ConnState.DISCONNECTED, mockRs.getState()); } @Test public void testCloseSession() throws InterruptedException { - rpc_address addr = new rpc_address(); - addr.fromString("127.0.0.1:34801"); - ReplicaSession rs = manager.getReplicaSession(addr); + final ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:34801"))); rs.tryConnect().awaitUninterruptibly(); Thread.sleep(200); - assertEquals(rs.getState(), ReplicaSession.ConnState.CONNECTED); + assertEquals(ReplicaSession.ConnState.CONNECTED, rs.getState()); rs.closeSession(); Thread.sleep(100); - assertEquals(rs.getState(), ReplicaSession.ConnState.DISCONNECTED); + assertEquals(ReplicaSession.ConnState.DISCONNECTED, rs.getState()); rs.fields.state = ReplicaSession.ConnState.CONNECTING; rs.closeSession(); Thread.sleep(100); - assertEquals(rs.getState(), ReplicaSession.ConnState.DISCONNECTED); + assertEquals(ReplicaSession.ConnState.DISCONNECTED, rs.getState()); } @Test public void testSessionConnectTimeout() throws InterruptedException { - rpc_address addr = new rpc_address(); - addr.fromString( - "www.baidu.com:34801"); // this website normally ignores incorrect request without replying + // This website normally ignores incorrect request without replying. + final rpc_address addr = rpc_address.fromIpPort("www.baidu.com:34801"); - long start = System.currentTimeMillis(); - EventLoopGroup rpcGroup = new NioEventLoopGroup(4); - EventLoopGroup timeoutTaskGroup = new NioEventLoopGroup(4); - ReplicaSession rs = + final long start = System.currentTimeMillis(); + + final EventLoopGroup rpcGroup = new NioEventLoopGroup(4); + final EventLoopGroup timeoutTaskGroup = new NioEventLoopGroup(4); + final ReplicaSession rs = new ReplicaSession( addr, rpcGroup, timeoutTaskGroup, 1000, 30, (ReplicaSessionInterceptorManager) null); rs.tryConnect().awaitUninterruptibly(); - long end = System.currentTimeMillis(); + + final long end = System.currentTimeMillis(); + assertEquals((end - start) / 1000, 1); // ensure connect failed within 1sec Thread.sleep(100); - assertEquals(rs.getState(), ReplicaSession.ConnState.DISCONNECTED); + assertEquals(ReplicaSession.ConnState.DISCONNECTED, rs.getState()); } @Test public void testSessionTryConnectWhenConnected() throws InterruptedException { - rpc_address addr = new rpc_address(); - addr.fromString("127.0.0.1:34801"); - ReplicaSession rs = manager.getReplicaSession(addr); + final ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:34801"))); rs.tryConnect().awaitUninterruptibly(); Thread.sleep(100); - assertEquals(rs.getState(), ReplicaSession.ConnState.CONNECTED); + + assertEquals(ReplicaSession.ConnState.CONNECTED, rs.getState()); assertNull(rs.tryConnect()); // do not connect again } + + @Test + public void testSessionAuth() throws InterruptedException, ExecutionException { + // Connect to the meta server. + final ReplicaSession rs = + manager.getReplicaSession( + Objects.requireNonNull(rpc_address.fromIpPort("127.0.0.1:34601"))); + rs.tryConnect().awaitUninterruptibly(); + Thread.sleep(100); + assertEquals(ReplicaSession.ConnState.CONNECTED, rs.getState()); + + // Send query_cfg_request to the meta server. + final query_cfg_request queryCfgReq = new query_cfg_request("temp", new ArrayList()); + final ReplicaSession.RequestEntry queryCfgEntry = new ReplicaSession.RequestEntry(); + queryCfgEntry.sequenceId = 100; + queryCfgEntry.op = new query_cfg_operator(new gpid(-1, -1), queryCfgReq); + final FutureTask cb = + new FutureTask<>( + () -> { + // queryCfgReq should be sent successfully to the meta server. + assertEquals(error_code.error_types.ERR_OK, queryCfgEntry.op.rpc_error.errno); + return null; + }); + queryCfgEntry.callback = cb; + queryCfgEntry.timeoutTask = null; + + // Also insert into pendingResponse since ReplicaSession#sendPendingRequests() would check + // sequenceId in it. + assertTrue(rs.pendingResponse.isEmpty()); + rs.pendingResponse.put(queryCfgEntry.sequenceId, queryCfgEntry); + + // Initially session has not been authenticated, and queryCfgEntry(id=100) would be pending. + assertTrue(rs.tryPendRequest(queryCfgEntry)); + + // queryCfgEntry(id=100) should be the only element in the pending queue. + rs.checkAuthPending( + (Queue realAuthPendingSend) -> { + assertEquals(1, realAuthPendingSend.size()); + ReplicaSession.RequestEntry entry = realAuthPendingSend.peek(); + assertNotNull(entry); + assertEquals(100, entry.sequenceId); + assertEquals(query_cfg_operator.class, entry.op.getClass()); + }); + + // Authentication is successful, then the pending queryCfgEntry(id=100) should be sent. + rs.onAuthSucceed(); + + // Wait callback to be done. + cb.get(); + + // After the callback for auth success was called, nothing should be in the queue. + rs.checkAuthPending( + (Queue realAuthPendingSend) -> { + assertTrue(realAuthPendingSend.isEmpty()); + }); + + // tryPendRequest would return false at any time once authentication passed. + queryCfgEntry.sequenceId = 101; + assertFalse(rs.tryPendRequest(queryCfgEntry)); + + // Now that authentication has passed, nothing should be in the queue. + rs.checkAuthPending( + (Queue realAuthPendingSend) -> { + assertTrue(realAuthPendingSend.isEmpty()); + }); + + // The session should keep connected before it is closed. + assertEquals(ReplicaSession.ConnState.CONNECTED, rs.getState()); + rs.closeSession(); + Thread.sleep(100); + assertEquals(ReplicaSession.ConnState.DISCONNECTED, rs.getState()); + + // Authentication would be reset after the session is closed. + queryCfgEntry.sequenceId = 102; + assertTrue(rs.tryPendRequest(queryCfgEntry)); + + // Clear the pending queue. + rs.checkAuthPending(Collection::clear); + } }