Skip to content

Commit

Permalink
Refactor fortinet preferred_kex fix and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ktbyers committed Nov 8, 2024
1 parent d55d217 commit 8f30127
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
12 changes: 9 additions & 3 deletions netmiko/base_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,12 +472,18 @@ def __init__(
self.system_host_keys = system_host_keys
self.alt_host_keys = alt_host_keys
self.alt_key_file = alt_key_file
self.disabled_algorithms = disabled_algorithms or {}
self.disabled_algorithms = disabled_algorithms

if disable_sha2_fix:
sha2_pubkeys = ["rsa-sha2-256", "rsa-sha2-512"]
# Merge sha2_pubkeys into pubkeys and prevent duplicates with a set
self.disabled_algorithms["pubkeys"] = list(set(self.disabled_algorithms.get("pubkeys", []) + sha2_pubkeys))
if self.disabled_algorithms is None:
self.disabled_algorithms = {"pubkeys": sha2_pubkeys}
else:
# Merge sha2_pubkeys into pubkeys and prevent duplicates
current_pubkeys = self.disabled_algorithms.get("pubkeys", [])
self.disabled_algorithms["pubkeys"] = list(
set(current_pubkeys + sha2_pubkeys)
)

# For SSH proxy support
self.ssh_config_file = ssh_config_file
Expand Down
20 changes: 10 additions & 10 deletions netmiko/fortinet/fortinet_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@

class FortinetSSH(NoConfig, NoEnable, CiscoSSHConnection):
prompt_pattern = r"[#$]"
preferred_kex = {
"diffie-hellman-group14-sha1",
"diffie-hellman-group-exchange-sha1",
"diffie-hellman-group-exchange-sha256",
"diffie-hellman-group1-sha1",
}

def __init__(self, *args: Any, **kwargs: Any) -> None:
disabled_algorithms = kwargs.get("disabled_algorithms")
if disabled_algorithms is None:
# We only want these and disable the rest
_preferred_kex = {
"diffie-hellman-group14-sha1",
"diffie-hellman-group-exchange-sha1",
"diffie-hellman-group-exchange-sha256",
"diffie-hellman-group1-sha1",
}
paramiko_transport = getattr(paramiko, "Transport")
kwargs["disabled_algorithms"] = {
"kex": list(set(paramiko_transport._preferred_kex) - _preferred_kex)
}
paramiko_cur_kex = set(paramiko_transport._preferred_kex)
# Disable any kex not in allowed fortinet set
disabled_kex = list(paramiko_cur_kex - self.preferred_kex)
kwargs["disabled_algorithms"] = {"kex": disabled_kex}

super().__init__(*args, **kwargs)

Expand Down
62 changes: 61 additions & 1 deletion tests/unit/test_base_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from os.path import dirname, join
from threading import Lock

from netmiko import NetmikoTimeoutException, log
import paramiko
from netmiko import NetmikoTimeoutException, log, ConnectHandler
from netmiko.base_connection import BaseConnection

RESOURCE_FOLDER = join(dirname(dirname(__file__)), "etc")
Expand Down Expand Up @@ -493,3 +494,62 @@ def test_remove_SecretsFilter_after_disconnection():
connection.disconnect()

assert not log.filters


def test_fortinet_kex_values():
"""Verify KEX override in Fortinet driver works properly"""
connection = ConnectHandler(
host="testhost",
device_type="fortinet",
auto_connect=False, # No need to connect for the test purposes
)
paramiko_transport = getattr(paramiko, "Transport")
paramiko_default_kex = set(paramiko_transport._preferred_kex)

allowed_fortinet_kex = set(connection.preferred_kex)
disabled_kex = list(paramiko_default_kex - allowed_fortinet_kex)
allowed_kex = paramiko_default_kex & allowed_fortinet_kex

# Ensure disabled_kex matches expectations
assert disabled_kex == connection.disabled_algorithms.get("kex", [])
# Ensure allowed_kex is not an empty set
assert allowed_kex

connection.disconnect()


def test_disable_sha2_fix():
"""
Verify SHA2 fix works properly; test with fortinet device_type as it is more of an edge
case.
"""
connection = ConnectHandler(
host="testhost",
device_type="fortinet",
disable_sha2_fix=True,
auto_connect=False, # No need to connect for the test purposes
)
paramiko_transport = getattr(paramiko, "Transport")

# Verify fortinet kex fix and disable_sha2_fix work properly together
paramiko_default_kex = set(paramiko_transport._preferred_kex)
allowed_fortinet_kex = set(connection.preferred_kex)
disabled_kex = list(paramiko_default_kex - allowed_fortinet_kex)
allowed_kex = paramiko_default_kex & allowed_fortinet_kex

# Ensure disabled_kex matches expectations
assert disabled_kex == connection.disabled_algorithms.get("kex", [])
# Ensure allowed_kex is not an empty set
assert allowed_kex

# Verify 'sha2' algorithms have been disabled
paramiko_default_pubkeys = set(paramiko_transport._preferred_keys)
disabled_pubkey_algos = set(connection.disabled_algorithms.get("pubkeys", []))

allowed_pubkeys = paramiko_default_pubkeys - disabled_pubkey_algos
# Check allowed_pubkeys is not an empty set
assert allowed_pubkeys
# Check both 'sha2' pubkeys are not in allowed_pubkeys
assert {"rsa-sha2-512", "rsa-sha2-256"} & allowed_pubkeys == set()

connection.disconnect()

0 comments on commit 8f30127

Please sign in to comment.