View Javadoc
1   package com.guinetik.rr.http;
2   
3   import com.guinetik.rr.request.RequestSpec;
4   import org.slf4j.Logger;
5   import org.slf4j.LoggerFactory;
6   
7   import javax.net.ssl.SSLContext;
8   import java.util.Collections;
9   import java.util.HashMap;
10  import java.util.Map;
11  import java.util.Objects;
12  import java.util.concurrent.ConcurrentHashMap;
13  import java.util.concurrent.atomic.AtomicBoolean;
14  import java.util.concurrent.atomic.AtomicInteger;
15  import java.util.concurrent.atomic.AtomicLong;
16  import java.util.concurrent.atomic.AtomicReference;
17  import java.util.function.Predicate;
18  
19  /**
20   * Decorator that adds circuit breaker resilience pattern to any {@link RocketClient}.
21   *
22   * <p>The circuit breaker pattern prevents cascading failures by failing fast when a downstream
23   * service appears unhealthy. This gives the service time to recover without being overwhelmed
24   * by requests that are likely to fail.
25   *
26   * <h2>Circuit Breaker States</h2>
27   * <ul>
28   *   <li><b>CLOSED</b> - Normal operation, requests pass through</li>
29   *   <li><b>OPEN</b> - Circuit is open, requests fail fast with {@link CircuitBreakerOpenException}</li>
30   *   <li><b>HALF_OPEN</b> - Testing if service recovered, next request determines state</li>
31   * </ul>
32   *
33   * <h2>Basic Usage</h2>
34   * <pre class="language-java"><code>
35   * // Wrap any RocketClient with circuit breaker
36   * RocketClient baseClient = new DefaultHttpClient("https://api.example.com");
37   * CircuitBreakerClient client = new CircuitBreakerClient(baseClient);
38   *
39   * try {
40   *     User user = client.execute(request);
41   * } catch (CircuitBreakerOpenException e) {
42   *     System.out.println("Service unavailable, retry after: " +
43   *         e.getEstimatedMillisUntilReset() + "ms");
44   * }
45   * </code></pre>
46   *
47   * <h2>Custom Configuration</h2>
48   * <pre class="language-java"><code>
49   * // Circuit opens after 3 failures, resets after 60 seconds
50   * CircuitBreakerClient client = new CircuitBreakerClient(
51   *     baseClient,
52   *     3,      // failure threshold
53   *     60000   // reset timeout in ms
54   * );
55   * </code></pre>
56   *
57   * <h2>Via RocketClientFactory</h2>
58   * <pre class="language-java"><code>
59   * RocketClient client = RocketClientFactory.builder("https://api.example.com")
60   *     .withCircuitBreaker(5, 30000)
61   *     .build();
62   * </code></pre>
63   *
64   * @author guinetik &lt;guinetik@gmail.com&gt;
65   * @see RocketClient
66   * @see CircuitBreakerOpenException
67   * @see RocketClientFactory
68   * @since 1.0.0
69   */
70  public class CircuitBreakerClient implements RocketClient {
71      private static final Logger logger = LoggerFactory.getLogger(CircuitBreakerClient.class);
72  
73      /**
74       * Circuit breaker state
75       * @see HttpConstants.CircuitBreaker#STATUS_CLOSED
76       * @see HttpConstants.CircuitBreaker#STATUS_OPEN
77       * @see HttpConstants.CircuitBreaker#STATUS_HALF_OPEN
78       */
79      public enum State {
80          /** Normal operation - {@link HttpConstants.CircuitBreaker#STATUS_CLOSED} */
81          CLOSED,
82          /** Circuit is open, fast-fail - {@link HttpConstants.CircuitBreaker#STATUS_OPEN} */
83          OPEN,
84          /** Testing if service is back - {@link HttpConstants.CircuitBreaker#STATUS_HALF_OPEN} */
85          HALF_OPEN
86      }
87  
88      /**
89       * Strategy for differentiating between failures
90       */
91      public enum FailurePolicy {
92          /** Counts all exceptions as failures */
93          ALL_EXCEPTIONS,
94          /** Only count status codes {@link HttpConstants.StatusCodes#SERVER_ERROR_MIN} (500) */
95          SERVER_ERRORS_ONLY,
96          /** Exclude status codes in range {@link HttpConstants.StatusCodes#CLIENT_ERROR_MIN} (400) to 
97              {@link HttpConstants.StatusCodes#CLIENT_ERROR_MAX} (499) */
98          EXCLUDE_CLIENT_ERRORS,
99          /** Use custom predicate */
100         CUSTOM
101     }
102 
103     private final RocketClient delegate;
104     private final AtomicReference<State> state = new AtomicReference<>(State.CLOSED);
105     private final AtomicInteger failureCount = new AtomicInteger(0);
106     private final AtomicLong lastFailureTime = new AtomicLong(0);
107     private final AtomicLong lastResetTime = new AtomicLong(System.currentTimeMillis());
108     private final AtomicBoolean halfOpenTestInProgress = new AtomicBoolean(false);
109     private final int failureThreshold;
110     private final long resetTimeoutMs;
111     private final long failureDecayTimeMs;
112     private final FailurePolicy failurePolicy;
113     private final Predicate<RocketRestException> failurePredicate;
114     
115     // Metrics
116     private final AtomicInteger totalRequests = new AtomicInteger(0);
117     private final AtomicInteger successfulRequests = new AtomicInteger(0);
118     private final AtomicInteger failedRequests = new AtomicInteger(0);
119     private final AtomicInteger rejectedRequests = new AtomicInteger(0);
120     private final AtomicInteger circuitTrips = new AtomicInteger(0);
121     private final Map<Integer, AtomicInteger> statusCodeCounts = new ConcurrentHashMap<>();
122 
123     /**
124      * Creates a circuit breaker with default settings
125      * 
126      * @param delegate The underlying client implementation
127      */
128     public CircuitBreakerClient(RocketClient delegate) {
129         this(delegate, 
130              HttpConstants.CircuitBreaker.DEFAULT_FAILURE_THRESHOLD,
131              HttpConstants.CircuitBreaker.DEFAULT_RESET_TIMEOUT_MS);
132     }
133 
134     /**
135      * Creates a circuit breaker with custom threshold and timeout
136      * 
137      * @param delegate The underlying client implementation
138      * @param failureThreshold Number of failures before opening circuit
139      * @param resetTimeoutMs Time in milliseconds before trying to close circuit
140      */
141     public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs) {
142         this(delegate, failureThreshold, resetTimeoutMs, 
143              HttpConstants.CircuitBreaker.DEFAULT_FAILURE_DECAY_TIME_MS,
144              FailurePolicy.ALL_EXCEPTIONS, null);
145     }
146 
147     /**
148      * Creates a fully customized circuit breaker
149      *
150      * @param delegate The underlying client implementation (must not be null)
151      * @param failureThreshold Number of failures before opening circuit
152      * @param resetTimeoutMs Time in milliseconds before trying to close circuit
153      * @param failureDecayTimeMs Time after which failure count starts to decay
154      * @param failurePolicy Strategy to determine what counts as a failure
155      * @param failurePredicate Custom predicate if policy is CUSTOM
156      * @throws NullPointerException if delegate is null
157      * @throws IllegalArgumentException if failureThreshold is less than 1 or timeouts are negative
158      */
159     public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs,
160                                long failureDecayTimeMs, FailurePolicy failurePolicy,
161                                Predicate<RocketRestException> failurePredicate) {
162         this.delegate = Objects.requireNonNull(delegate, "delegate must not be null");
163 
164         if (failureThreshold < 1) {
165             throw new IllegalArgumentException("failureThreshold must be at least 1");
166         }
167         if (resetTimeoutMs < 0) {
168             throw new IllegalArgumentException("resetTimeoutMs must not be negative");
169         }
170         if (failureDecayTimeMs < 0) {
171             throw new IllegalArgumentException("failureDecayTimeMs must not be negative");
172         }
173 
174         this.failureThreshold = failureThreshold;
175         this.resetTimeoutMs = resetTimeoutMs;
176         this.failureDecayTimeMs = failureDecayTimeMs;
177         this.failurePolicy = failurePolicy != null ? failurePolicy : FailurePolicy.ALL_EXCEPTIONS;
178 
179         // Set default predicate based on policy if not provided
180         if (failurePolicy == FailurePolicy.CUSTOM && failurePredicate != null) {
181             this.failurePredicate = failurePredicate;
182         } else {
183             this.failurePredicate = createDefaultPredicate(this.failurePolicy);
184         }
185     }
186 
187     @Override
188     public <Req, Res> Res execute(RequestSpec<Req, Res> requestSpec) throws RocketRestException {
189         // Check for periodic decay reset
190         checkFailureDecay();
191 
192         // Track metrics
193         totalRequests.incrementAndGet();
194 
195         // Check circuit state and handle state transitions
196         State currentState = state.get();
197         boolean isTestRequest = false;
198 
199         if (currentState == State.OPEN) {
200             if (System.currentTimeMillis() - lastFailureTime.get() >= resetTimeoutMs) {
201                 // Try moving to HALF_OPEN
202                 if (state.compareAndSet(State.OPEN, State.HALF_OPEN)) {
203                     logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_HALF_OPEN);
204                     currentState = State.HALF_OPEN;
205                 } else {
206                     // Another thread transitioned the state, re-read it
207                     currentState = state.get();
208                 }
209             } else {
210                 // Track rejected request metric
211                 rejectedRequests.incrementAndGet();
212 
213                 // Get time since last failure
214                 long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
215 
216                 // We're in OPEN state and the timeout hasn't elapsed, so fast-fail with circuit breaker exception
217                 throw new CircuitBreakerOpenException(
218                     HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
219                     millisSinceFailure,
220                     resetTimeoutMs
221                 );
222             }
223         }
224 
225         // In HALF_OPEN state, only allow one test request at a time
226         if (currentState == State.HALF_OPEN) {
227             if (!halfOpenTestInProgress.compareAndSet(false, true)) {
228                 // Another thread is already testing, reject this request
229                 logger.debug(HttpConstants.CircuitBreaker.LOG_HALF_OPEN_TEST_IN_PROGRESS);
230                 rejectedRequests.incrementAndGet();
231 
232                 long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
233                 throw new CircuitBreakerOpenException(
234                     HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
235                     millisSinceFailure,
236                     resetTimeoutMs
237                 );
238             }
239             isTestRequest = true;
240         }
241 
242         try {
243             // Execute the request with the delegate client
244             Res response = delegate.execute(requestSpec);
245 
246             // Success - reset circuit if needed (use compareAndSet to handle concurrent state changes)
247             State stateBeforeSuccess = state.get();
248             if (stateBeforeSuccess == State.HALF_OPEN) {
249                 if (state.compareAndSet(State.HALF_OPEN, State.CLOSED)) {
250                     failureCount.set(0);
251                     logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
252                 }
253             }
254 
255             // Track metrics
256             successfulRequests.incrementAndGet();
257 
258             return response;
259         } catch (RocketRestException e) {
260             // Track all failures in metrics
261             failedRequests.incrementAndGet();
262 
263             // Track status code in metrics if available
264             int statusCode = e.getStatusCode();
265             if (statusCode > 0) {
266                 statusCodeCounts.computeIfAbsent(statusCode, code -> new AtomicInteger(0))
267                                 .incrementAndGet();
268             }
269 
270             // Handle failure according to policy
271             boolean isCountableFailure = shouldCountAsFailure(e);
272             if (isCountableFailure) {
273                 handleFailure(e);
274             }
275 
276             // Check if we just opened the circuit from this failure
277             // Re-read state to get current value, not stale snapshot
278             State currentStateAfterFailure = state.get();
279             if (isCountableFailure && currentState == State.CLOSED && currentStateAfterFailure == State.OPEN) {
280                 throw new CircuitBreakerOpenException(
281                     "Circuit opened due to failure: " + e.getMessage(),
282                     e,
283                     0,  // Just opened, so 0 time since failure
284                     resetTimeoutMs
285                 );
286             }
287 
288             // Otherwise rethrow the original exception
289             throw e;
290         } finally {
291             // Always release the test lock if we acquired it
292             if (isTestRequest) {
293                 halfOpenTestInProgress.set(false);
294             }
295         }
296     }
297     
298     /**
299      * Performs a health check by trying to execute the given request.
300      * This can be used to manually test if the service is healthy.
301      * <p>
302      * Note: This method bypasses the normal circuit breaker flow and directly
303      * executes the request against the delegate. It's intended for external
304      * health monitoring systems.
305      *
306      * @param <Req> Request type
307      * @param <Res> Response type
308      * @param healthCheckRequest The request to use as a health check
309      * @return true if the service is healthy
310      */
311     public <Req, Res> boolean performHealthCheck(RequestSpec<Req, Res> healthCheckRequest) {
312         try {
313             delegate.execute(healthCheckRequest);
314 
315             // If we get here, service is healthy, close circuit
316             State currentState = state.get();
317             if (currentState != State.CLOSED) {
318                 state.set(State.CLOSED);
319                 failureCount.set(0);
320                 halfOpenTestInProgress.set(false);
321                 logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
322             }
323 
324             return true;
325         } catch (RocketRestException e) {
326             // Service still failing
327             if (state.get() == State.HALF_OPEN) {
328                 state.set(State.OPEN);
329                 lastFailureTime.set(System.currentTimeMillis());
330                 halfOpenTestInProgress.set(false);
331                 logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
332             }
333 
334             return false;
335         }
336     }
337 
338     /**
339      * Manually resets the circuit to closed state.
340      * This also resets all internal state including failure counts and test flags.
341      */
342     public void resetCircuit() {
343         state.set(State.CLOSED);
344         failureCount.set(0);
345         halfOpenTestInProgress.set(false);
346         logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED + " (manual reset)");
347     }
348     
349     /**
350      * Gets current circuit breaker state
351      * 
352      * @return Current state (OPEN, CLOSED, HALF_OPEN)
353      */
354     public State getState() {
355         return state.get();
356     }
357     
358     /**
359      * Gets current failure count
360      * 
361      * @return Current failure count
362      */
363     public int getFailureCount() {
364         return failureCount.get();
365     }
366     
367     /**
368      * Gets circuit breaker metrics
369      * 
370      * @return Map of metric name to value
371      */
372     public Map<String, Object> getMetrics() {
373         Map<String, Object> metrics = new HashMap<>();
374         
375         // Basic metrics
376         metrics.put("state", getStateAsString());
377         metrics.put("failureCount", failureCount.get());
378         metrics.put("failureThreshold", failureThreshold);
379         metrics.put("totalRequests", totalRequests.get());
380         metrics.put("successfulRequests", successfulRequests.get());
381         metrics.put("failedRequests", failedRequests.get());
382         metrics.put("rejectedRequests", rejectedRequests.get());
383         metrics.put("circuitTrips", circuitTrips.get());
384         metrics.put("halfOpenTestInProgress", halfOpenTestInProgress.get());
385         
386         // Add status code counts
387         Map<String, Integer> statusCounts = new HashMap<>();
388         statusCodeCounts.forEach((code, count) -> statusCounts.put(code.toString(), count.get()));
389         metrics.put("statusCodes", statusCounts);
390         
391         // Time metrics
392         long lastFailure = lastFailureTime.get();
393         if (lastFailure > 0) {
394             metrics.put("millisSinceLastFailure", System.currentTimeMillis() - lastFailure);
395         }
396         
397         return Collections.unmodifiableMap(metrics);
398     }
399     
400     /**
401      * Gets the state as a string constant
402      * 
403      * @return State as a string defined in HttpConstants
404      */
405     private String getStateAsString() {
406         switch (state.get()) {
407             case OPEN:
408                 return HttpConstants.CircuitBreaker.STATUS_OPEN;
409             case CLOSED:
410                 return HttpConstants.CircuitBreaker.STATUS_CLOSED;
411             case HALF_OPEN:
412                 return HttpConstants.CircuitBreaker.STATUS_HALF_OPEN;
413             default:
414                 return state.get().toString();
415         }
416     }
417 
418     private void handleFailure(RocketRestException e) {
419         if (state.get() == State.HALF_OPEN) {
420             // Failed during test request
421             state.set(State.OPEN);
422             lastFailureTime.set(System.currentTimeMillis());
423             logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
424             return;
425         }
426 
427         int currentFailures = failureCount.incrementAndGet();
428         if (currentFailures >= failureThreshold && state.compareAndSet(State.CLOSED, State.OPEN)) {
429             lastFailureTime.set(System.currentTimeMillis());
430             logger.warn(HttpConstants.CircuitBreaker.LOG_CIRCUIT_OPENED, currentFailures);
431             circuitTrips.incrementAndGet();
432         }
433     }
434     
435     /**
436      * Checks if it's time to decay the failure count
437      */
438     private void checkFailureDecay() {
439         long now = System.currentTimeMillis();
440         long lastReset = lastResetTime.get();
441         
442         // If we're in CLOSED state and decay time has passed, reset failure count
443         if (state.get() == State.CLOSED && failureCount.get() > 0 && 
444                 (now - lastReset) >= failureDecayTimeMs) {
445             if (failureCount.getAndSet(0) > 0) {
446                 logger.debug(HttpConstants.CircuitBreaker.LOG_DECAY_RESET);
447             }
448             lastResetTime.set(now);
449         }
450     }
451     
452     /**
453      * Creates appropriate failure predicate based on policy
454      */
455     private Predicate<RocketRestException> createDefaultPredicate(FailurePolicy policy) {
456         switch (policy) {
457             case SERVER_ERRORS_ONLY:
458                 return e -> e.getStatusCode() >= HttpConstants.StatusCodes.SERVER_ERROR_MIN &&
459                             e.getStatusCode() <= HttpConstants.StatusCodes.SERVER_ERROR_MAX;
460             case EXCLUDE_CLIENT_ERRORS:
461                 return e -> e.getStatusCode() < HttpConstants.StatusCodes.CLIENT_ERROR_MIN || 
462                             e.getStatusCode() > HttpConstants.StatusCodes.CLIENT_ERROR_MAX;
463             case ALL_EXCEPTIONS:
464             default:
465                 return e -> true;
466         }
467     }
468     
469     /**
470      * Determines if an exception should count toward failure threshold based on policy
471      */
472     private boolean shouldCountAsFailure(RocketRestException e) {
473         return failurePredicate.test(e);
474     }
475 
476     @Override
477     public void configureSsl(SSLContext sslContext) {
478         delegate.configureSsl(sslContext);
479     }
480 
481     @Override
482     public void setBaseUrl(String baseUrl) {
483         this.delegate.setBaseUrl(baseUrl);
484     }
485 }