AbstractOAuth2Strategy.java
package com.guinetik.rr.auth;
import com.guinetik.rr.RocketRestOptions;
import com.guinetik.rr.http.DefaultHttpClient;
import com.guinetik.rr.http.RocketHeaders;
import com.guinetik.rr.http.RocketRestException;
import com.guinetik.rr.json.JsonObjectMapper;
import com.guinetik.rr.request.RequestBuilder;
import com.guinetik.rr.request.RequestSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.time.Instant;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
/**
* Abstract base class for OAuth 2.0 authentication strategies.
* Provides common functionality for different OAuth2 flows.
*/
public abstract class AbstractOAuth2Strategy implements AuthStrategy, RocketSSL.SSLAware {
protected final Logger logger = LoggerFactory.getLogger(getClass());
protected final Map<String, String> additionalParams;
/**
* OAuth 2.0 token endpoint URL to be used when refreshing credentials.
*/
protected final String oauthTokenUrl;
protected String accessToken;
protected Date tokenExpiryTime;
protected boolean isRefreshing;
protected DefaultHttpClient httpClient;
/**
* Creates a new OAuth 2.0 strategy with additional parameters.
*
* @param tokenUrl the OAuth 2.0 token endpoint URL
* @param additionalParams additional parameters to include in the token request
*/
protected AbstractOAuth2Strategy(String tokenUrl, Map<String, String> additionalParams) {
this.oauthTokenUrl = tokenUrl;
this.additionalParams = additionalParams != null ? additionalParams : new HashMap<>();
this.isRefreshing = false;
}
@Override
public RocketHeaders applyAuthHeaders(RocketHeaders headers) {
if (accessToken != null && !accessToken.isEmpty()) {
headers.bearerAuth(accessToken);
}
return headers;
}
@Override
public boolean needsTokenRefresh() {
if (accessToken == null || accessToken.isEmpty()) {
return true;
}
if (tokenExpiryTime == null) {
return true;
}
// Refresh the token if it's expired or about to expire in the next 5 minutes
return Instant.now().plusSeconds(300).isAfter(tokenExpiryTime.toInstant());
}
/**
* Helper method to perform POST requests.
*
* @param url the URL to send the request to
* @param formParams the form parameters to include in the request
* @return the response body as a string
* @throws IOException if an I/O error occurs
*/
protected String post(String url, Map<String, String> formParams) throws IOException {
// Initialize the HTTP client if not already done
if (httpClient == null) {
String baseUrl = url.substring(0, url.lastIndexOf("/") + 1);
RocketRestOptions options = new RocketRestOptions();
options.set(RocketRestOptions.LOGGING_ENABLED, true);
options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
options.set(RocketRestOptions.LOG_REQUEST_BODY, true);
options.set(RocketRestOptions.LOG_RAW_RESPONSE, true);
httpClient = new DefaultHttpClient(baseUrl, options);
}
// Build the form body
StringBuilder formBody = new StringBuilder();
try {
boolean first = true;
for (Map.Entry<String, String> entry : formParams.entrySet()) {
if (!first) {
formBody.append("&");
}
first = false;
formBody.append(entry.getKey())
.append("=")
.append(java.net.URLEncoder.encode(entry.getValue(), "UTF-8"));
}
} catch (UnsupportedEncodingException e) {
logger.error("Error encoding form parameters", e);
throw new TokenRefreshException("Error encoding form parameters", e);
}
// Create headers
RocketHeaders headers = new RocketHeaders()
.contentType(RocketHeaders.ContentTypes.APPLICATION_FORM);
// Extract a path from the full URL for the request
String endpoint = url.substring(url.lastIndexOf("/") + 1);
// Create the request using a builder
RequestSpec<String, String> requestSpec = RequestBuilder.<String, String>post(endpoint)
.headers(headers)
.body(formBody.toString())
.responseType(String.class)
.build();
// Execute the request
try {
return httpClient.execute(requestSpec);
} catch (RocketRestException e) {
logger.error("Error executing POST request", e);
logger.error("Status Code: {}", e.getStatusCode());
logger.error("Response Body: {}", e.getResponseBody());
throw e;
}
catch (Exception e) {
logger.error("Error executing POST request", e);
throw new IOException("Error executing POST request", e);
}
}
@Override
public boolean refreshCredentials() {
if (isRefreshing) {
logger.warn("Token refresh already in progress");
return false;
}
// Check if the token URL is provided
if (oauthTokenUrl == null || oauthTokenUrl.isEmpty()) {
throw new TokenRefreshException("Token URL is required for OAuth2 flow");
}
// Validate the credentials
validateCredentials();
// Set the refreshing flag
isRefreshing = true;
// Try to refresh the token
try {
// Prepare form parameters
Map<String, String> formParams = prepareTokenRequestParams();
// Add any additional parameters
formParams.putAll(additionalParams);
// Execute POST request to get token
String responseString = post(oauthTokenUrl, formParams);
// Parse the response. Flexing some of the one-liners we pack with JsonObjectMapper
Map<String, Object> tokenResponse = JsonObjectMapper.jsonNodeToMap(JsonObjectMapper.getJsonNode(responseString));
// Process the token response
return processTokenResponse(tokenResponse);
} catch (Exception e) {
logger.error("Error during token refresh", e);
throw new TokenRefreshException("Error during token refresh", e);
} finally {
isRefreshing = false;
}
}
/**
* Validates that all required credentials are present.
* @throws TokenRefreshException if any required credentials are missing
*/
protected abstract void validateCredentials();
/**
* Prepares the parameters for the token request.
* @return map of parameters to include in the token request
*/
protected abstract Map<String, String> prepareTokenRequestParams();
/**
* Processes the token response and extracts relevant information.
* @param tokenResponse the parsed token response
* @return true if the token was successfully refreshed, false otherwise
* @throws TokenRefreshException if there was an error processing the token response
*/
protected boolean processTokenResponse(Map<String, Object> tokenResponse) {
Object tokenObj = tokenResponse.get("access_token");
Object expiresInObj = tokenResponse.get("expires_in");
if (tokenObj != null) {
accessToken = tokenObj.toString();
// Calculate expiry time
long expiresIn = 3600; // Default to 1 hour
if (expiresInObj != null) {
try {
expiresIn = Long.parseLong(expiresInObj.toString());
} catch (NumberFormatException e) {
logger.warn("Invalid expires_in value: {}", expiresInObj);
}
}
tokenExpiryTime = Date.from(Instant.now().plusSeconds(expiresIn));
logger.debug("Token refreshed successfully, expires in {} seconds", expiresIn);
return true;
} else {
logger.error("Token response did not contain access_token: {}", tokenResponse);
throw new TokenRefreshException("Token response did not contain access_token");
}
}
/**
* Gets the current access token.
* @return the current access token, or null if not yet obtained
*/
public String getAccessToken() {
return accessToken;
}
/**
* Gets the token expiry time.
* @return the token expiry time, or null if not yet obtained
*/
public Date getTokenExpiryTime() {
return tokenExpiryTime;
}
/**
* Sets the SSL context for secure token requests.
* @param sslContext the SSL context to use
*/
public void configureSsl(SSLContext sslContext) {
if (httpClient != null) {
httpClient.configureSsl(sslContext);
}
}
}