View Javadoc
1   package com.guinetik.rr.auth;
2   
3   import com.guinetik.rr.RocketRestOptions;
4   import com.guinetik.rr.http.DefaultHttpClient;
5   import com.guinetik.rr.http.RocketHeaders;
6   import com.guinetik.rr.http.RocketRestException;
7   import com.guinetik.rr.json.JsonObjectMapper;
8   import com.guinetik.rr.request.RequestBuilder;
9   import com.guinetik.rr.request.RequestSpec;
10  import org.slf4j.Logger;
11  import org.slf4j.LoggerFactory;
12  
13  import javax.net.ssl.SSLContext;
14  import java.io.IOException;
15  import java.io.UnsupportedEncodingException;
16  import java.time.Instant;
17  import java.util.Date;
18  import java.util.HashMap;
19  import java.util.Map;
20  
21  /**
22   * Abstract base class for OAuth 2.0 authentication strategies.
23   * Provides common functionality for different OAuth2 flows.
24   */
25  public abstract class AbstractOAuth2Strategy implements AuthStrategy, RocketSSL.SSLAware {
26  
27      protected final Logger logger = LoggerFactory.getLogger(getClass());
28  
29      protected final Map<String, String> additionalParams;
30  
31      /**
32       * OAuth 2.0 token endpoint URL to be used when refreshing credentials.
33       */
34      protected final String oauthTokenUrl;
35  
36      protected String accessToken;
37      protected Date tokenExpiryTime;
38      protected boolean isRefreshing;
39      protected DefaultHttpClient httpClient;
40  
41      /**
42       * Creates a new OAuth 2.0 strategy with additional parameters.
43       *
44       * @param tokenUrl        the OAuth 2.0 token endpoint URL
45       * @param additionalParams additional parameters to include in the token request
46       */
47      protected AbstractOAuth2Strategy(String tokenUrl, Map<String, String> additionalParams) {
48          this.oauthTokenUrl = tokenUrl;
49          this.additionalParams = additionalParams != null ? additionalParams : new HashMap<>();
50          this.isRefreshing = false;
51      }
52  
53      @Override
54      public RocketHeaders applyAuthHeaders(RocketHeaders headers) {
55          if (accessToken != null && !accessToken.isEmpty()) {
56              headers.bearerAuth(accessToken);
57          }
58          return headers;
59      }
60  
61      @Override
62      public boolean needsTokenRefresh() {
63          if (accessToken == null || accessToken.isEmpty()) {
64              return true;
65          }
66          if (tokenExpiryTime == null) {
67              return true;
68          }
69          // Refresh the token if it's expired or about to expire in the next 5 minutes
70          return Instant.now().plusSeconds(300).isAfter(tokenExpiryTime.toInstant());
71      }
72  
73      /**
74       * Helper method to perform POST requests.
75       * 
76       * @param url the URL to send the request to
77       * @param formParams the form parameters to include in the request
78       * @return the response body as a string
79       * @throws IOException if an I/O error occurs
80       */
81      protected String post(String url, Map<String, String> formParams) throws IOException {
82          // Initialize the HTTP client if not already done
83          if (httpClient == null) {
84              String baseUrl = url.substring(0, url.lastIndexOf("/") + 1);
85              RocketRestOptions options = new RocketRestOptions();
86              options.set(RocketRestOptions.LOGGING_ENABLED, true);
87              options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
88              options.set(RocketRestOptions.LOG_REQUEST_BODY, true);
89              options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
90              httpClient = new DefaultHttpClient(baseUrl, options);
91          }
92          // Build the form body
93          StringBuilder formBody = new StringBuilder();
94          try {
95              boolean first = true;
96              for (Map.Entry<String, String> entry : formParams.entrySet()) {
97                  if (!first) {
98                      formBody.append("&");
99                  }
100                 first = false;
101                 formBody.append(entry.getKey())
102                         .append("=")
103                         .append(java.net.URLEncoder.encode(entry.getValue(), "UTF-8"));
104             }
105         } catch (UnsupportedEncodingException e) {
106             logger.error("Error encoding form parameters", e);
107             throw new TokenRefreshException("Error encoding form parameters", e);
108         }
109         // Create headers
110         RocketHeaders headers = new RocketHeaders()
111                 .contentType(RocketHeaders.ContentTypes.APPLICATION_FORM);
112         // Extract a path from the full URL for the request
113         String endpoint = url.substring(url.lastIndexOf("/") + 1);
114         // Create the request using a builder
115         RequestSpec<String, String> requestSpec = RequestBuilder.<String, String>post(endpoint)
116                 .headers(headers)
117                 .body(formBody.toString())
118                 .responseType(String.class)
119                 .build();
120         // Execute the request
121         try {
122             return httpClient.execute(requestSpec);
123         } catch (RocketRestException e) {
124             logger.error("Error executing POST request", e);
125             logger.error("Status Code: {}", e.getStatusCode());
126             logger.error("Response Body: {}", e.getResponseBody());
127             throw e;
128         }
129         catch (Exception e) {
130             logger.error("Error executing POST request", e);
131             throw new IOException("Error executing POST request", e);
132         }
133     }
134 
135     @Override
136     public boolean refreshCredentials() {
137         if (isRefreshing) {
138             logger.warn("Token refresh already in progress");
139             return false;
140         }
141         // Check if the token URL is provided
142         if (oauthTokenUrl == null || oauthTokenUrl.isEmpty()) {
143             throw new TokenRefreshException("Token URL is required for OAuth2 flow");
144         }
145         // Validate the credentials
146         validateCredentials();
147         // Set the refreshing flag
148         isRefreshing = true;
149         // Try to refresh the token
150         try {
151             // Prepare form parameters
152             Map<String, String> formParams = prepareTokenRequestParams();
153             // Add any additional parameters
154             formParams.putAll(additionalParams);
155             // Execute POST request to get token
156             String responseString = post(oauthTokenUrl, formParams);
157             // Parse the response. Flexing some of the one-liners we pack with JsonObjectMapper
158             Map<String, Object> tokenResponse = JsonObjectMapper.jsonNodeToMap(JsonObjectMapper.getJsonNode(responseString));
159             // Process the token response
160             return processTokenResponse(tokenResponse);
161         } catch (Exception e) {
162             logger.error("Error during token refresh", e);
163             throw new TokenRefreshException("Error during token refresh", e);
164         } finally {
165             isRefreshing = false;
166         }
167     }
168 
169     /**
170      * Validates that all required credentials are present.
171      * @throws TokenRefreshException if any required credentials are missing
172      */
173     protected abstract void validateCredentials();
174 
175     /**
176      * Prepares the parameters for the token request.
177      * @return map of parameters to include in the token request
178      */
179     protected abstract Map<String, String> prepareTokenRequestParams();
180 
181     /**
182      * Processes the token response and extracts relevant information.
183      * @param tokenResponse the parsed token response
184      * @return true if the token was successfully refreshed, false otherwise
185      * @throws TokenRefreshException if there was an error processing the token response
186      */
187     protected boolean processTokenResponse(Map<String, Object> tokenResponse) {
188         Object tokenObj = tokenResponse.get("access_token");
189         Object expiresInObj = tokenResponse.get("expires_in");
190         if (tokenObj != null) {
191             accessToken = tokenObj.toString();
192             // Calculate expiry time
193             long expiresIn = 3600; // Default to 1 hour
194             if (expiresInObj != null) {
195                 try {
196                     expiresIn = Long.parseLong(expiresInObj.toString());
197                 } catch (NumberFormatException e) {
198                     logger.warn("Invalid expires_in value: {}", expiresInObj);
199                 }
200             }
201             tokenExpiryTime = Date.from(Instant.now().plusSeconds(expiresIn));
202             logger.debug("Token refreshed successfully, expires in {} seconds", expiresIn);
203             return true;
204         } else {
205             logger.error("Token response did not contain access_token: {}", tokenResponse);
206             throw new TokenRefreshException("Token response did not contain access_token");
207         }
208     }
209 
210     /**
211      * Gets the current access token.
212      * @return the current access token, or null if not yet obtained
213      */
214     public String getAccessToken() {
215         return accessToken;
216     }
217 
218     /**
219      * Gets the token expiry time.
220      * @return the token expiry time, or null if not yet obtained
221      */
222     public Date getTokenExpiryTime() {
223         return tokenExpiryTime;
224     }
225 
226     /**
227      * Sets the SSL context for secure token requests.
228      * @param sslContext the SSL context to use
229      */
230     public void configureSsl(SSLContext sslContext) {
231         if (httpClient != null) {
232             httpClient.configureSsl(sslContext);
233         }
234     }
235 }