CircuitBreakerClient.java
package com.guinetik.rr.http;
import com.guinetik.rr.request.RequestSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
/**
* Decorator that adds circuit breaker resilience pattern to any {@link RocketClient}.
*
* <p>The circuit breaker pattern prevents cascading failures by failing fast when a downstream
* service appears unhealthy. This gives the service time to recover without being overwhelmed
* by requests that are likely to fail.
*
* <h2>Circuit Breaker States</h2>
* <ul>
* <li><b>CLOSED</b> - Normal operation, requests pass through</li>
* <li><b>OPEN</b> - Circuit is open, requests fail fast with {@link CircuitBreakerOpenException}</li>
* <li><b>HALF_OPEN</b> - Testing if service recovered, next request determines state</li>
* </ul>
*
* <h2>Basic Usage</h2>
* <pre class="language-java"><code>
* // Wrap any RocketClient with circuit breaker
* RocketClient baseClient = new DefaultHttpClient("https://api.example.com");
* CircuitBreakerClient client = new CircuitBreakerClient(baseClient);
*
* try {
* User user = client.execute(request);
* } catch (CircuitBreakerOpenException e) {
* System.out.println("Service unavailable, retry after: " +
* e.getEstimatedMillisUntilReset() + "ms");
* }
* </code></pre>
*
* <h2>Custom Configuration</h2>
* <pre class="language-java"><code>
* // Circuit opens after 3 failures, resets after 60 seconds
* CircuitBreakerClient client = new CircuitBreakerClient(
* baseClient,
* 3, // failure threshold
* 60000 // reset timeout in ms
* );
* </code></pre>
*
* <h2>Via RocketClientFactory</h2>
* <pre class="language-java"><code>
* RocketClient client = RocketClientFactory.builder("https://api.example.com")
* .withCircuitBreaker(5, 30000)
* .build();
* </code></pre>
*
* @author guinetik <guinetik@gmail.com>
* @see RocketClient
* @see CircuitBreakerOpenException
* @see RocketClientFactory
* @since 1.0.0
*/
public class CircuitBreakerClient implements RocketClient {
private static final Logger logger = LoggerFactory.getLogger(CircuitBreakerClient.class);
/**
* Circuit breaker state
* @see HttpConstants.CircuitBreaker#STATUS_CLOSED
* @see HttpConstants.CircuitBreaker#STATUS_OPEN
* @see HttpConstants.CircuitBreaker#STATUS_HALF_OPEN
*/
public enum State {
/** Normal operation - {@link HttpConstants.CircuitBreaker#STATUS_CLOSED} */
CLOSED,
/** Circuit is open, fast-fail - {@link HttpConstants.CircuitBreaker#STATUS_OPEN} */
OPEN,
/** Testing if service is back - {@link HttpConstants.CircuitBreaker#STATUS_HALF_OPEN} */
HALF_OPEN
}
/**
* Strategy for differentiating between failures
*/
public enum FailurePolicy {
/** Counts all exceptions as failures */
ALL_EXCEPTIONS,
/** Only count status codes {@link HttpConstants.StatusCodes#SERVER_ERROR_MIN} (500) */
SERVER_ERRORS_ONLY,
/** Exclude status codes in range {@link HttpConstants.StatusCodes#CLIENT_ERROR_MIN} (400) to
{@link HttpConstants.StatusCodes#CLIENT_ERROR_MAX} (499) */
EXCLUDE_CLIENT_ERRORS,
/** Use custom predicate */
CUSTOM
}
private final RocketClient delegate;
private final AtomicReference<State> state = new AtomicReference<>(State.CLOSED);
private final AtomicInteger failureCount = new AtomicInteger(0);
private final AtomicLong lastFailureTime = new AtomicLong(0);
private final AtomicLong lastResetTime = new AtomicLong(System.currentTimeMillis());
private final AtomicBoolean halfOpenTestInProgress = new AtomicBoolean(false);
private final int failureThreshold;
private final long resetTimeoutMs;
private final long failureDecayTimeMs;
private final FailurePolicy failurePolicy;
private final Predicate<RocketRestException> failurePredicate;
// Metrics
private final AtomicInteger totalRequests = new AtomicInteger(0);
private final AtomicInteger successfulRequests = new AtomicInteger(0);
private final AtomicInteger failedRequests = new AtomicInteger(0);
private final AtomicInteger rejectedRequests = new AtomicInteger(0);
private final AtomicInteger circuitTrips = new AtomicInteger(0);
private final Map<Integer, AtomicInteger> statusCodeCounts = new ConcurrentHashMap<>();
/**
* Creates a circuit breaker with default settings
*
* @param delegate The underlying client implementation
*/
public CircuitBreakerClient(RocketClient delegate) {
this(delegate,
HttpConstants.CircuitBreaker.DEFAULT_FAILURE_THRESHOLD,
HttpConstants.CircuitBreaker.DEFAULT_RESET_TIMEOUT_MS);
}
/**
* Creates a circuit breaker with custom threshold and timeout
*
* @param delegate The underlying client implementation
* @param failureThreshold Number of failures before opening circuit
* @param resetTimeoutMs Time in milliseconds before trying to close circuit
*/
public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs) {
this(delegate, failureThreshold, resetTimeoutMs,
HttpConstants.CircuitBreaker.DEFAULT_FAILURE_DECAY_TIME_MS,
FailurePolicy.ALL_EXCEPTIONS, null);
}
/**
* Creates a fully customized circuit breaker
*
* @param delegate The underlying client implementation (must not be null)
* @param failureThreshold Number of failures before opening circuit
* @param resetTimeoutMs Time in milliseconds before trying to close circuit
* @param failureDecayTimeMs Time after which failure count starts to decay
* @param failurePolicy Strategy to determine what counts as a failure
* @param failurePredicate Custom predicate if policy is CUSTOM
* @throws NullPointerException if delegate is null
* @throws IllegalArgumentException if failureThreshold is less than 1 or timeouts are negative
*/
public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs,
long failureDecayTimeMs, FailurePolicy failurePolicy,
Predicate<RocketRestException> failurePredicate) {
this.delegate = Objects.requireNonNull(delegate, "delegate must not be null");
if (failureThreshold < 1) {
throw new IllegalArgumentException("failureThreshold must be at least 1");
}
if (resetTimeoutMs < 0) {
throw new IllegalArgumentException("resetTimeoutMs must not be negative");
}
if (failureDecayTimeMs < 0) {
throw new IllegalArgumentException("failureDecayTimeMs must not be negative");
}
this.failureThreshold = failureThreshold;
this.resetTimeoutMs = resetTimeoutMs;
this.failureDecayTimeMs = failureDecayTimeMs;
this.failurePolicy = failurePolicy != null ? failurePolicy : FailurePolicy.ALL_EXCEPTIONS;
// Set default predicate based on policy if not provided
if (failurePolicy == FailurePolicy.CUSTOM && failurePredicate != null) {
this.failurePredicate = failurePredicate;
} else {
this.failurePredicate = createDefaultPredicate(this.failurePolicy);
}
}
@Override
public <Req, Res> Res execute(RequestSpec<Req, Res> requestSpec) throws RocketRestException {
// Check for periodic decay reset
checkFailureDecay();
// Track metrics
totalRequests.incrementAndGet();
// Check circuit state and handle state transitions
State currentState = state.get();
boolean isTestRequest = false;
if (currentState == State.OPEN) {
if (System.currentTimeMillis() - lastFailureTime.get() >= resetTimeoutMs) {
// Try moving to HALF_OPEN
if (state.compareAndSet(State.OPEN, State.HALF_OPEN)) {
logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_HALF_OPEN);
currentState = State.HALF_OPEN;
} else {
// Another thread transitioned the state, re-read it
currentState = state.get();
}
} else {
// Track rejected request metric
rejectedRequests.incrementAndGet();
// Get time since last failure
long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
// We're in OPEN state and the timeout hasn't elapsed, so fast-fail with circuit breaker exception
throw new CircuitBreakerOpenException(
HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
millisSinceFailure,
resetTimeoutMs
);
}
}
// In HALF_OPEN state, only allow one test request at a time
if (currentState == State.HALF_OPEN) {
if (!halfOpenTestInProgress.compareAndSet(false, true)) {
// Another thread is already testing, reject this request
logger.debug(HttpConstants.CircuitBreaker.LOG_HALF_OPEN_TEST_IN_PROGRESS);
rejectedRequests.incrementAndGet();
long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
throw new CircuitBreakerOpenException(
HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
millisSinceFailure,
resetTimeoutMs
);
}
isTestRequest = true;
}
try {
// Execute the request with the delegate client
Res response = delegate.execute(requestSpec);
// Success - reset circuit if needed (use compareAndSet to handle concurrent state changes)
State stateBeforeSuccess = state.get();
if (stateBeforeSuccess == State.HALF_OPEN) {
if (state.compareAndSet(State.HALF_OPEN, State.CLOSED)) {
failureCount.set(0);
logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
}
}
// Track metrics
successfulRequests.incrementAndGet();
return response;
} catch (RocketRestException e) {
// Track all failures in metrics
failedRequests.incrementAndGet();
// Track status code in metrics if available
int statusCode = e.getStatusCode();
if (statusCode > 0) {
statusCodeCounts.computeIfAbsent(statusCode, code -> new AtomicInteger(0))
.incrementAndGet();
}
// Handle failure according to policy
boolean isCountableFailure = shouldCountAsFailure(e);
if (isCountableFailure) {
handleFailure(e);
}
// Check if we just opened the circuit from this failure
// Re-read state to get current value, not stale snapshot
State currentStateAfterFailure = state.get();
if (isCountableFailure && currentState == State.CLOSED && currentStateAfterFailure == State.OPEN) {
throw new CircuitBreakerOpenException(
"Circuit opened due to failure: " + e.getMessage(),
e,
0, // Just opened, so 0 time since failure
resetTimeoutMs
);
}
// Otherwise rethrow the original exception
throw e;
} finally {
// Always release the test lock if we acquired it
if (isTestRequest) {
halfOpenTestInProgress.set(false);
}
}
}
/**
* Performs a health check by trying to execute the given request.
* This can be used to manually test if the service is healthy.
* <p>
* Note: This method bypasses the normal circuit breaker flow and directly
* executes the request against the delegate. It's intended for external
* health monitoring systems.
*
* @param <Req> Request type
* @param <Res> Response type
* @param healthCheckRequest The request to use as a health check
* @return true if the service is healthy
*/
public <Req, Res> boolean performHealthCheck(RequestSpec<Req, Res> healthCheckRequest) {
try {
delegate.execute(healthCheckRequest);
// If we get here, service is healthy, close circuit
State currentState = state.get();
if (currentState != State.CLOSED) {
state.set(State.CLOSED);
failureCount.set(0);
halfOpenTestInProgress.set(false);
logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
}
return true;
} catch (RocketRestException e) {
// Service still failing
if (state.get() == State.HALF_OPEN) {
state.set(State.OPEN);
lastFailureTime.set(System.currentTimeMillis());
halfOpenTestInProgress.set(false);
logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
}
return false;
}
}
/**
* Manually resets the circuit to closed state.
* This also resets all internal state including failure counts and test flags.
*/
public void resetCircuit() {
state.set(State.CLOSED);
failureCount.set(0);
halfOpenTestInProgress.set(false);
logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED + " (manual reset)");
}
/**
* Gets current circuit breaker state
*
* @return Current state (OPEN, CLOSED, HALF_OPEN)
*/
public State getState() {
return state.get();
}
/**
* Gets current failure count
*
* @return Current failure count
*/
public int getFailureCount() {
return failureCount.get();
}
/**
* Gets circuit breaker metrics
*
* @return Map of metric name to value
*/
public Map<String, Object> getMetrics() {
Map<String, Object> metrics = new HashMap<>();
// Basic metrics
metrics.put("state", getStateAsString());
metrics.put("failureCount", failureCount.get());
metrics.put("failureThreshold", failureThreshold);
metrics.put("totalRequests", totalRequests.get());
metrics.put("successfulRequests", successfulRequests.get());
metrics.put("failedRequests", failedRequests.get());
metrics.put("rejectedRequests", rejectedRequests.get());
metrics.put("circuitTrips", circuitTrips.get());
metrics.put("halfOpenTestInProgress", halfOpenTestInProgress.get());
// Add status code counts
Map<String, Integer> statusCounts = new HashMap<>();
statusCodeCounts.forEach((code, count) -> statusCounts.put(code.toString(), count.get()));
metrics.put("statusCodes", statusCounts);
// Time metrics
long lastFailure = lastFailureTime.get();
if (lastFailure > 0) {
metrics.put("millisSinceLastFailure", System.currentTimeMillis() - lastFailure);
}
return Collections.unmodifiableMap(metrics);
}
/**
* Gets the state as a string constant
*
* @return State as a string defined in HttpConstants
*/
private String getStateAsString() {
switch (state.get()) {
case OPEN:
return HttpConstants.CircuitBreaker.STATUS_OPEN;
case CLOSED:
return HttpConstants.CircuitBreaker.STATUS_CLOSED;
case HALF_OPEN:
return HttpConstants.CircuitBreaker.STATUS_HALF_OPEN;
default:
return state.get().toString();
}
}
private void handleFailure(RocketRestException e) {
if (state.get() == State.HALF_OPEN) {
// Failed during test request
state.set(State.OPEN);
lastFailureTime.set(System.currentTimeMillis());
logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
return;
}
int currentFailures = failureCount.incrementAndGet();
if (currentFailures >= failureThreshold && state.compareAndSet(State.CLOSED, State.OPEN)) {
lastFailureTime.set(System.currentTimeMillis());
logger.warn(HttpConstants.CircuitBreaker.LOG_CIRCUIT_OPENED, currentFailures);
circuitTrips.incrementAndGet();
}
}
/**
* Checks if it's time to decay the failure count
*/
private void checkFailureDecay() {
long now = System.currentTimeMillis();
long lastReset = lastResetTime.get();
// If we're in CLOSED state and decay time has passed, reset failure count
if (state.get() == State.CLOSED && failureCount.get() > 0 &&
(now - lastReset) >= failureDecayTimeMs) {
if (failureCount.getAndSet(0) > 0) {
logger.debug(HttpConstants.CircuitBreaker.LOG_DECAY_RESET);
}
lastResetTime.set(now);
}
}
/**
* Creates appropriate failure predicate based on policy
*/
private Predicate<RocketRestException> createDefaultPredicate(FailurePolicy policy) {
switch (policy) {
case SERVER_ERRORS_ONLY:
return e -> e.getStatusCode() >= HttpConstants.StatusCodes.SERVER_ERROR_MIN &&
e.getStatusCode() <= HttpConstants.StatusCodes.SERVER_ERROR_MAX;
case EXCLUDE_CLIENT_ERRORS:
return e -> e.getStatusCode() < HttpConstants.StatusCodes.CLIENT_ERROR_MIN ||
e.getStatusCode() > HttpConstants.StatusCodes.CLIENT_ERROR_MAX;
case ALL_EXCEPTIONS:
default:
return e -> true;
}
}
/**
* Determines if an exception should count toward failure threshold based on policy
*/
private boolean shouldCountAsFailure(RocketRestException e) {
return failurePredicate.test(e);
}
@Override
public void configureSsl(SSLContext sslContext) {
delegate.configureSsl(sslContext);
}
@Override
public void setBaseUrl(String baseUrl) {
this.delegate.setBaseUrl(baseUrl);
}
}