1 package com.guinetik.rr.http;
2
3 import com.guinetik.rr.request.RequestSpec;
4 import org.slf4j.Logger;
5 import org.slf4j.LoggerFactory;
6
7 import javax.net.ssl.SSLContext;
8 import java.util.Collections;
9 import java.util.HashMap;
10 import java.util.Map;
11 import java.util.Objects;
12 import java.util.concurrent.ConcurrentHashMap;
13 import java.util.concurrent.atomic.AtomicBoolean;
14 import java.util.concurrent.atomic.AtomicInteger;
15 import java.util.concurrent.atomic.AtomicLong;
16 import java.util.concurrent.atomic.AtomicReference;
17 import java.util.function.Predicate;
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public class CircuitBreakerClient implements RocketClient {
71 private static final Logger logger = LoggerFactory.getLogger(CircuitBreakerClient.class);
72
73
74
75
76
77
78
79 public enum State {
80
81 CLOSED,
82
83 OPEN,
84
85 HALF_OPEN
86 }
87
88
89
90
91 public enum FailurePolicy {
92
93 ALL_EXCEPTIONS,
94
95 SERVER_ERRORS_ONLY,
96
97
98 EXCLUDE_CLIENT_ERRORS,
99
100 CUSTOM
101 }
102
103 private final RocketClient delegate;
104 private final AtomicReference<State> state = new AtomicReference<>(State.CLOSED);
105 private final AtomicInteger failureCount = new AtomicInteger(0);
106 private final AtomicLong lastFailureTime = new AtomicLong(0);
107 private final AtomicLong lastResetTime = new AtomicLong(System.currentTimeMillis());
108 private final AtomicBoolean halfOpenTestInProgress = new AtomicBoolean(false);
109 private final int failureThreshold;
110 private final long resetTimeoutMs;
111 private final long failureDecayTimeMs;
112 private final FailurePolicy failurePolicy;
113 private final Predicate<RocketRestException> failurePredicate;
114
115
116 private final AtomicInteger totalRequests = new AtomicInteger(0);
117 private final AtomicInteger successfulRequests = new AtomicInteger(0);
118 private final AtomicInteger failedRequests = new AtomicInteger(0);
119 private final AtomicInteger rejectedRequests = new AtomicInteger(0);
120 private final AtomicInteger circuitTrips = new AtomicInteger(0);
121 private final Map<Integer, AtomicInteger> statusCodeCounts = new ConcurrentHashMap<>();
122
123
124
125
126
127
128 public CircuitBreakerClient(RocketClient delegate) {
129 this(delegate,
130 HttpConstants.CircuitBreaker.DEFAULT_FAILURE_THRESHOLD,
131 HttpConstants.CircuitBreaker.DEFAULT_RESET_TIMEOUT_MS);
132 }
133
134
135
136
137
138
139
140
141 public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs) {
142 this(delegate, failureThreshold, resetTimeoutMs,
143 HttpConstants.CircuitBreaker.DEFAULT_FAILURE_DECAY_TIME_MS,
144 FailurePolicy.ALL_EXCEPTIONS, null);
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159 public CircuitBreakerClient(RocketClient delegate, int failureThreshold, long resetTimeoutMs,
160 long failureDecayTimeMs, FailurePolicy failurePolicy,
161 Predicate<RocketRestException> failurePredicate) {
162 this.delegate = Objects.requireNonNull(delegate, "delegate must not be null");
163
164 if (failureThreshold < 1) {
165 throw new IllegalArgumentException("failureThreshold must be at least 1");
166 }
167 if (resetTimeoutMs < 0) {
168 throw new IllegalArgumentException("resetTimeoutMs must not be negative");
169 }
170 if (failureDecayTimeMs < 0) {
171 throw new IllegalArgumentException("failureDecayTimeMs must not be negative");
172 }
173
174 this.failureThreshold = failureThreshold;
175 this.resetTimeoutMs = resetTimeoutMs;
176 this.failureDecayTimeMs = failureDecayTimeMs;
177 this.failurePolicy = failurePolicy != null ? failurePolicy : FailurePolicy.ALL_EXCEPTIONS;
178
179
180 if (failurePolicy == FailurePolicy.CUSTOM && failurePredicate != null) {
181 this.failurePredicate = failurePredicate;
182 } else {
183 this.failurePredicate = createDefaultPredicate(this.failurePolicy);
184 }
185 }
186
187 @Override
188 public <Req, Res> Res execute(RequestSpec<Req, Res> requestSpec) throws RocketRestException {
189
190 checkFailureDecay();
191
192
193 totalRequests.incrementAndGet();
194
195
196 State currentState = state.get();
197 boolean isTestRequest = false;
198
199 if (currentState == State.OPEN) {
200 if (System.currentTimeMillis() - lastFailureTime.get() >= resetTimeoutMs) {
201
202 if (state.compareAndSet(State.OPEN, State.HALF_OPEN)) {
203 logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_HALF_OPEN);
204 currentState = State.HALF_OPEN;
205 } else {
206
207 currentState = state.get();
208 }
209 } else {
210
211 rejectedRequests.incrementAndGet();
212
213
214 long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
215
216
217 throw new CircuitBreakerOpenException(
218 HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
219 millisSinceFailure,
220 resetTimeoutMs
221 );
222 }
223 }
224
225
226 if (currentState == State.HALF_OPEN) {
227 if (!halfOpenTestInProgress.compareAndSet(false, true)) {
228
229 logger.debug(HttpConstants.CircuitBreaker.LOG_HALF_OPEN_TEST_IN_PROGRESS);
230 rejectedRequests.incrementAndGet();
231
232 long millisSinceFailure = System.currentTimeMillis() - lastFailureTime.get();
233 throw new CircuitBreakerOpenException(
234 HttpConstants.CircuitBreaker.CIRCUIT_OPEN,
235 millisSinceFailure,
236 resetTimeoutMs
237 );
238 }
239 isTestRequest = true;
240 }
241
242 try {
243
244 Res response = delegate.execute(requestSpec);
245
246
247 State stateBeforeSuccess = state.get();
248 if (stateBeforeSuccess == State.HALF_OPEN) {
249 if (state.compareAndSet(State.HALF_OPEN, State.CLOSED)) {
250 failureCount.set(0);
251 logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
252 }
253 }
254
255
256 successfulRequests.incrementAndGet();
257
258 return response;
259 } catch (RocketRestException e) {
260
261 failedRequests.incrementAndGet();
262
263
264 int statusCode = e.getStatusCode();
265 if (statusCode > 0) {
266 statusCodeCounts.computeIfAbsent(statusCode, code -> new AtomicInteger(0))
267 .incrementAndGet();
268 }
269
270
271 boolean isCountableFailure = shouldCountAsFailure(e);
272 if (isCountableFailure) {
273 handleFailure(e);
274 }
275
276
277
278 State currentStateAfterFailure = state.get();
279 if (isCountableFailure && currentState == State.CLOSED && currentStateAfterFailure == State.OPEN) {
280 throw new CircuitBreakerOpenException(
281 "Circuit opened due to failure: " + e.getMessage(),
282 e,
283 0,
284 resetTimeoutMs
285 );
286 }
287
288
289 throw e;
290 } finally {
291
292 if (isTestRequest) {
293 halfOpenTestInProgress.set(false);
294 }
295 }
296 }
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311 public <Req, Res> boolean performHealthCheck(RequestSpec<Req, Res> healthCheckRequest) {
312 try {
313 delegate.execute(healthCheckRequest);
314
315
316 State currentState = state.get();
317 if (currentState != State.CLOSED) {
318 state.set(State.CLOSED);
319 failureCount.set(0);
320 halfOpenTestInProgress.set(false);
321 logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED);
322 }
323
324 return true;
325 } catch (RocketRestException e) {
326
327 if (state.get() == State.HALF_OPEN) {
328 state.set(State.OPEN);
329 lastFailureTime.set(System.currentTimeMillis());
330 halfOpenTestInProgress.set(false);
331 logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
332 }
333
334 return false;
335 }
336 }
337
338
339
340
341
342 public void resetCircuit() {
343 state.set(State.CLOSED);
344 failureCount.set(0);
345 halfOpenTestInProgress.set(false);
346 logger.info(HttpConstants.CircuitBreaker.LOG_CIRCUIT_CLOSED + " (manual reset)");
347 }
348
349
350
351
352
353
354 public State getState() {
355 return state.get();
356 }
357
358
359
360
361
362
363 public int getFailureCount() {
364 return failureCount.get();
365 }
366
367
368
369
370
371
372 public Map<String, Object> getMetrics() {
373 Map<String, Object> metrics = new HashMap<>();
374
375
376 metrics.put("state", getStateAsString());
377 metrics.put("failureCount", failureCount.get());
378 metrics.put("failureThreshold", failureThreshold);
379 metrics.put("totalRequests", totalRequests.get());
380 metrics.put("successfulRequests", successfulRequests.get());
381 metrics.put("failedRequests", failedRequests.get());
382 metrics.put("rejectedRequests", rejectedRequests.get());
383 metrics.put("circuitTrips", circuitTrips.get());
384 metrics.put("halfOpenTestInProgress", halfOpenTestInProgress.get());
385
386
387 Map<String, Integer> statusCounts = new HashMap<>();
388 statusCodeCounts.forEach((code, count) -> statusCounts.put(code.toString(), count.get()));
389 metrics.put("statusCodes", statusCounts);
390
391
392 long lastFailure = lastFailureTime.get();
393 if (lastFailure > 0) {
394 metrics.put("millisSinceLastFailure", System.currentTimeMillis() - lastFailure);
395 }
396
397 return Collections.unmodifiableMap(metrics);
398 }
399
400
401
402
403
404
405 private String getStateAsString() {
406 switch (state.get()) {
407 case OPEN:
408 return HttpConstants.CircuitBreaker.STATUS_OPEN;
409 case CLOSED:
410 return HttpConstants.CircuitBreaker.STATUS_CLOSED;
411 case HALF_OPEN:
412 return HttpConstants.CircuitBreaker.STATUS_HALF_OPEN;
413 default:
414 return state.get().toString();
415 }
416 }
417
418 private void handleFailure(RocketRestException e) {
419 if (state.get() == State.HALF_OPEN) {
420
421 state.set(State.OPEN);
422 lastFailureTime.set(System.currentTimeMillis());
423 logger.warn(HttpConstants.CircuitBreaker.LOG_TEST_FAILED);
424 return;
425 }
426
427 int currentFailures = failureCount.incrementAndGet();
428 if (currentFailures >= failureThreshold && state.compareAndSet(State.CLOSED, State.OPEN)) {
429 lastFailureTime.set(System.currentTimeMillis());
430 logger.warn(HttpConstants.CircuitBreaker.LOG_CIRCUIT_OPENED, currentFailures);
431 circuitTrips.incrementAndGet();
432 }
433 }
434
435
436
437
438 private void checkFailureDecay() {
439 long now = System.currentTimeMillis();
440 long lastReset = lastResetTime.get();
441
442
443 if (state.get() == State.CLOSED && failureCount.get() > 0 &&
444 (now - lastReset) >= failureDecayTimeMs) {
445 if (failureCount.getAndSet(0) > 0) {
446 logger.debug(HttpConstants.CircuitBreaker.LOG_DECAY_RESET);
447 }
448 lastResetTime.set(now);
449 }
450 }
451
452
453
454
455 private Predicate<RocketRestException> createDefaultPredicate(FailurePolicy policy) {
456 switch (policy) {
457 case SERVER_ERRORS_ONLY:
458 return e -> e.getStatusCode() >= HttpConstants.StatusCodes.SERVER_ERROR_MIN &&
459 e.getStatusCode() <= HttpConstants.StatusCodes.SERVER_ERROR_MAX;
460 case EXCLUDE_CLIENT_ERRORS:
461 return e -> e.getStatusCode() < HttpConstants.StatusCodes.CLIENT_ERROR_MIN ||
462 e.getStatusCode() > HttpConstants.StatusCodes.CLIENT_ERROR_MAX;
463 case ALL_EXCEPTIONS:
464 default:
465 return e -> true;
466 }
467 }
468
469
470
471
472 private boolean shouldCountAsFailure(RocketRestException e) {
473 return failurePredicate.test(e);
474 }
475
476 @Override
477 public void configureSsl(SSLContext sslContext) {
478 delegate.configureSsl(sslContext);
479 }
480
481 @Override
482 public void setBaseUrl(String baseUrl) {
483 this.delegate.setBaseUrl(baseUrl);
484 }
485 }