diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 422b76e3..b2fc446c 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -8,6 +8,7 @@ import os import sys import time +import uuid from urllib.parse import urlparse # Python 3+ from collections import UserDict # Python 3+ from typing import List, Optional, Union # Needed in Python 3.7 & 3.8 @@ -15,6 +16,7 @@ from .individual_cache import _IndividualCache as IndividualCache from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser from .cloudshell import _is_running_in_cloud_shell +from .sku import SKU, __version__ logger = logging.getLogger(__name__) @@ -480,7 +482,12 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource): "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254" ).strip("/") + "/metadata/identity/oauth2/token", params=params, - headers={"Metadata": "true"}, + headers={ + "Metadata": "true", + "x-client-SKU": SKU, + "x-client-Ver": __version__, + "x-ms-client-request-id": str(uuid.uuid4()), + }, ) try: payload = json.loads(resp.text) diff --git a/tests/test_mi.py b/tests/test_mi.py index 8e6b6b14..fd3834c8 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -3,6 +3,7 @@ import os import sys import time +import uuid from typing import List, Optional import unittest try: @@ -32,6 +33,8 @@ ) from msal.token_cache import is_subdict_of +EXPECTED_SKU = "MSAL.Python" # Hardcoded constant, not imported from product + class ManagedIdentityTestCase(unittest.TestCase): def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self): @@ -181,14 +184,23 @@ def _test_happy_path(self) -> callable: return mocked_method def test_happy_path_of_vm(self): - self._test_happy_path().assert_called_with( + mock_get = self._test_happy_path() + mock_get.assert_called_with( # The last call contained claims_challenge # but since IMDS doesn't support token_sha256_to_refresh, # the request shall remain the same as before 'http://169.254.169.254/metadata/identity/oauth2/token', params={'api-version': '2018-02-01', 'resource': 'R'}, - headers={'Metadata': 'true'}, + headers={ + 'Metadata': 'true', + 'x-client-SKU': EXPECTED_SKU, + 'x-client-Ver': ANY, + 'x-ms-client-request-id': ANY, + }, ) + # Validate correlation ID is a valid UUID + corr_id = mock_get.call_args.kwargs["headers"]["x-ms-client-request-id"] + uuid.UUID(corr_id) @patch.object(ManagedIdentityClient, "_ManagedIdentityClient__instance", "MixedCaseHostName") def test_happy_path_of_theoretical_mixed_case_hostname(self): @@ -200,11 +212,20 @@ def test_happy_path_of_theoretical_mixed_case_hostname(self): @patch.dict(os.environ, {"AZURE_POD_IDENTITY_AUTHORITY_HOST": "http://localhost:1234//"}) def test_happy_path_of_pod_identity(self): - self._test_happy_path().assert_called_with( + mock_get = self._test_happy_path() + mock_get.assert_called_with( 'http://localhost:1234/metadata/identity/oauth2/token', params={'api-version': '2018-02-01', 'resource': 'R'}, - headers={'Metadata': 'true'}, + headers={ + 'Metadata': 'true', + 'x-client-SKU': EXPECTED_SKU, + 'x-client-Ver': ANY, + 'x-ms-client-request-id': ANY, + }, ) + # Validate correlation ID is a valid UUID + corr_id = mock_get.call_args.kwargs["headers"]["x-ms-client-request-id"] + uuid.UUID(corr_id) def test_vm_error_should_be_returned_as_is(self): raw_error = '{"raw": "error format is undefined"}' @@ -229,8 +250,16 @@ def test_vm_resource_id_parameter_should_be_msi_res_id(self): mocked_method.assert_called_with( 'http://169.254.169.254/metadata/identity/oauth2/token', params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'}, - headers={'Metadata': 'true'}, + headers={ + 'Metadata': 'true', + 'x-client-SKU': EXPECTED_SKU, + 'x-client-Ver': ANY, + 'x-ms-client-request-id': ANY, + }, ) + # Validate correlation ID is a valid UUID + corr_id = mocked_method.call_args.kwargs["headers"]["x-ms-client-request-id"] + uuid.UUID(corr_id) @patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})