1 package com.guinetik.rr.auth;
2
3 import org.slf4j.Logger;
4 import org.slf4j.LoggerFactory;
5
6 import java.io.IOException;
7 import java.util.HashMap;
8 import java.util.Map;
9
10 /**
11 * Authentication strategy that implements OAuth 2.0 assertion flow.
12 * This strategy implements a two-step OAuth flow:
13 * 1. Get an assertion from the Identity Provider endpoint by providing a private key;
14 * 2. Use the assertion to get the actual OAuth token from the token endpoint.
15 * This can be used with various identity providers like SAP, Azure AD, Okta, etc.
16 */
17 public class OAuth2AssertionStrategy extends AbstractOAuth2Strategy {
18
19 private static final Logger logger = LoggerFactory.getLogger(OAuth2AssertionStrategy.class);
20
21 private final String clientId;
22 private final String userId;
23 private final String privateKey;
24 private final String companyId;
25 private final String grantType;
26 private final String assertionUrl;
27 private final String tokenUrl;
28 private final Map<String, String> additionalAssertionParams;
29
30 /**
31 * Creates a new OAuth 2.0 assertion strategy.
32 *
33 * @param clientId the OAuth client ID
34 * @param userId the user ID
35 * @param privateKey the private key for assertion
36 * @param companyId the company ID (optional, can be null)
37 * @param grantType the OAuth grant type
38 * @param assertionUrl the assertion endpoint URL
39 * @param tokenUrl the token endpoint URL
40 */
41 public OAuth2AssertionStrategy(String clientId, String userId, String privateKey,
42 String companyId, String grantType, String assertionUrl, String tokenUrl) {
43 this(clientId, userId, privateKey, companyId, grantType, assertionUrl, tokenUrl,
44 new HashMap<>(), new HashMap<>());
45 }
46
47 /**
48 * Creates a new OAuth 2.0 assertion strategy with additional parameters.
49 *
50 * @param clientId the OAuth client ID
51 * @param userId the user ID
52 * @param privateKey the private key for assertion
53 * @param companyId the company ID (optional, can be null)
54 * @param grantType the OAuth grant type
55 * @param assertionUrl the assertion endpoint URL
56 * @param tokenUrl the token endpoint URL
57 * @param additionalAssertionParams additional parameters for assertion request
58 * @param additionalTokenParams additional parameters for token request
59 */
60 public OAuth2AssertionStrategy(String clientId, String userId, String privateKey,
61 String companyId, String grantType, String assertionUrl, String tokenUrl,
62 Map<String, String> additionalAssertionParams,
63 Map<String, String> additionalTokenParams) {
64 super(tokenUrl, additionalTokenParams);
65 this.clientId = clientId;
66 this.userId = userId;
67 this.privateKey = privateKey;
68 this.companyId = companyId;
69 this.grantType = grantType;
70 this.assertionUrl = assertionUrl;
71 this.tokenUrl = tokenUrl;
72 this.additionalAssertionParams = additionalAssertionParams;
73 }
74
75 @Override
76 public AuthType getType() {
77 return AuthType.OAUTH_ASSERTION;
78 }
79
80 /**
81 * {@inheritDoc}
82 * @throws TokenRefreshException if any of the required parameters (clientId, userId, privateKey,
83 * grantType, assertionUrl, or tokenUrl) are missing.
84 */
85 @Override
86 protected void validateCredentials() {
87 if (clientId == null || userId == null || privateKey == null ||
88 grantType == null || assertionUrl == null || tokenUrl == null) {
89 throw new TokenRefreshException("Required credentials are missing for OAuth 2.0 assertion flow");
90 }
91 }
92
93 /**
94 * {@inheritDoc}
95 * <p>
96 * This implementation handles the two-step OAuth 2.0 assertion flow:
97 * <ol>
98 * <li>It first calls {@link #getAssertion()} to obtain an assertion from the configured assertion URL.</li>
99 * <li>Then, it uses this assertion along with other parameters (clientId, userId, grantType, companyId if present)
100 * to call the {@code super.refreshCredentials()} method, which performs the actual token request
101 * to the configured token URL.</li>
102 * </ol>
103 *
104 * @return {@code true} if the token was successfully refreshed, {@code false} otherwise.
105 * @throws TokenRefreshException if token refresh fails at any step.
106 */
107 @Override
108 public boolean refreshCredentials() {
109 try {
110 // Step 1: Get assertion from the assertion endpoint
111 String assertion = getAssertion();
112 if (assertion.isEmpty()) {
113 logger.error("Failed to get assertion from assertion endpoint");
114 return false;
115 }
116 // Step 2: Get token using assertion and the parent class functionality
117 // Prepare token parameters including the assertion we just got
118 additionalParams.put("client_id", clientId);
119 additionalParams.put("user_id", userId);
120 additionalParams.put("grant_type", grantType);
121 if (companyId != null) {
122 additionalParams.put("company_id", companyId);
123 }
124 additionalParams.put("assertion", assertion);
125 // Call the parent implementation to get the token
126 return super.refreshCredentials();
127 } catch (Exception e) {
128 logger.error("Failed to refresh token", e);
129 throw new TokenRefreshException("Failed to refresh token: " + e.getMessage());
130 }
131 }
132
133 /**
134 * Retrieves an assertion token from the configured assertion URL.
135 * This method makes a POST request to the {@code assertionUrl} using parameters such as
136 * clientId, userId, the privateKey, and the target tokenUrl.
137 *
138 * @return The assertion string obtained from the endpoint.
139 * @throws IOException if an I/O error occurs during the request to the assertion endpoint.
140 * @throws TokenRefreshException if the assertion endpoint returns an error or an empty assertion.
141 */
142 private String getAssertion() throws IOException {
143 Map<String, String> assertionParams = new HashMap<>(additionalAssertionParams);
144 assertionParams.put("client_id", clientId);
145 assertionParams.put("user_id", userId);
146 assertionParams.put("token_url", tokenUrl);
147 assertionParams.put("private_key", privateKey);
148
149 // Make the request to the assertion endpoint using the parent class's POST method
150 try {
151 return post(assertionUrl, assertionParams).trim();
152 } catch (Exception e) {
153 logger.error("Error getting assertion", e);
154 throw new IOException("Error getting assertion: " + e.getMessage(), e);
155 }
156 }
157
158 /**
159 * {@inheritDoc}
160 * <p>
161 * Prepares parameters for the token request part of the assertion flow.
162 * This method is typically called by the parent class's {@code refreshToken} method.
163 * It includes clientId, userId, grantType, and companyId (if available).
164 * The assertion itself is expected to have been added to {@code additionalParams} by the
165 * overridden {@link #refreshCredentials()} method before this is called.
166 */
167 @Override
168 protected Map<String, String> prepareTokenRequestParams() {
169 // This method is called by the parent class's refreshToken method
170 Map<String, String> params = new HashMap<>();
171 params.put("client_id", clientId);
172 params.put("user_id", userId);
173 params.put("grant_type", grantType);
174 if (companyId != null) {
175 params.put("company_id", companyId);
176 }
177 // The assertion should be in additionalParams at this point
178 return params;
179 }
180 }