diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e703e4..22a6400 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- `envfile` command to create `.env` files from `.env.example` templates recursively in the current directory without creating worktrees +- `--force` flag for `envfile` command to overwrite existing `.env` files +- `--dry-run` flag for `envfile` command to preview what would be created without actually creating files +- `--silent` flag for `envfile` command to suppress output except for prompts ### Changed diff --git a/src/sprout/cli.py b/src/sprout/cli.py index e6e23be..f25936d 100644 --- a/src/sprout/cli.py +++ b/src/sprout/cli.py @@ -5,6 +5,7 @@ from sprout import __version__ from sprout.commands.create import create_worktree +from sprout.commands.envfile import create_env_files from sprout.commands.ls import list_worktrees from sprout.commands.path import get_worktree_path from sprout.commands.rm import remove_worktree @@ -84,5 +85,28 @@ def path( get_worktree_path(identifier) +@app.command() +def envfile( + force: bool = typer.Option( + False, + "--force", + "-f", + help="Overwrite existing .env files", + ), + dry_run: bool = typer.Option( + False, + "--dry-run", + help="Show what would be created without creating files", + ), + silent: bool = typer.Option( + False, + "--silent", + help="Suppress output except for prompts", + ), +) -> None: + """Create .env files from .env.example templates recursively.""" + create_env_files(force=force, dry_run=dry_run, silent=silent) + + if __name__ == "__main__": app() diff --git a/src/sprout/commands/envfile.py b/src/sprout/commands/envfile.py new file mode 100644 index 0000000..2a0376c --- /dev/null +++ b/src/sprout/commands/envfile.py @@ -0,0 +1,168 @@ +"""Environment file management command implementation.""" + +import subprocess +from pathlib import Path +from typing import TypeAlias + +from rich.console import Console +from rich.table import Table + +from sprout.exceptions import SproutError +from sprout.utils import get_used_ports, parse_env_template + +PortSet: TypeAlias = set[int] + +console = Console() + + +def get_current_branch() -> str | None: + """Get the current git branch name.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + branch = result.stdout.strip() + if branch != "HEAD": + return branch + except (subprocess.SubprocessError, FileNotFoundError): + pass + return None + + +def find_env_example_files(root_dir: Path) -> list[Path]: + """Find all .env.example files recursively from the given directory.""" + return sorted(root_dir.rglob(".env.example")) + + +def create_env_files( + force: bool = False, + dry_run: bool = False, + silent: bool = False, +) -> None: + """Create .env files from .env.example templates in current directory and subdirectories. + + Args: + force: Overwrite existing .env files + dry_run: Show what would be created without creating files + silent: Suppress output except for prompts + """ + root_dir = Path.cwd() + env_examples = find_env_example_files(root_dir) + + if not env_examples: + if not silent: + console.print( + "[yellow]No .env.example files found in current directory " + "or subdirectories[/yellow]" + ) + return + + branch_name = get_current_branch() + used_ports = get_used_ports() + + created_files: list[tuple[Path, str]] = [] + skipped_files: list[tuple[Path, str]] = [] + errors: list[tuple[Path, str]] = [] + + if not silent: + console.print(f"[cyan]Found {len(env_examples)} .env.example file(s)[/cyan]") + if branch_name: + console.print(f"[dim]Current branch: {branch_name}[/dim]") + console.print() + + for example_path in env_examples: + env_path = example_path.parent / ".env" + + try: + relative_env = env_path.relative_to(root_dir) + except ValueError: + relative_env = env_path + + if env_path.exists() and not force: + skipped_files.append((relative_env, "already exists")) + if not silent: + console.print(f"[dim]Skipping {relative_env} (already exists)[/dim]") + continue + + if dry_run: + if env_path.exists(): + created_files.append((relative_env, "would overwrite")) + else: + created_files.append((relative_env, "would create")) + if not silent: + action = "Would overwrite" if env_path.exists() else "Would create" + console.print(f"[blue]{action}[/blue] {relative_env}") + continue + + try: + parsed_content = parse_env_template( + example_path, + silent=silent, + used_ports=used_ports, + branch_name=branch_name, + ) + + env_path.write_text(parsed_content) + + if env_path.exists() and force: + created_files.append((relative_env, "overwritten")) + if not silent: + console.print(f"[green]✓[/green] Overwritten {relative_env}") + else: + created_files.append((relative_env, "created")) + if not silent: + console.print(f"[green]✓[/green] Created {relative_env}") + + except SproutError as e: + errors.append((relative_env, str(e))) + if not silent: + console.print(f"[red]✗[/red] Failed to create {relative_env}: {e}") + except Exception as e: + errors.append((relative_env, f"Unexpected error: {e}")) + if not silent: + console.print( + f"[red]✗[/red] Failed to create {relative_env}: Unexpected error: {e}" + ) + + if not silent: + console.print() + _show_summary(created_files, skipped_files, errors, dry_run) + + +def _show_summary( + created_files: list[tuple[Path, str]], + skipped_files: list[tuple[Path, str]], + errors: list[tuple[Path, str]], + dry_run: bool, +) -> None: + """Show summary of the operation.""" + table = Table(title="Summary", show_header=True, header_style="bold cyan") + table.add_column("Status", style="cyan", no_wrap=True) + table.add_column("Count", justify="right") + + if dry_run: + would_create = len([f for f in created_files if f[1] == "would create"]) + would_overwrite = len([f for f in created_files if f[1] == "would overwrite"]) + if would_create: + table.add_row("Would create", str(would_create)) + if would_overwrite: + table.add_row("Would overwrite", str(would_overwrite)) + else: + created = len([f for f in created_files if f[1] == "created"]) + overwritten = len([f for f in created_files if f[1] == "overwritten"]) + if created: + table.add_row("Created", str(created)) + if overwritten: + table.add_row("Overwritten", str(overwritten)) + + if skipped_files: + table.add_row("Skipped", str(len(skipped_files))) + + if errors: + table.add_row("[red]Failed[/red]", str(len(errors))) + + console.print(table) diff --git a/tests/test_envfile.py b/tests/test_envfile.py new file mode 100644 index 0000000..939be74 --- /dev/null +++ b/tests/test_envfile.py @@ -0,0 +1,217 @@ +"""Tests for the envfile command.""" + +from unittest.mock import MagicMock, patch + +from sprout.commands.envfile import create_env_files, find_env_example_files, get_current_branch +from sprout.exceptions import SproutError + + +class TestGetCurrentBranch: + """Test get_current_branch function.""" + + def test_get_current_branch_success(self): + """Test successful branch retrieval.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="feature-branch\n") + assert get_current_branch() == "feature-branch" + + def test_get_current_branch_head_state(self): + """Test when in detached HEAD state.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="HEAD\n") + assert get_current_branch() is None + + def test_get_current_branch_failure(self): + """Test when git command fails.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=1) + assert get_current_branch() is None + + def test_get_current_branch_no_git(self): + """Test when git is not available.""" + with patch("subprocess.run", side_effect=FileNotFoundError): + assert get_current_branch() is None + + +class TestFindEnvExampleFiles: + """Test find_env_example_files function.""" + + def test_find_env_example_files(self, tmp_path): + """Test finding .env.example files recursively.""" + # Create test structure + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / ".env.example").touch() + (tmp_path / "dir2").mkdir() + (tmp_path / "dir2" / "subdir").mkdir() + (tmp_path / "dir2" / "subdir" / ".env.example").touch() + (tmp_path / ".env").touch() # Should not be found + + files = find_env_example_files(tmp_path) + assert len(files) == 2 + assert all(f.name == ".env.example" for f in files) + + def test_find_env_example_files_empty(self, tmp_path): + """Test when no .env.example files exist.""" + files = find_env_example_files(tmp_path) + assert files == [] + + def test_find_env_example_files_sorted(self, tmp_path): + """Test that files are returned sorted.""" + (tmp_path / "b").mkdir() + (tmp_path / "b" / ".env.example").touch() + (tmp_path / "a").mkdir() + (tmp_path / "a" / ".env.example").touch() + + files = find_env_example_files(tmp_path) + assert len(files) == 2 + assert files[0].parent.name == "a" + assert files[1].parent.name == "b" + + +class TestCreateEnvFiles: + """Test create_env_files function.""" + + @patch("sprout.commands.envfile.Path.cwd") + @patch("sprout.commands.envfile.get_current_branch") + @patch("sprout.commands.envfile.get_used_ports") + @patch("sprout.commands.envfile.parse_env_template") + def test_create_env_files_basic(self, mock_parse, mock_ports, mock_branch, mock_cwd, tmp_path): + """Test basic .env file creation.""" + # Setup + mock_cwd.return_value = tmp_path + mock_branch.return_value = "main" + mock_ports.return_value = set() + mock_parse.return_value = "KEY=value\n" + + # Create .env.example + (tmp_path / ".env.example").write_text("KEY={{ VALUE }}\n") + + # Run + with patch("sprout.commands.envfile.console"): + create_env_files() + + # Verify + assert (tmp_path / ".env").exists() + assert (tmp_path / ".env").read_text() == "KEY=value\n" + + @patch("sprout.commands.envfile.Path.cwd") + def test_create_env_files_skip_existing(self, mock_cwd, tmp_path): + """Test skipping existing .env files.""" + mock_cwd.return_value = tmp_path + + # Create existing .env and .env.example + (tmp_path / ".env").write_text("existing") + (tmp_path / ".env.example").write_text("new") + + with patch("sprout.commands.envfile.console"): + create_env_files(force=False) + + # Should not overwrite + assert (tmp_path / ".env").read_text() == "existing" + + @patch("sprout.commands.envfile.Path.cwd") + @patch("sprout.commands.envfile.get_current_branch") + @patch("sprout.commands.envfile.get_used_ports") + @patch("sprout.commands.envfile.parse_env_template") + def test_create_env_files_force(self, mock_parse, mock_ports, mock_branch, mock_cwd, tmp_path): + """Test forcing overwrite of existing .env files.""" + mock_cwd.return_value = tmp_path + mock_branch.return_value = "main" + mock_ports.return_value = set() + mock_parse.return_value = "new" + + # Create existing .env and .env.example + (tmp_path / ".env").write_text("existing") + (tmp_path / ".env.example").write_text("template") + + with patch("sprout.commands.envfile.console"): + create_env_files(force=True) + + # Should overwrite + assert (tmp_path / ".env").read_text() == "new" + + @patch("sprout.commands.envfile.Path.cwd") + @patch("sprout.commands.envfile.get_current_branch") + @patch("sprout.commands.envfile.get_used_ports") + def test_create_env_files_dry_run(self, mock_ports, mock_branch, mock_cwd, tmp_path): + """Test dry run mode.""" + mock_cwd.return_value = tmp_path + mock_branch.return_value = "main" + mock_ports.return_value = set() + + # Create .env.example + (tmp_path / ".env.example").write_text("KEY=value\n") + + with patch("sprout.commands.envfile.console"): + create_env_files(dry_run=True) + + # Should not create .env + assert not (tmp_path / ".env").exists() + + @patch("sprout.commands.envfile.Path.cwd") + def test_create_env_files_no_examples(self, mock_cwd, tmp_path): + """Test when no .env.example files exist.""" + mock_cwd.return_value = tmp_path + + with patch("sprout.commands.envfile.console") as mock_console: + create_env_files() + # Should print warning + mock_console.print.assert_called() + + @patch("sprout.commands.envfile.Path.cwd") + @patch("sprout.commands.envfile.get_current_branch") + @patch("sprout.commands.envfile.get_used_ports") + @patch("sprout.commands.envfile.parse_env_template") + def test_create_env_files_recursive( + self, mock_parse, mock_ports, mock_branch, mock_cwd, tmp_path + ): + """Test recursive .env file creation.""" + mock_cwd.return_value = tmp_path + mock_branch.return_value = "main" + mock_ports.return_value = set() + mock_parse.return_value = "KEY=value\n" + + # Create nested structure + (tmp_path / "app1").mkdir() + (tmp_path / "app1" / ".env.example").write_text("APP1={{ VALUE }}\n") + (tmp_path / "app2").mkdir() + (tmp_path / "app2" / ".env.example").write_text("APP2={{ VALUE }}\n") + + with patch("sprout.commands.envfile.console"): + create_env_files() + + # Verify both .env files created + assert (tmp_path / "app1" / ".env").exists() + assert (tmp_path / "app2" / ".env").exists() + + @patch("sprout.commands.envfile.Path.cwd") + @patch("sprout.commands.envfile.get_current_branch") + @patch("sprout.commands.envfile.get_used_ports") + @patch("sprout.commands.envfile.parse_env_template") + def test_create_env_files_error_handling( + self, mock_parse, mock_ports, mock_branch, mock_cwd, tmp_path + ): + """Test error handling during file creation.""" + mock_cwd.return_value = tmp_path + mock_branch.return_value = "main" + mock_ports.return_value = set() + mock_parse.side_effect = SproutError("Template error") + + # Create .env.example + (tmp_path / ".env.example").write_text("KEY={{ VALUE }}\n") + + with patch("sprout.commands.envfile.console"): + create_env_files() + + # Should not create .env due to error + assert not (tmp_path / ".env").exists() + + @patch("sprout.commands.envfile.Path.cwd") + def test_create_env_files_silent_mode(self, mock_cwd, tmp_path): + """Test silent mode suppresses output.""" + mock_cwd.return_value = tmp_path + + with patch("sprout.commands.envfile.console") as mock_console: + create_env_files(silent=True) + # Should not print anything in silent mode + mock_console.print.assert_not_called()