From c478d0bff412c67280dfe8f08568de733f9425a1 Mon Sep 17 00:00:00 2001
From: Daniel Watkins <oddbloke@ubuntu.com>
Date: Tue, 31 Mar 2020 13:52:21 -0400
Subject: [PATCH] distros: replace invalid characters in mirror URLs with
 hyphens (#291)

This modifies _get_package_mirror_info to convert the hostnames of generated mirror URLs to their IDNA form, and then iterate through them replacing any invalid characters (i.e. anything other than letters, digits or a hyphen) with a hyphen.

This commit introduces the following changes in behaviour:

* generated mirror URLs with Unicode characters in their hostnames will have their hostnames converted to their all-ASCII IDNA form
* generated mirror URLs with invalid-for-hostname characters in their hostname will have those characters converted to hyphens
* generated mirror URLs which cannot be parsed by `urllib.parse.urlsplit` will not be considered for use
  * other configured patterns will still be considered
  * if all configured patterns fail to produce a URL that parses then the fallback mirror URL will be used

LP: #1868232
---
 cloudinit/distros/__init__.py        | 109 ++++++++++++++++++++++++++-
 cloudinit/distros/tests/test_init.py |  84 ++++++++++++++++-----
 2 files changed, 174 insertions(+), 19 deletions(-)

--- a/cloudinit/distros/__init__.py
+++ b/cloudinit/distros/__init__.py
@@ -13,6 +13,8 @@ import abc
 import os
 import re
 import stat
+import string
+import urllib.parse
 from io import StringIO
 
 from cloudinit import importer
@@ -50,6 +52,9 @@ _EC2_AZ_RE = re.compile('^[a-z][a-z]-(?:
 # Default NTP Client Configurations
 PREFERRED_NTP_CLIENTS = ['chrony', 'systemd-timesyncd', 'ntp', 'ntpdate']
 
+# Letters/Digits/Hyphen characters, for use in domain name validation
+LDH_ASCII_CHARS = string.ascii_letters + string.digits + "-"
+
 
 class Distro(metaclass=abc.ABCMeta):
 
@@ -720,6 +725,102 @@ class Distro(metaclass=abc.ABCMeta):
                 LOG.info("Added user '%s' to group '%s'", member, name)
 
 
+def _apply_hostname_transformations_to_url(url: str, transformations: list):
+    """
+    Apply transformations to a URL's hostname, return transformed URL.
+
+    This is a separate function because unwrapping and rewrapping only the
+    hostname portion of a URL is complex.
+
+    :param url:
+        The URL to operate on.
+    :param transformations:
+        A list of ``(str) -> Optional[str]`` functions, which will be applied
+        in order to the hostname portion of the URL.  If any function
+        (regardless of ordering) returns None, ``url`` will be returned without
+        any modification.
+
+    :return:
+        A string whose value is ``url`` with the hostname ``transformations``
+        applied, or ``None`` if ``url`` is unparseable.
+    """
+    try:
+        parts = urllib.parse.urlsplit(url)
+    except ValueError:
+        # If we can't even parse the URL, we shouldn't use it for anything
+        return None
+    new_hostname = parts.hostname
+
+    for transformation in transformations:
+        new_hostname = transformation(new_hostname)
+        if new_hostname is None:
+            # If a transformation returns None, that indicates we should abort
+            # processing and return `url` unmodified
+            return url
+
+    new_netloc = new_hostname
+    if parts.port is not None:
+        new_netloc = "{}:{}".format(new_netloc, parts.port)
+    return urllib.parse.urlunsplit(parts._replace(netloc=new_netloc))
+
+
+def _sanitize_mirror_url(url: str):
+    """
+    Given a mirror URL, replace or remove any invalid URI characters.
+
+    This performs the following actions on the URL's hostname:
+      * Checks if it is an IP address, returning the URL immediately if it is
+      * Converts it to its IDN form (see below for details)
+      * Replaces any non-Letters/Digits/Hyphen (LDH) characters in it with
+        hyphens
+      * TODO: Remove any leading/trailing hyphens from each domain name label
+
+    Before we replace any invalid domain name characters, we first need to
+    ensure that any valid non-ASCII characters in the hostname will not be
+    replaced, by ensuring the hostname is in its Internationalized domain name
+    (IDN) representation (see RFC 5890).  This conversion has to be applied to
+    the whole hostname (rather than just the substitution variables), because
+    the Punycode algorithm used by IDNA transcodes each part of the hostname as
+    a whole string (rather than encoding individual characters).  It cannot be
+    applied to the whole URL, because (a) the Punycode algorithm expects to
+    operate on domain names so doesn't output a valid URL, and (b) non-ASCII
+    characters in non-hostname parts of the URL aren't encoded via Punycode.
+
+    To put this in RFC 5890's terminology: before we remove or replace any
+    characters from our domain name (which we do to ensure that each label is a
+    valid LDH Label), we first ensure each label is in its A-label form.
+
+    (Note that Python's builtin idna encoding is actually IDNA2003, not
+    IDNA2008.  This changes the specifics of how some characters are encoded to
+    ASCII, but doesn't affect the logic here.)
+
+    :param url:
+        The URL to operate on.
+
+    :return:
+        A sanitized version of the URL, which will have been IDNA encoded if
+        necessary, or ``None`` if the generated string is not a parseable URL.
+    """
+    # Acceptable characters are LDH characters, plus "." to separate each label
+    acceptable_chars = LDH_ASCII_CHARS + "."
+    transformations = [
+        # This is an IP address, not a hostname, so no need to apply the
+        # transformations
+        lambda hostname: None if net.is_ip_address(hostname) else hostname,
+
+        # Encode with IDNA to get the correct characters (as `bytes`), then
+        # decode with ASCII so we return a `str`
+        lambda hostname: hostname.encode('idna').decode('ascii'),
+
+        # Replace any unacceptable characters with "-"
+        lambda hostname: ''.join(
+            c if c in acceptable_chars else "-" for c in hostname
+        ),
+    ]
+
+    return _apply_hostname_transformations_to_url(url, transformations)
+
+
 def _get_package_mirror_info(mirror_info, data_source=None,
                              mirror_filter=util.search_for_mirror):
     # given a arch specific 'mirror_info' entry (from package_mirrors)
@@ -748,9 +849,13 @@ def _get_package_mirror_info(mirror_info
         mirrors = []
         for tmpl in searchlist:
             try:
-                mirrors.append(tmpl % subst)
+                mirror = tmpl % subst
             except KeyError:
-                pass
+                continue
+
+            mirror = _sanitize_mirror_url(mirror)
+            if mirror is not None:
+                mirrors.append(mirror)
 
         found = mirror_filter(mirrors)
         if found:
--- a/cloudinit/distros/tests/test_init.py
+++ b/cloudinit/distros/tests/test_init.py
@@ -9,7 +9,18 @@ from unittest import mock
 
 import pytest
 
-from cloudinit.distros import _get_package_mirror_info
+from cloudinit.distros import _get_package_mirror_info, LDH_ASCII_CHARS
+
+
+# Define a set of characters we would expect to be replaced
+INVALID_URL_CHARS = [
+    chr(x) for x in range(127) if chr(x) not in LDH_ASCII_CHARS
+]
+for separator in [":", ".", "/", "#", "?", "@", "[", "]"]:
+    # Remove from the set characters that either separate hostname parts (":",
+    # "."), terminate hostnames ("/", "#", "?", "@"), or cause Python to be
+    # unable to parse URLs ("[", "]").
+    INVALID_URL_CHARS.remove(separator)
 
 
 class TestGetPackageMirrorInfo:
@@ -25,14 +36,16 @@ class TestGetPackageMirrorInfo:
         # Empty info gives empty return
         ({}, {}),
         # failsafe values used if present
-        ({'failsafe': {'primary': 'value', 'security': 'other'}},
-         {'primary': 'value', 'security': 'other'}),
+        ({'failsafe': {'primary': 'http://value', 'security': 'http://other'}},
+         {'primary': 'http://value', 'security': 'http://other'}),
         # search values used if present
-        ({'search': {'primary': ['value'], 'security': ['other']}},
-         {'primary': ['value'], 'security': ['other']}),
+        ({'search': {'primary': ['http://value'],
+                     'security': ['http://other']}},
+         {'primary': ['http://value'], 'security': ['http://other']}),
         # failsafe values used if search value not present
-        ({'search': {'primary': ['value']}, 'failsafe': {'security': 'other'}},
-         {'primary': ['value'], 'security': 'other'})
+        ({'search': {'primary': ['http://value']},
+          'failsafe': {'security': 'http://other'}},
+         {'primary': ['http://value'], 'security': 'http://other'})
     ])
     def test_get_package_mirror_info_failsafe(self, mirror_info, expected):
         """
@@ -48,26 +61,63 @@ class TestGetPackageMirrorInfo:
     def test_failsafe_used_if_all_search_results_filtered_out(self):
         """Test the failsafe option used if all search options eliminated."""
         mirror_info = {
-            'search': {'primary': ['value']}, 'failsafe': {'primary': 'other'}
+            'search': {'primary': ['http://value']},
+            'failsafe': {'primary': 'http://other'}
         }
-        assert {'primary': 'other'} == _get_package_mirror_info(
+        assert {'primary': 'http://other'} == _get_package_mirror_info(
             mirror_info, mirror_filter=lambda x: False)
 
     @pytest.mark.parametrize('availability_zone,region,patterns,expected', (
         # Test ec2_region alone
-        ('fk-fake-1f', None, ['EC2-%(ec2_region)s'], ['EC2-fk-fake-1']),
+        ('fk-fake-1f', None, ['http://EC2-%(ec2_region)s/ubuntu'],
+         ['http://ec2-fk-fake-1/ubuntu']),
         # Test availability_zone alone
-        ('fk-fake-1f', None, ['AZ-%(availability_zone)s'], ['AZ-fk-fake-1f']),
+        ('fk-fake-1f', None, ['http://AZ-%(availability_zone)s/ubuntu'],
+         ['http://az-fk-fake-1f/ubuntu']),
         # Test region alone
-        (None, 'fk-fake-1', ['RG-%(region)s'], ['RG-fk-fake-1']),
+        (None, 'fk-fake-1', ['http://RG-%(region)s/ubuntu'],
+         ['http://rg-fk-fake-1/ubuntu']),
         # Test that ec2_region is not available for non-matching AZs
         ('fake-fake-1f', None,
-         ['EC2-%(ec2_region)s', 'AZ-%(availability_zone)s'],
-         ['AZ-fake-fake-1f']),
+         ['http://EC2-%(ec2_region)s/ubuntu',
+          'http://AZ-%(availability_zone)s/ubuntu'],
+         ['http://az-fake-fake-1f/ubuntu']),
         # Test that template order maintained
-        (None, 'fake-region', ['RG-%(region)s-2', 'RG-%(region)s-1'],
-         ['RG-fake-region-2', 'RG-fake-region-1']),
-    ))
+        (None, 'fake-region',
+         ['http://RG-%(region)s-2/ubuntu', 'http://RG-%(region)s-1/ubuntu'],
+         ['http://rg-fake-region-2/ubuntu', 'http://rg-fake-region-1/ubuntu']),
+        # Test that non-ASCII hostnames are IDNA encoded;
+        # "IDNA-ТεЅТ̣".encode('idna') == b"xn--idna--4kd53hh6aba3q"
+        (None, 'ТεЅТ̣', ['http://www.IDNA-%(region)s.com/ubuntu'],
+         ['http://www.xn--idna--4kd53hh6aba3q.com/ubuntu']),
+        # Test that non-ASCII hostnames with a port are IDNA encoded;
+        # "IDNA-ТεЅТ̣".encode('idna') == b"xn--idna--4kd53hh6aba3q"
+        (None, 'ТεЅТ̣', ['http://www.IDNA-%(region)s.com:8080/ubuntu'],
+         ['http://www.xn--idna--4kd53hh6aba3q.com:8080/ubuntu']),
+        # Test that non-ASCII non-hostname parts of URLs are unchanged
+        (None, 'ТεЅТ̣', ['http://www.example.com/%(region)s/ubuntu'],
+         ['http://www.example.com/ТεЅТ̣/ubuntu']),
+        # Test that IPv4 addresses are unchanged
+        (None, 'fk-fake-1', ['http://192.168.1.1:8080/%(region)s/ubuntu'],
+         ['http://192.168.1.1:8080/fk-fake-1/ubuntu']),
+        # Test that IPv6 addresses are unchanged
+        (None, 'fk-fake-1',
+         ['http://[2001:67c:1360:8001::23]/%(region)s/ubuntu'],
+         ['http://[2001:67c:1360:8001::23]/fk-fake-1/ubuntu']),
+        # Test that unparseable URLs are filtered out of the mirror list
+        (None, 'inv[lid',
+         ['http://%(region)s.in.hostname/should/be/filtered',
+          'http://but.not.in.the.path/%(region)s'],
+         ['http://but.not.in.the.path/inv[lid']),
+    ) + (
+        # Dynamically generate a test case for each non-LDH
+        # (Letters/Digits/Hyphen) ASCII character, testing that it is
+        # substituted with a hyphen
+        tuple(
+            (None, 'fk{0}fake{0}1'.format(invalid_char),
+             ['http://%(region)s/ubuntu'], ['http://fk-fake-1/ubuntu'])
+            for invalid_char in INVALID_URL_CHARS))
+    )
     def test_substitution(self, availability_zone, region, patterns, expected):
         """Test substitution works as expected."""
         m_data_source = mock.Mock(
