Compare commits

...

4 Commits

Author SHA1 Message Date
Raito Bezarius
9a3760926a nixos/tests/systemd-credentials-tpm2: adopt newest TPM support in test infra 2023-04-30 04:25:45 +02:00
Raito Bezarius
c3e3968544 nixos/lib/test-driver: wire up TPMs into the test driver 2023-04-30 04:25:44 +02:00
Raito Bezarius
685da3c995 nixos/qemu-vm: add virtualisation.tpm for running TPM in QEMU infrastructure 2023-04-30 04:25:43 +02:00
Raito Bezarius
975b5d3152 nixos/lib/test-driver: add Tpm driver 2023-04-30 04:25:42 +02:00
8 changed files with 168 additions and 64 deletions

View File

@@ -3,6 +3,7 @@ import argparse
import ptpython.repl
import os
import time
import json
from test_driver.logger import rootlog
from test_driver.driver import Driver
@@ -77,6 +78,13 @@ def main() -> None:
nargs="*",
help="vlans to span by the driver",
)
arg_parser.add_argument(
"--tpms",
metavar="TPMs",
action=EnvDefault,
envvar="tpms",
help="tpms blob to initialize by the driver (in JSON)",
)
arg_parser.add_argument(
"-o",
"--output_directory",
@@ -101,6 +109,7 @@ def main() -> None:
with Driver(
args.start_scripts,
args.vlans,
json.loads(args.tpms),
args.testscript.read_text(),
args.output_directory.resolve(),
args.keep_vm_state,
@@ -120,7 +129,7 @@ def generate_driver_symbols() -> None:
in user's test scripts. That list is then used by pyflakes to lint those
scripts.
"""
d = Driver([], [], "", Path())
d = Driver([], [], [], "", Path())
test_symbols = d.test_symbols()
with open("driver-symbols", "w") as fp:
fp.write(",".join(test_symbols.keys()))

View File

@@ -1,12 +1,23 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Iterator, List, Union, Optional, Callable, ContextManager
from typing import (
Any,
Dict,
Iterator,
List,
Union,
Optional,
Callable,
ContextManager,
Tuple,
)
import os
import tempfile
from test_driver.logger import rootlog
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.vlan import VLan
from test_driver.tpm import Tpm
from test_driver.polling_condition import PollingCondition
@@ -35,18 +46,21 @@ class Driver:
tests: str
vlans: List[VLan]
machines: List[Machine]
tpms: Dict[str, Tpm]
polling_conditions: List[PollingCondition]
def __init__(
self,
start_scripts: List[str],
vlans: List[int],
tpms: List[Dict[str, str]],
tests: str,
out_dir: Path,
keep_vm_state: bool = False,
):
self.tests = tests
self.out_dir = out_dir
self.polling_conditions = []
tmp_dir = get_tmp_dir()
@@ -54,12 +68,22 @@ class Driver:
vlans = list(set(vlans))
self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
with rootlog.nested("start all TPMs"):
self.tpms = {
tpm_data["machine_name"]: Tpm(
tpm_data["swtpm_binary_path"], tpm_data["socket_path"]
)
for tpm_data in tpms
}
for tpm in self.tpms.values():
tpm.start()
# Monitor TPMs for unexpected crashes.
self.polling_conditions.append(PollingCondition(tpm.check))
def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
for s in scripts:
yield NixStartScript(s)
self.polling_conditions = []
self.machines = [
Machine(
start_command=cmd,
@@ -68,6 +92,7 @@ class Driver:
tmp_dir=tmp_dir,
callbacks=[self.check_polling_conditions],
out_dir=self.out_dir,
tpm=self.tpms.get(cmd.machine_name),
)
for cmd in cmd(start_scripts)
]
@@ -111,13 +136,17 @@ class Driver:
serial_stdout_off=self.serial_stdout_off,
serial_stdout_on=self.serial_stdout_on,
polling_condition=self.polling_condition,
Machine=Machine, # for typing
# for typing
Machine=Machine,
Tpm=Tpm,
)
machine_symbols = {m.name: m for m in self.machines}
tpm_symbols = {f"tpm_{m.name}": m.tpm for m in self.machines}
# If there's exactly one machine, make it available under the name
# "machine", even if it's not called that.
if len(self.machines) == 1:
(machine_symbols["machine"],) = self.machines
(tpm_symbols["tpm_machine"],) = self.tpms.values()
vlan_symbols = {
f"vlan{v.nr}": self.vlans[idx] for idx, v in enumerate(self.vlans)
}
@@ -125,11 +154,13 @@ class Driver:
"additionally exposed symbols:\n "
+ ", ".join(map(lambda m: m.name, self.machines))
+ ",\n "
+ ", ".join(self.tpms.keys())
+ ",\n "
+ ", ".join(map(lambda v: f"vlan{v.nr}", self.vlans))
+ ",\n "
+ ", ".join(list(general_symbols.keys()))
)
return {**general_symbols, **machine_symbols, **vlan_symbols}
return {**general_symbols, **machine_symbols, **vlan_symbols, **tpm_symbols}
def test_script(self) -> None:
"""Run the test script"""

View File

@@ -17,6 +17,7 @@ import threading
import time
from test_driver.logger import rootlog
from test_driver.tpm import Tpm
CHAR_TO_KEY = {
"A": "shift-a",
@@ -318,6 +319,8 @@ class Machine:
shell: Optional[socket.socket]
serial_thread: Optional[threading.Thread]
tpm: Tpm | None
booted: bool
connected: bool
# Store last serial console lines for use
@@ -336,6 +339,7 @@ class Machine:
name: str = "machine",
keep_vm_state: bool = False,
callbacks: Optional[List[Callable]] = None,
tpm: Optional[Tpm] = None,
) -> None:
self.out_dir = out_dir
self.tmp_dir = tmp_dir
@@ -343,6 +347,7 @@ class Machine:
self.name = name
self.start_command = start_command
self.callbacks = callbacks if callbacks is not None else []
self.tpm = tpm
# set up directories
self.shared_dir = self.tmp_dir / "shared-xchg"

View File

@@ -0,0 +1,54 @@
import subprocess
from tempfile import TemporaryDirectory
class Tpm:
"""
This is a TPM driver for QEMU tests.
It gives you access to a TPM socket path on the host
you can access at `tpm_socket_path`.
"""
state_dir: TemporaryDirectory
swtpm_binary_path: str
tpm_socket_path: str
def __init__(self, swtpm_binary_path: str, tpm_socket_path: str):
self.state_dir = TemporaryDirectory()
self.swtpm_binary_path = swtpm_binary_path
self.tpm_socket_path = tpm_socket_path
self.start()
def start(self) -> None:
"""
Start swtpm binary and wait for its proper startup.
In case of failure, this will raise a runtime error.
"""
self.proc = subprocess.Popen(
[
self.swtpm_binary_path,
"socket",
"--tpmstate",
f"dir={self.state_dir.name}",
"--ctrl",
f"type=unixio,path={self.tpm_socket_path}",
"--tpm2",
]
)
# Check whether starting swtpm failed
try:
exit_code = self.proc.wait(timeout=0.2)
if exit_code is not None and exit_code != 0:
raise RuntimeError(f"failed to start swtpm, exit code: {exit_code}")
except subprocess.TimeoutExpired:
pass
def check(self) -> None:
"""
Check whether the swtpm process exited due to an error
Useful as a @polling_condition.
"""
exit_code = self.proc.poll()
if exit_code is not None and exit_code != 0:
raise RuntimeError("swtpm process died")

View File

@@ -3,6 +3,7 @@
from test_driver.driver import Driver
from test_driver.vlan import VLan
from test_driver.tpm import Tpm
from test_driver.machine import Machine
from test_driver.logger import Logger
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union

View File

@@ -14,6 +14,14 @@ let
vlans = map (m: m.virtualisation.vlans) (lib.attrValues config.nodes);
vms = map (m: m.system.build.vm) (lib.attrValues config.nodes);
tpms = map (n:
let m = config.nodes.${n};
in
{
swtpm_binary_path = "${lib.getExe m.virtualisation.tpm.package}";
socket_path = "${m.virtualisation.tpm.socketPath}";
machine_name = n;
}) (lib.attrNames config.nodes);
nodeHostNames =
let
@@ -31,6 +39,13 @@ let
uniqueVlans = lib.unique (builtins.concatLists vlans);
vlanNames = map (i: "vlan${toString i}: VLan;") uniqueVlans;
machineNames = map (name: "${name}: Machine;") nodeHostNames;
tpmNames = map (name:
let
tpmConfig = config.nodes.${name}.virtualisation.tpm;
in
''tpm_${name}: Tpm = Tpm(\"${lib.getExe tpmConfig.package}\", \"${tpmConfig.socketPath}\")''
) (lib.attrNames config.nodes);
withChecks =
if lib.length invalidNodeNames > 0 then
@@ -70,6 +85,7 @@ let
cat "${../test-script-prepend.py}" >> testScriptWithTypes
echo "${builtins.toString machineNames}" >> testScriptWithTypes
echo "${builtins.toString vlanNames}" >> testScriptWithTypes
echo "${builtins.toString tpmNames}" >> testScriptWithTypes
echo -n "$testScript" >> testScriptWithTypes
cat -n testScriptWithTypes
@@ -98,6 +114,7 @@ let
--set startScripts "''${vmStartScripts[*]}" \
--set testScript "$out/test-script" \
--set vlans '${toString vlans}' \
--set tpms '${builtins.toJSON tpms}' \
${lib.escapeShellArgs (lib.concatMap (arg: ["--add-flags" arg]) config.extraDriverArgs)}
'';

View File

@@ -796,6 +796,42 @@ in
};
};
virtualisation.tpm = {
enable = mkEnableOption ''a TPM device in the virtual machine with a driver, using swtpm.
To use it in a `test_script`, you can use `machine.tpm` which is a TPM driver offering some basic
facilities to manipulate the TPM socket on the host.
'';
package = mkPackageOptionMD cfg.host.pkgs "swtpm" { };
socketPath = mkOption {
type = types.str;
default = "/tmp/swtpm-sock";
description = lib.mdDoc "swtpm socket path on the host";
};
deviceModel = mkOption {
type = types.str;
default = ({
"i686-linux" = "tpm-tis";
"x86_64-linux" = "tpm-tis";
"ppc64-linux" = "tpm-spapr";
"armv7-linux" = "tpm-tis-device";
"aarch64-linux" = "tpm-tis-device";
}.${pkgs.hostPlatform.system});
defaultText = ''({
"i686-linux" = "tpm-tis";
"x86_64-linux" = "tpm-tis";
"ppc64-linux" = "tpm-spapr";
"armv7-linux" = "tpm-tis-device";
"aarch64-linux" = "tpm-tis-device";
}.''${pkgs.hostPlatform.system})'';
example = "tpm-tis-device";
description = lib.mdDoc "QEMU device model for the TPM, uses the appropriate default based on the system and the package passed.";
};
};
virtualisation.useDefaultFilesystems =
mkOption {
type = types.bool;
@@ -946,7 +982,8 @@ in
boot.initrd.availableKernelModules =
optional cfg.writableStore "overlay"
++ optional (cfg.qemu.diskInterface == "scsi") "sym53c8xx";
++ optional (cfg.qemu.diskInterface == "scsi") "sym53c8xx"
++ optional (cfg.tpm.enable) "tpm_tis";
virtualisation.additionalPaths = [ config.system.build.toplevel ];
@@ -1012,6 +1049,11 @@ in
(mkIf (!cfg.graphics) [
"-nographic"
])
(mkIf (cfg.tpm.enable) [
"-chardev socket,id=chrtpm,path=${cfg.tpm.socketPath}"
"-tpmdev emulator,id=tpm_dev_0,chardev=chrtpm"
"-device ${cfg.tpm.deviceModel},tpmdev=tpm_dev_0"
])
];
virtualisation.qemu.drives = mkMerge [

View File

@@ -1,13 +1,4 @@
import ./make-test-python.nix ({ lib, pkgs, system, ... }:
let
tpmSocketPath = "/tmp/swtpm-sock";
tpmDeviceModels = {
x86_64-linux = "tpm-tis";
aarch64-linux = "tpm-tis-device";
};
in
import ./make-test-python.nix ({ lib, pkgs, ... }:
{
name = "systemd-credentials-tpm2";
@@ -16,51 +7,11 @@ in
};
nodes.machine = { pkgs, ... }: {
virtualisation = {
qemu.options = [
"-chardev socket,id=chrtpm,path=${tpmSocketPath}"
"-tpmdev emulator,id=tpm_dev_0,chardev=chrtpm"
"-device ${tpmDeviceModels.${system}},tpmdev=tpm_dev_0"
];
};
boot.initrd.availableKernelModules = [ "tpm_tis" ];
virtualisation.tpm.enable = true;
environment.systemPackages = with pkgs; [ diffutils ];
};
testScript = ''
import subprocess
from tempfile import TemporaryDirectory
# From systemd-initrd-luks-tpm2.nix
class Tpm:
def __init__(self):
self.state_dir = TemporaryDirectory()
self.start()
def start(self):
self.proc = subprocess.Popen(["${pkgs.swtpm}/bin/swtpm",
"socket",
"--tpmstate", f"dir={self.state_dir.name}",
"--ctrl", "type=unixio,path=${tpmSocketPath}",
"--tpm2",
])
# Check whether starting swtpm failed
try:
exit_code = self.proc.wait(timeout=0.2)
if exit_code is not None and exit_code != 0:
raise Exception("failed to start swtpm")
except subprocess.TimeoutExpired:
pass
"""Check whether the swtpm process exited due to an error"""
def check(self):
exit_code = self.proc.poll()
if exit_code is not None and exit_code != 0:
raise Exception("swtpm process died")
CRED_NAME = "testkey"
CRED_RAW_FILE = f"/root/{CRED_NAME}"
CRED_FILE = f"/root/{CRED_NAME}.cred"
@@ -85,12 +36,6 @@ in
machine.log("systemd-run finished successfully")
tpm = Tpm()
@polling_condition
def swtpm_running():
tpm.check()
machine.wait_for_unit("multi-user.target")
with subtest("Check whether TPM device exists"):