Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import os
import sys
import time
import uuid
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
from typing import List, Optional, Union # Needed in Python 3.7 & 3.8
from .token_cache import TokenCache
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
from .cloudshell import _is_running_in_cloud_shell
from .sku import SKU, __version__


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -480,7 +482,12 @@ def _obtain_token_on_azure_vm(http_client, managed_identity, resource):
"AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.254"
).strip("/") + "/metadata/identity/oauth2/token",
params=params,
headers={"Metadata": "true"},
headers={
"Metadata": "true",
"x-client-SKU": SKU,
"x-client-Ver": __version__,
"x-ms-client-request-id": str(uuid.uuid4()),
},
)
try:
payload = json.loads(resp.text)
Expand Down
39 changes: 34 additions & 5 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import time
import uuid
from typing import List, Optional
import unittest
try:
Expand Down Expand Up @@ -32,6 +33,8 @@
)
from msal.token_cache import is_subdict_of

EXPECTED_SKU = "MSAL.Python" # Hardcoded constant, not imported from product


class ManagedIdentityTestCase(unittest.TestCase):
def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self):
Expand Down Expand Up @@ -181,14 +184,23 @@ def _test_happy_path(self) -> callable:
return mocked_method

def test_happy_path_of_vm(self):
self._test_happy_path().assert_called_with(
mock_get = self._test_happy_path()
mock_get.assert_called_with(
# The last call contained claims_challenge
# but since IMDS doesn't support token_sha256_to_refresh,
# the request shall remain the same as before
'http://169.254.169.254/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R'},
headers={'Metadata': 'true'},
headers={
'Metadata': 'true',
'x-client-SKU': EXPECTED_SKU,
'x-client-Ver': ANY,
'x-ms-client-request-id': ANY,
},
)
# Validate correlation ID is a valid UUID
corr_id = mock_get.call_args.kwargs["headers"]["x-ms-client-request-id"]
uuid.UUID(corr_id)

@patch.object(ManagedIdentityClient, "_ManagedIdentityClient__instance", "MixedCaseHostName")
def test_happy_path_of_theoretical_mixed_case_hostname(self):
Expand All @@ -200,11 +212,20 @@ def test_happy_path_of_theoretical_mixed_case_hostname(self):

@patch.dict(os.environ, {"AZURE_POD_IDENTITY_AUTHORITY_HOST": "http://localhost:1234//"})
def test_happy_path_of_pod_identity(self):
self._test_happy_path().assert_called_with(
mock_get = self._test_happy_path()
mock_get.assert_called_with(
'http://localhost:1234/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R'},
headers={'Metadata': 'true'},
headers={
'Metadata': 'true',
'x-client-SKU': EXPECTED_SKU,
'x-client-Ver': ANY,
'x-ms-client-request-id': ANY,
},
)
# Validate correlation ID is a valid UUID
corr_id = mock_get.call_args.kwargs["headers"]["x-ms-client-request-id"]
uuid.UUID(corr_id)

def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
Expand All @@ -229,8 +250,16 @@ def test_vm_resource_id_parameter_should_be_msi_res_id(self):
mocked_method.assert_called_with(
'http://169.254.169.254/metadata/identity/oauth2/token',
params={'api-version': '2018-02-01', 'resource': 'R', 'msi_res_id': '1234'},
headers={'Metadata': 'true'},
headers={
'Metadata': 'true',
'x-client-SKU': EXPECTED_SKU,
'x-client-Ver': ANY,
'x-ms-client-request-id': ANY,
},
)
# Validate correlation ID is a valid UUID
corr_id = mocked_method.call_args.kwargs["headers"]["x-ms-client-request-id"]
uuid.UUID(corr_id)


@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
Expand Down
Loading