1 package com.guinetik.rr.auth;
2
3 import com.guinetik.rr.RocketRestOptions;
4 import com.guinetik.rr.http.DefaultHttpClient;
5 import com.guinetik.rr.http.RocketHeaders;
6 import com.guinetik.rr.http.RocketRestException;
7 import com.guinetik.rr.json.JsonObjectMapper;
8 import com.guinetik.rr.request.RequestBuilder;
9 import com.guinetik.rr.request.RequestSpec;
10 import org.slf4j.Logger;
11 import org.slf4j.LoggerFactory;
12
13 import javax.net.ssl.SSLContext;
14 import java.io.IOException;
15 import java.io.UnsupportedEncodingException;
16 import java.time.Instant;
17 import java.util.Date;
18 import java.util.HashMap;
19 import java.util.Map;
20
21
22
23
24
25 public abstract class AbstractOAuth2Strategy implements AuthStrategy, RocketSSL.SSLAware {
26
27 protected final Logger logger = LoggerFactory.getLogger(getClass());
28
29 protected final Map<String, String> additionalParams;
30
31
32
33
34 protected final String oauthTokenUrl;
35
36 protected String accessToken;
37 protected Date tokenExpiryTime;
38 protected boolean isRefreshing;
39 protected DefaultHttpClient httpClient;
40
41
42
43
44
45
46
47 protected AbstractOAuth2Strategy(String tokenUrl, Map<String, String> additionalParams) {
48 this.oauthTokenUrl = tokenUrl;
49 this.additionalParams = additionalParams != null ? additionalParams : new HashMap<>();
50 this.isRefreshing = false;
51 }
52
53 @Override
54 public RocketHeaders applyAuthHeaders(RocketHeaders headers) {
55 if (accessToken != null && !accessToken.isEmpty()) {
56 headers.bearerAuth(accessToken);
57 }
58 return headers;
59 }
60
61 @Override
62 public boolean needsTokenRefresh() {
63 if (accessToken == null || accessToken.isEmpty()) {
64 return true;
65 }
66 if (tokenExpiryTime == null) {
67 return true;
68 }
69
70 return Instant.now().plusSeconds(300).isAfter(tokenExpiryTime.toInstant());
71 }
72
73
74
75
76
77
78
79
80
81 protected String post(String url, Map<String, String> formParams) throws IOException {
82
83 if (httpClient == null) {
84 String baseUrl = url.substring(0, url.lastIndexOf("/") + 1);
85 RocketRestOptions options = new RocketRestOptions();
86 options.set(RocketRestOptions.LOGGING_ENABLED, true);
87 options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
88 options.set(RocketRestOptions.LOG_REQUEST_BODY, true);
89 options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
90 httpClient = new DefaultHttpClient(baseUrl, options);
91 }
92
93 StringBuilder formBody = new StringBuilder();
94 try {
95 boolean first = true;
96 for (Map.Entry<String, String> entry : formParams.entrySet()) {
97 if (!first) {
98 formBody.append("&");
99 }
100 first = false;
101 formBody.append(entry.getKey())
102 .append("=")
103 .append(java.net.URLEncoder.encode(entry.getValue(), "UTF-8"));
104 }
105 } catch (UnsupportedEncodingException e) {
106 logger.error("Error encoding form parameters", e);
107 throw new TokenRefreshException("Error encoding form parameters", e);
108 }
109
110 RocketHeaders headers = new RocketHeaders()
111 .contentType(RocketHeaders.ContentTypes.APPLICATION_FORM);
112
113 String endpoint = url.substring(url.lastIndexOf("/") + 1);
114
115 RequestSpec<String, String> requestSpec = RequestBuilder.<String, String>post(endpoint)
116 .headers(headers)
117 .body(formBody.toString())
118 .responseType(String.class)
119 .build();
120
121 try {
122 return httpClient.execute(requestSpec);
123 } catch (RocketRestException e) {
124 logger.error("Error executing POST request", e);
125 logger.error("Status Code: {}", e.getStatusCode());
126 logger.error("Response Body: {}", e.getResponseBody());
127 throw e;
128 }
129 catch (Exception e) {
130 logger.error("Error executing POST request", e);
131 throw new IOException("Error executing POST request", e);
132 }
133 }
134
135 @Override
136 public boolean refreshCredentials() {
137 if (isRefreshing) {
138 logger.warn("Token refresh already in progress");
139 return false;
140 }
141
142 if (oauthTokenUrl == null || oauthTokenUrl.isEmpty()) {
143 throw new TokenRefreshException("Token URL is required for OAuth2 flow");
144 }
145
146 validateCredentials();
147
148 isRefreshing = true;
149
150 try {
151
152 Map<String, String> formParams = prepareTokenRequestParams();
153
154 formParams.putAll(additionalParams);
155
156 String responseString = post(oauthTokenUrl, formParams);
157
158 Map<String, Object> tokenResponse = JsonObjectMapper.jsonNodeToMap(JsonObjectMapper.getJsonNode(responseString));
159
160 return processTokenResponse(tokenResponse);
161 } catch (Exception e) {
162 logger.error("Error during token refresh", e);
163 throw new TokenRefreshException("Error during token refresh", e);
164 } finally {
165 isRefreshing = false;
166 }
167 }
168
169
170
171
172
173 protected abstract void validateCredentials();
174
175
176
177
178
179 protected abstract Map<String, String> prepareTokenRequestParams();
180
181
182
183
184
185
186
187 protected boolean processTokenResponse(Map<String, Object> tokenResponse) {
188 Object tokenObj = tokenResponse.get("access_token");
189 Object expiresInObj = tokenResponse.get("expires_in");
190 if (tokenObj != null) {
191 accessToken = tokenObj.toString();
192
193 long expiresIn = 3600;
194 if (expiresInObj != null) {
195 try {
196 expiresIn = Long.parseLong(expiresInObj.toString());
197 } catch (NumberFormatException e) {
198 logger.warn("Invalid expires_in value: {}", expiresInObj);
199 }
200 }
201 tokenExpiryTime = Date.from(Instant.now().plusSeconds(expiresIn));
202 logger.debug("Token refreshed successfully, expires in {} seconds", expiresIn);
203 return true;
204 } else {
205 logger.error("Token response did not contain access_token: {}", tokenResponse);
206 throw new TokenRefreshException("Token response did not contain access_token");
207 }
208 }
209
210
211
212
213
214 public String getAccessToken() {
215 return accessToken;
216 }
217
218
219
220
221
222 public Date getTokenExpiryTime() {
223 return tokenExpiryTime;
224 }
225
226
227
228
229
230 public void configureSsl(SSLContext sslContext) {
231 if (httpClient != null) {
232 httpClient.configureSsl(sslContext);
233 }
234 }
235 }