diff --git a/examples/inventory.hcloud.yml b/examples/inventory.hcloud.yml new file mode 100644 index 0000000..6a67e28 --- /dev/null +++ b/examples/inventory.hcloud.yml @@ -0,0 +1,7 @@ +# You can list the hosts using: +# ansible-inventory --list -i examples/inventory.hcloud.yml --extra-vars=network_name=my-network + +plugin: hetzner.hcloud.hcloud + +network: "{{ network_name }}" +status: [running] diff --git a/plugins/inventory/hcloud.py b/plugins/inventory/hcloud.py index 3bcddfb..f8ee043 100644 --- a/plugins/inventory/hcloud.py +++ b/plugins/inventory/hcloud.py @@ -144,8 +144,14 @@ from ansible.module_utils.common.text.converters import to_native from ansible.plugins.inventory import BaseInventoryPlugin, Cacheable, Constructable from ansible.utils.display import Display -from ..module_utils.client import HAS_DATEUTIL, HAS_REQUESTS -from ..module_utils.vendor import hcloud +from ..module_utils.client import ( + Client, + ClientException, + client_check_required_lib, + client_get_by_name_or_id, +) +from ..module_utils.vendor.hcloud import APIException +from ..module_utils.vendor.hcloud.networks import Network from ..module_utils.vendor.hcloud.servers import Server from ..module_utils.version import version @@ -196,13 +202,24 @@ else: InventoryServer = dict +def first_ipv6_address(network: str) -> str: + """ + Return the first address for a ipv6 network. + + :param network: IPv6 Network. + """ + return next(IPv6Network(network).hosts()) + + class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): NAME = "hetzner.hcloud.hcloud" inventory: InventoryData display: Display - client: hcloud.Client + client: Client + + network: Network | None def _configure_hcloud_client(self): # If api_token_env is not the default, print a deprecation warning and load the @@ -232,7 +249,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): # Resolve template string api_token = self.templar.template(api_token) - self.client = hcloud.Client( + self.client = Client( token=api_token, api_endpoint=api_endpoint, application_name="ansible-inventory", @@ -242,61 +259,47 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): try: # Ensure the api token is valid self.client.locations.get_list() - except hcloud.APIException as exception: + except APIException as exception: raise AnsibleError("Invalid Hetzner Cloud API Token.") from exception - def _get_servers(self): - if len(self.get_option("label_selector")) > 0: - self.servers = self.client.servers.get_all(label_selector=self.get_option("label_selector")) - else: - self.servers = self.client.servers.get_all() - - def _filter_servers(self): + def _validate_options(self) -> None: if self.get_option("network"): - network = self.templar.template(self.get_option("network"), fail_on_undefined=False) or self.get_option( - "network" - ) + network_param: str = self.get_option("network") + network_param = self.templar.template(network_param) + try: - self.network = self.client.networks.get_by_name(network) - if self.network is None: - self.network = self.client.networks.get_by_id(network) - except hcloud.APIException: - raise AnsibleError("The given network is not found.") + self.network = client_get_by_name_or_id(self.client, "networks", network_param) + except (ClientException, APIException) as exception: + raise AnsibleError(to_native(exception)) from exception - tmp = [] - for server in self.servers: - for server_private_network in server.private_net: - if server_private_network.network.id == self.network.id: - tmp.append(server) - self.servers = tmp + def _fetch_servers(self) -> list[Server]: + self._validate_options() - if self.get_option("locations"): - tmp = [] - for server in self.servers: - if server.datacenter.location.name in self.get_option("locations"): - tmp.append(server) - self.servers = tmp - - if self.get_option("types"): - tmp = [] - for server in self.servers: - if server.server_type.name in self.get_option("types"): - tmp.append(server) - self.servers = tmp - - if self.get_option("images"): - tmp = [] - for server in self.servers: - if server.image is not None and server.image.os_flavor in self.get_option("images"): - tmp.append(server) - self.servers = tmp + get_servers_params = {} + if self.get_option("label_selector"): + get_servers_params["label_selector"] = self.get_option("label_selector") if self.get_option("status"): - tmp = [] - for server in self.servers: - if server.status in self.get_option("status"): - tmp.append(server) - self.servers = tmp + get_servers_params["status"] = self.get_option("status") + + servers = self.client.servers.get_all(**get_servers_params) + + if self.get_option("network"): + servers = [s for s in servers if self.network.id in [p.network.id for p in s.private_net]] + + if self.get_option("locations"): + locations: list[str] = self.get_option("locations") + servers = [s for s in servers if s.datacenter.location.name in locations] + + if self.get_option("types"): + server_types: list[str] = self.get_option("types") + servers = [s for s in servers if s.server_type.name in server_types] + + if self.get_option("images"): + images: list[str] = self.get_option("images") + servers = [s for s in servers if s.image is not None and s.image.os_flavor in images] + + return servers def _build_inventory_server(self, server: Server) -> InventoryServer: server_dict: InventoryServer = {} @@ -311,7 +314,7 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): server_dict["ipv4"] = to_native(server.public_net.ipv4.ip) if server.public_net.ipv6: - server_dict["ipv6"] = to_native(self._first_ipv6_address(server.public_net.ipv6.ip)) + server_dict["ipv6"] = to_native(first_ipv6_address(server.public_net.ipv6.ip)) server_dict["ipv6_network"] = to_native(server.public_net.ipv6.network) server_dict["ipv6_network_mask"] = to_native(server.public_net.ipv6.network_mask) @@ -320,10 +323,11 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): ] if self.get_option("network"): - for server_private_network in server.private_net: + for private_net in server.private_net: # Set private_ipv4 if user filtered for one network - if server_private_network.network.id == self.network.id: - server_dict["private_ipv4"] = to_native(server_private_network.ip) + if private_net.network.id == self.network.id: + server_dict["private_ipv4"] = to_native(private_net.ip) + break # Server Type if server.server_type is not None: @@ -353,60 +357,54 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): return server_dict - def _get_server_ansible_host(self, server): + def _get_server_ansible_host(self, server: Server): if self.get_option("connect_with") == "public_ipv4": if server.public_net.ipv4: return to_native(server.public_net.ipv4.ip) - else: - raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified") + raise AnsibleError("Server has no public ipv4, but connect_with=public_ipv4 was specified") if self.get_option("connect_with") == "public_ipv6": if server.public_net.ipv6: - return to_native(self._first_ipv6_address(server.public_net.ipv6.ip)) - else: - raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified") + return to_native(first_ipv6_address(server.public_net.ipv6.ip)) + raise AnsibleError("Server has no public ipv6, but connect_with=public_ipv6 was specified") - elif self.get_option("connect_with") == "hostname": + if self.get_option("connect_with") == "hostname": # every server has a name, no need to guard this return to_native(server.name) - elif self.get_option("connect_with") == "ipv4_dns_ptr": + if self.get_option("connect_with") == "ipv4_dns_ptr": if server.public_net.ipv4: return to_native(server.public_net.ipv4.dns_ptr) - else: - raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified") + raise AnsibleError("Server has no public ipv4, but connect_with=ipv4_dns_ptr was specified") - elif self.get_option("connect_with") == "private_ipv4": + if self.get_option("connect_with") == "private_ipv4": if self.get_option("network"): - for server_private_network in server.private_net: - if server_private_network.network.id == self.network.id: - return to_native(server_private_network.ip) + for private_net in server.private_net: + if private_net.network.id == self.network.id: + return to_native(private_net.ip) else: raise AnsibleError("You can only connect via private IPv4 if you specify a network") - def _first_ipv6_address(self, network): - return next(IPv6Network(network).hosts()) - def verify_file(self, path): """Return the possibly of a file being consumable by this plugin.""" return super().verify_file(path) and path.endswith(("hcloud.yaml", "hcloud.yml")) - def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer | None], bool]: + def _get_cached_result(self, path, cache) -> tuple[list[InventoryServer], bool]: # false when refresh_cache or --flush-cache is used if not cache: - return None, False + return [], False # get the user-specified directive if not self.get_option("cache"): - return None, False + return [], False cache_key = self.get_cache_key(path) try: cached_result = self._cache[cache_key] except KeyError: # if cache expires or cache file doesn"t exist - return None, False + return [], False return cached_result, True @@ -426,24 +424,27 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): def parse(self, inventory, loader, path, cache=True): super().parse(inventory, loader, path, cache) - if not HAS_REQUESTS: - raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires requests.") - if not HAS_DATEUTIL: - raise AnsibleError("The Hetzner Cloud dynamic inventory plugin requires python-dateutil.") + try: + client_check_required_lib() + except ClientException as exception: + raise AnsibleError(to_native(exception)) from exception + + # Allow using extra variables arguments as template variables (e.g. + # '--extra-vars my_var=my_value') + self.templar.available_variables = self._vars self._read_config_data(path) self._configure_hcloud_client() - self.servers, cached = self._get_cached_result(path, cache) + servers, cached = self._get_cached_result(path, cache) if not cached: - self._get_servers() - self._filter_servers() - self.servers = [self._build_inventory_server(server) for server in self.servers] + with self.client.cached_session(): + servers = [self._build_inventory_server(s) for s in self._fetch_servers()] # Add a top group self.inventory.add_group(group=self.get_option("group")) - for server in self.servers: + for server in servers: self.inventory.add_host(server["name"], group=self.get_option("group")) for key, value in server.items(): self.inventory.set_variable(server["name"], key, value) @@ -475,4 +476,4 @@ class InventoryModule(BaseInventoryPlugin, Constructable, Cacheable): strict=strict, ) - self._update_cached_result(path, cache, self.servers) + self._update_cached_result(path, cache, servers) diff --git a/plugins/module_utils/client.py b/plugins/module_utils/client.py index d2178d2..352ca58 100644 --- a/plugins/module_utils/client.py +++ b/plugins/module_utils/client.py @@ -2,9 +2,11 @@ from __future__ import annotations +from contextlib import contextmanager + from ansible.module_utils.basic import missing_required_lib -from .vendor.hcloud import APIException, Client +from .vendor.hcloud import APIException, Client as ClientBase HAS_REQUESTS = True HAS_DATEUTIL = True @@ -61,3 +63,40 @@ def client_get_by_name_or_id(client: Client, resource: str, param: str | int): if exception.code == "not_found": raise _client_resource_not_found(resource, param) from exception raise exception + + +if HAS_REQUESTS: + + class CachedSession(requests.Session): + cache: dict[str, requests.Response] = {} + + def send(self, request: requests.PreparedRequest, **kwargs) -> requests.Response: # type: ignore[no-untyped-def] + """ + Send a given PreparedRequest. + """ + if request.method != "GET" or request.url is None: + return super().send(request, **kwargs) + + if request.url in self.cache: + return self.cache[request.url] + + response = super().send(request, **kwargs) + if response.ok: + self.cache[request.url] = response + + return response + + +class Client(ClientBase): + @contextmanager + def cached_session(self) -> None: + """ + Swap the client session during the scope of the context. The session will cache + all GET requests. + + Cached response will not expire, therefore the cached client must not be used + for long living scopes. + """ + self._requests_session = CachedSession() + yield + self._requests_session = requests.Session()