View Javadoc
1   package com.guinetik.rr.auth;
2   
3   import org.slf4j.Logger;
4   import org.slf4j.LoggerFactory;
5   
6   import java.io.IOException;
7   import java.util.HashMap;
8   import java.util.Map;
9   
10  /**
11   * Authentication strategy that implements OAuth 2.0 assertion flow.
12   * This strategy implements a two-step OAuth flow:
13   * 1. Get an assertion from the Identity Provider endpoint by providing a private key;
14   * 2. Use the assertion to get the actual OAuth token from the token endpoint.
15   * This can be used with various identity providers like SAP, Azure AD, Okta, etc.
16   */
17  public class OAuth2AssertionStrategy extends AbstractOAuth2Strategy {
18  
19      private static final Logger logger = LoggerFactory.getLogger(OAuth2AssertionStrategy.class);
20  
21      private final String clientId;
22      private final String userId;
23      private final String privateKey;
24      private final String companyId;
25      private final String grantType;
26      private final String assertionUrl;
27      private final String tokenUrl;
28      private final Map<String, String> additionalAssertionParams;
29  
30      /**
31       * Creates a new OAuth 2.0 assertion strategy.
32       *
33       * @param clientId       the OAuth client ID
34       * @param userId         the user ID
35       * @param privateKey     the private key for assertion
36       * @param companyId      the company ID (optional, can be null)
37       * @param grantType      the OAuth grant type
38       * @param assertionUrl   the assertion endpoint URL
39       * @param tokenUrl       the token endpoint URL
40       */
41      public OAuth2AssertionStrategy(String clientId, String userId, String privateKey, 
42                            String companyId, String grantType, String assertionUrl, String tokenUrl) {
43          this(clientId, userId, privateKey, companyId, grantType, assertionUrl, tokenUrl, 
44               new HashMap<>(), new HashMap<>());
45      }
46  
47      /**
48       * Creates a new OAuth 2.0 assertion strategy with additional parameters.
49       *
50       * @param clientId                 the OAuth client ID
51       * @param userId                   the user ID
52       * @param privateKey               the private key for assertion
53       * @param companyId                the company ID (optional, can be null)
54       * @param grantType                the OAuth grant type
55       * @param assertionUrl             the assertion endpoint URL
56       * @param tokenUrl                 the token endpoint URL
57       * @param additionalAssertionParams additional parameters for assertion request
58       * @param additionalTokenParams    additional parameters for token request
59       */
60      public OAuth2AssertionStrategy(String clientId, String userId, String privateKey, 
61                            String companyId, String grantType, String assertionUrl, String tokenUrl,
62                            Map<String, String> additionalAssertionParams,
63                            Map<String, String> additionalTokenParams) {
64          super(tokenUrl, additionalTokenParams);
65          this.clientId = clientId;
66          this.userId = userId;
67          this.privateKey = privateKey;
68          this.companyId = companyId;
69          this.grantType = grantType;
70          this.assertionUrl = assertionUrl;
71          this.tokenUrl = tokenUrl;
72          this.additionalAssertionParams = additionalAssertionParams;
73      }
74  
75      @Override
76      public AuthType getType() {
77          return AuthType.OAUTH_ASSERTION;
78      }
79  
80      /**
81       * {@inheritDoc}
82       * @throws TokenRefreshException if any of the required parameters (clientId, userId, privateKey,
83       * grantType, assertionUrl, or tokenUrl) are missing.
84       */
85      @Override
86      protected void validateCredentials() {
87          if (clientId == null || userId == null || privateKey == null || 
88              grantType == null || assertionUrl == null || tokenUrl == null) {
89              throw new TokenRefreshException("Required credentials are missing for OAuth 2.0 assertion flow");
90          }
91      }
92  
93      /**
94       * {@inheritDoc}
95       * <p>
96       * This implementation handles the two-step OAuth 2.0 assertion flow:
97       * <ol>
98       *   <li>It first calls {@link #getAssertion()} to obtain an assertion from the configured assertion URL.</li>
99       *   <li>Then, it uses this assertion along with other parameters (clientId, userId, grantType, companyId if present)
100      *       to call the {@code super.refreshCredentials()} method, which performs the actual token request
101      *       to the configured token URL.</li>
102      * </ol>
103      *
104      * @return {@code true} if the token was successfully refreshed, {@code false} otherwise.
105      * @throws TokenRefreshException if token refresh fails at any step.
106      */
107     @Override
108     public boolean refreshCredentials() {
109         try {
110             // Step 1: Get assertion from the assertion endpoint
111             String assertion = getAssertion();
112             if (assertion.isEmpty()) {
113                 logger.error("Failed to get assertion from assertion endpoint");
114                 return false;
115             }
116             // Step 2: Get token using assertion and the parent class functionality
117             // Prepare token parameters including the assertion we just got
118             additionalParams.put("client_id", clientId);
119             additionalParams.put("user_id", userId);
120             additionalParams.put("grant_type", grantType);
121             if (companyId != null) {
122                 additionalParams.put("company_id", companyId);
123             }
124             additionalParams.put("assertion", assertion);
125             // Call the parent implementation to get the token
126             return super.refreshCredentials();
127         } catch (Exception e) {
128             logger.error("Failed to refresh token", e);
129             throw new TokenRefreshException("Failed to refresh token: " + e.getMessage());
130         }
131     }
132 
133     /**
134      * Retrieves an assertion token from the configured assertion URL.
135      * This method makes a POST request to the {@code assertionUrl} using parameters such as
136      * clientId, userId, the privateKey, and the target tokenUrl.
137      *
138      * @return The assertion string obtained from the endpoint.
139      * @throws IOException if an I/O error occurs during the request to the assertion endpoint.
140      * @throws TokenRefreshException if the assertion endpoint returns an error or an empty assertion.
141      */
142     private String getAssertion() throws IOException {
143         Map<String, String> assertionParams = new HashMap<>(additionalAssertionParams);
144         assertionParams.put("client_id", clientId);
145         assertionParams.put("user_id", userId);
146         assertionParams.put("token_url", tokenUrl);
147         assertionParams.put("private_key", privateKey);
148 
149         // Make the request to the assertion endpoint using the parent class's POST method
150         try {
151             return post(assertionUrl, assertionParams).trim();
152         } catch (Exception e) {
153             logger.error("Error getting assertion", e);
154             throw new IOException("Error getting assertion: " + e.getMessage(), e);
155         }
156     }
157 
158     /**
159      * {@inheritDoc}
160      * <p>
161      * Prepares parameters for the token request part of the assertion flow.
162      * This method is typically called by the parent class's {@code refreshToken} method.
163      * It includes clientId, userId, grantType, and companyId (if available).
164      * The assertion itself is expected to have been added to {@code additionalParams} by the
165      * overridden {@link #refreshCredentials()} method before this is called.
166      */
167     @Override
168     protected Map<String, String> prepareTokenRequestParams() {
169         // This method is called by the parent class's refreshToken method
170         Map<String, String> params = new HashMap<>();
171         params.put("client_id", clientId);
172         params.put("user_id", userId);
173         params.put("grant_type", grantType);
174         if (companyId != null) {
175             params.put("company_id", companyId);
176         }
177         // The assertion should be in additionalParams at this point        
178         return params;
179     }
180 }