Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,18 @@ async def get_authorization_token(
decoded_bytes = base64.b64decode(additional_authentication_context[self.CLAIMS_KEY])
decoded_claim = decoded_bytes.decode("utf-8")

if not self._scopes:
self._scopes = [f"{parsed_url.scheme}://{parsed_url.netloc}/.default"]
span.set_attribute(self.SCOPES, ",".join(self._scopes))
# Derive the scope per-call from the request hostname.
scopes = self._resolve_scopes(parsed_url, span)
span.set_attribute(self.SCOPES, ",".join(scopes))
span.set_attribute(self.ADDITIONAL_CLAIMS_PROVIDED, bool(self._options))

if self._options:
result = self._credentials.get_token(
*self._scopes,
claims=decoded_claim,
enable_cae=self._is_cae_enabled,
**self._options
*scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled, **self._options
)
else:
result = self._credentials.get_token(
*self._scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled
*scopes, claims=decoded_claim, enable_cae=self._is_cae_enabled
)

if inspect.isawaitable(result):
Expand All @@ -127,3 +124,24 @@ def get_allowed_hosts_validator(self) -> AllowedHostsValidator:
AllowedHostsValidator: The allowed hosts validator.
"""
return self._allowed_hosts_validator

def _resolve_scopes(self, parsed_url, span) -> list[str]:
"""Return the scopes to pass to `get_token` for this request.

Caller-supplied scopes are returned verbatim. Otherwise a default
`.default` scope is derived from the request hostname only, so that
userinfo (`user:password@`) and ports (which Entra ID rejects for
`.default` scopes) are never copied into the scope or telemetry.
IPv6 literal brackets stripped by `urlparse` are re-added.
"""
if self._scopes:
return self._scopes
hostname = parsed_url.hostname
if not hostname:
span.set_attribute(self.IS_VALID_URL, False)
exc = HTTPError("Valid url scheme and host required")
span.record_exception(exc)
raise exc
if ":" in hostname:
hostname = f"[{hostname}]"
return [f"{parsed_url.scheme}://{hostname}/.default"]
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,96 @@ async def test_get_authorization_token_localhost():
token_provider = AzureIdentityAccessTokenProvider(DummySyncAzureTokenCredential(), None)
token = await token_provider.get_authorization_token('HTTP://LOCALHOST:8080')
assert token



class RecordingSyncAzureTokenCredential(DummySyncAzureTokenCredential):
"""Sync credential that records the scopes passed to get_token."""

def __init__(self):
self.received_scopes: list[tuple[str, ...]] = []

def get_token(self, *scopes, **kwargs):
self.received_scopes.append(scopes)
return super().get_token(*scopes, **kwargs)


@pytest.mark.asyncio
async def test_derived_scope_strips_userinfo_and_port():
"""The default `.default` scope passed to `get_token` must be derived
from the hostname only — never include userinfo or
a `:port` (which Entra ID rejects for `.default` scopes).
"""
credential = RecordingSyncAzureTokenCredential()
token_provider = AzureIdentityAccessTokenProvider(credential, None)

await token_provider.get_authorization_token(
'https://alice:secret@graph.microsoft.com:8443/v1.0/me'
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
)

assert credential.received_scopes == [('https://graph.microsoft.com/.default',)]


@pytest.mark.asyncio
async def test_derived_scope_is_not_cached_across_hosts():
"""The first URL's derived scope must not be reused for later URLs.

Previously the scope was assigned to `self._scopes`, making it sticky for
the lifetime of the provider instance and causing tokens to be requested
for the wrong audience after the first call.
"""
credential = RecordingSyncAzureTokenCredential()
token_provider = AzureIdentityAccessTokenProvider(credential, None)

await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me')
await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me')

assert credential.received_scopes == [
('https://graph.microsoft.com/.default',),
('https://graph.microsoft.us/.default',),
]
# Provider must not have cached derived scopes into `_scopes`.
assert token_provider._scopes == []


@pytest.mark.asyncio
async def test_explicit_scopes_are_respected():
credential = RecordingSyncAzureTokenCredential()
token_provider = AzureIdentityAccessTokenProvider(
credential, None, scopes=['https://graph.microsoft.com/.default']
)

await token_provider.get_authorization_token('https://graph.microsoft.com/v1.0/me')
await token_provider.get_authorization_token('https://graph.microsoft.us/v1.0/me')

assert credential.received_scopes == [
('https://graph.microsoft.com/.default',),
('https://graph.microsoft.com/.default',),
]


@pytest.mark.asyncio
async def test_derived_scope_rejects_url_without_hostname():
"""A URI whose netloc has no hostname (e.g. `https://@/path`) must not
silently derive a scope like `https://None/.default`; it must raise.
"""
credential = RecordingSyncAzureTokenCredential()
token_provider = AzureIdentityAccessTokenProvider(credential, None)

with pytest.raises(Exception):
await token_provider.get_authorization_token('https://@/path')
assert credential.received_scopes == []


@pytest.mark.asyncio
async def test_derived_scope_brackets_ipv6_hostname():
"""`urlparse` strips brackets from IPv6 literals; the derived scope
must re-add them so the resulting URL is syntactically valid.
"""
credential = RecordingSyncAzureTokenCredential()
token_provider = AzureIdentityAccessTokenProvider(credential, None)

await token_provider.get_authorization_token('https://[2001:db8::1]/v1.0/me')

assert credential.received_scopes == [('https://[2001:db8::1]/.default',)]


Loading