Skip to content

Commit 13a6e33

Browse files
committed
Project template validation. Add tests for initializing project from all bundled templates.
1 parent e5dd690 commit 13a6e33

File tree

6 files changed

+177
-50
lines changed

6 files changed

+177
-50
lines changed

agentstack/cli/cli.py

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from agentstack import generation
3232
from agentstack.utils import open_json_file, term_color, is_snake_case
3333
from agentstack.update import AGENTSTACK_PACKAGE
34+
from agentstack.proj_templates import TemplateConfig
3435

3536

3637
PREFERRED_MODELS = [
@@ -57,45 +58,34 @@ def init_project_builder(
5758

5859
template_data = None
5960
if template is not None:
60-
url_start = "https://"
61-
if template[: len(url_start)] == url_start:
62-
# template is a url
63-
response = requests.get(template)
64-
if response.status_code == 200:
65-
template_data = response.json()
66-
else:
67-
print(
68-
term_color(
69-
f"Failed to fetch template data from {template}. Status code: {response.status_code}",
70-
'red',
71-
)
72-
)
61+
if template.startswith("https://"):
62+
try:
63+
template_data = TemplateConfig.from_url(template)
64+
except Exception as e:
65+
print(term_color(f"Failed to fetch template data from {template}", 'red'))
7366
sys.exit(1)
7467
else:
75-
with importlib.resources.path(
76-
'agentstack.templates.proj_templates', f'{template}.json'
77-
) as template_path:
78-
if template_path is None:
79-
print(term_color(f"No such template {template} found", 'red'))
80-
sys.exit(1)
81-
template_data = open_json_file(template_path)
68+
try:
69+
template_data = TemplateConfig.from_template_name(template)
70+
except Exception as e:
71+
print(term_color(f"Failed to load template {template}", 'red'))
72+
sys.exit(1)
8273

8374
if template_data:
8475
project_details = {
85-
"name": slug_name or template_data['name'],
76+
"name": slug_name or template_data.name,
8677
"version": "0.0.1",
87-
"description": template_data['description'],
78+
"description": template_data.description,
8879
"author": "Name <Email>",
8980
"license": "MIT",
9081
}
91-
framework = template_data['framework']
82+
framework = template_data.framework
9283
design = {
93-
'agents': template_data['agents'],
94-
'tasks': template_data['tasks'],
95-
'inputs': template_data['inputs'],
84+
'agents': template_data.agents,
85+
'tasks': template_data.tasks,
86+
'inputs': template_data.inputs,
9687
}
97-
98-
tools = template_data['tools']
88+
tools = template_data.tools
9989

10090
elif use_wizard:
10191
welcome_message()
@@ -390,7 +380,7 @@ def insert_template(
390380
project_details: dict,
391381
framework_name: str,
392382
design: dict,
393-
template_data: Optional[dict] = None,
383+
template_data: Optional[TemplateConfig] = None,
394384
):
395385
framework = FrameworkData(framework_name.lower())
396386
project_metadata = ProjectMetadata(
@@ -400,8 +390,8 @@ def insert_template(
400390
version="0.0.1",
401391
license="MIT",
402392
year=datetime.now().year,
403-
template=template_data['name'] if template_data else 'none',
404-
template_version=template_data['template_version'] if template_data else '0',
393+
template=template_data.name if template_data else 'none',
394+
template_version=template_data.template_version if template_data else '0',
405395
)
406396

407397
project_structure = ProjectStructure()

agentstack/proj_templates.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Optional
2+
import os, sys
3+
from pathlib import Path
4+
import pydantic
5+
import requests
6+
from agentstack import ValidationError
7+
from agentstack.utils import get_package_path, open_json_file, term_color
8+
9+
10+
class TemplateConfig(pydantic.BaseModel):
11+
"""
12+
Interface for interacting with template configuration files.
13+
14+
Templates are read-only.
15+
16+
Template Schema
17+
-------------
18+
name: str
19+
The name of the project.
20+
description: str
21+
A description of the template.
22+
template_version: str
23+
The version of the template.
24+
framework: str
25+
The framework the template is for.
26+
method: str
27+
The method used by the project. ie. "sequential"
28+
agents: list[dict]
29+
A list of agents used by the project. TODO vaidate this against an agent schema
30+
tasks: list[dict]
31+
A list of tasks used by the project. TODO validate this against a task schema
32+
tools: list[dict]
33+
A list of tools used by the project. TODO validate this against a tool schema
34+
inputs: list[str]
35+
A list of inputs used by the project.
36+
"""
37+
38+
name: str
39+
description: str
40+
template_version: int
41+
framework: str
42+
method: str
43+
agents: list[dict]
44+
tasks: list[dict]
45+
tools: list[dict]
46+
inputs: list[str]
47+
48+
@classmethod
49+
def from_template_name(cls, name: str) -> 'TemplateConfig':
50+
path = get_package_path() / f'templates/proj_templates/{name}.json'
51+
if not os.path.exists(path): # TODO raise exceptions and handle message/exit in cli
52+
print(term_color(f'No known agentstack tool: {name}', 'red'))
53+
sys.exit(1)
54+
return cls.from_json(path)
55+
56+
@classmethod
57+
def from_json(cls, path: Path) -> 'ToolConfig':
58+
data = open_json_file(path)
59+
try:
60+
return cls(**data)
61+
except pydantic.ValidationError as e:
62+
# TODO raise exceptions and handle message/exit in cli
63+
print(term_color(f"Error validating template config JSON: \n{path}", 'red'))
64+
for error in e.errors():
65+
print(f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}")
66+
sys.exit(1)
67+
68+
@classmethod
69+
def from_url(cls, url: str) -> 'TemplateConfig':
70+
if not url.startswith("https://"):
71+
raise ValidationError(f"Invalid URL: {url}")
72+
response = requests.get(url)
73+
if response.status_code != 200:
74+
raise ValidationError(f"Failed to fetch template from {url}")
75+
return cls(**response.json())
76+
77+
78+
def get_all_template_paths() -> list[Path]:
79+
paths = []
80+
templates_dir = get_package_path() / 'templates/proj_templates'
81+
for file in templates_dir.iterdir():
82+
if file.suffix == '.json':
83+
paths.append(file)
84+
return paths
85+
86+
87+
def get_all_template_names() -> list[str]:
88+
return [path.stem for path in get_all_template_paths()]
89+
90+
91+
def get_all_templates() -> list[TemplateConfig]:
92+
return [TemplateConfig.from_json(path) for path in get_all_template_paths()]

agentstack/templates/proj_templates/content_creator.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"name": "content_creation",
2+
"name": "content_creator",
33
"description": "Multi-agent system for creating high-quality content",
44
"template_version": 1,
55
"framework": "crewai",

tests/test_cli_init.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import subprocess
2+
import os, sys
3+
import unittest
4+
from parameterized import parameterized
5+
from pathlib import Path
6+
import shutil
7+
from agentstack.proj_templates import get_all_template_names
8+
9+
BASE_PATH = Path(__file__).parent
10+
CLI_ENTRY = [
11+
sys.executable,
12+
"-m",
13+
"agentstack.main",
14+
]
15+
16+
17+
class CLIInitTest(unittest.TestCase):
18+
def setUp(self):
19+
self.project_dir = Path(BASE_PATH / 'tmp/cli_init')
20+
os.makedirs(self.project_dir)
21+
22+
def tearDown(self):
23+
shutil.rmtree(self.project_dir)
24+
25+
def _run_cli(self, *args):
26+
"""Helper method to run the CLI with arguments."""
27+
return subprocess.run([*CLI_ENTRY, *args], capture_output=True, text=True)
28+
29+
def test_init_command(self):
30+
"""Test the 'init' command to create a project directory."""
31+
os.chdir(self.project_dir)
32+
result = self._run_cli('init', str(self.project_dir))
33+
self.assertEqual(result.returncode, 0)
34+
self.assertTrue(self.project_dir.exists())
35+
36+
@parameterized.expand([(x, ) for x in get_all_template_names()])
37+
def test_init_command_for_template(self, template_name):
38+
"""Test the 'init' command to create a project directory with a template."""
39+
os.chdir(self.project_dir)
40+
result = self._run_cli('init', str(self.project_dir), '--template', template_name)
41+
self.assertEqual(result.returncode, 0)
42+
self.assertTrue(self.project_dir.exists())
43+

tests/test_cli_loads.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
class TestAgentStackCLI(unittest.TestCase):
11-
# Replace with your actual CLI entry point if different
1211
CLI_ENTRY = [
1312
sys.executable,
1413
"-m",
@@ -32,23 +31,6 @@ def test_invalid_command(self):
3231
self.assertNotEqual(result.returncode, 0)
3332
self.assertIn("usage:", result.stderr)
3433

35-
def test_init_command(self):
36-
"""Test the 'init' command to create a project directory."""
37-
test_dir = Path(BASE_PATH / 'tmp/test_project')
38-
39-
# Ensure the directory doesn't exist from previous runs
40-
if test_dir.exists():
41-
shutil.rmtree(test_dir)
42-
os.makedirs(test_dir)
43-
44-
os.chdir(test_dir)
45-
result = self.run_cli("init", str(test_dir))
46-
self.assertEqual(result.returncode, 0)
47-
self.assertTrue(test_dir.exists())
48-
49-
# Clean up
50-
shutil.rmtree(test_dir)
51-
5234
def test_run_command_invalid_project(self):
5335
"""Test the 'run' command on an invalid project."""
5436
test_dir = Path(BASE_PATH / 'tmp/test_project')

tests/test_templates_config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import json
2+
import unittest
3+
from pathlib import Path
4+
from agentstack.proj_templates import TemplateConfig, get_all_template_names, get_all_template_paths
5+
6+
BASE_PATH = Path(__file__).parent
7+
8+
9+
class TemplateConfigTest(unittest.TestCase):
10+
def test_all_configs_from_template_name(self):
11+
for template_name in get_all_template_names():
12+
config = TemplateConfig.from_template_name(template_name)
13+
assert config.name == template_name
14+
# We can assume that pydantic validation caught any other issues
15+
16+
def test_all_configs_from_template_path(self):
17+
for path in get_all_template_paths():
18+
config = TemplateConfig.from_json(path)
19+
assert config.name == path.stem
20+
# We can assume that pydantic validation caught any other issues

0 commit comments

Comments
 (0)