diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 2980b0ce..9529daee 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -14,6 +14,8 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/delete" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/fu" + "github.com/brevdev/brev-cli/pkg/cmd/gpucreate" + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" "github.com/brevdev/brev-cli/pkg/cmd/hello" "github.com/brevdev/brev-cli/pkg/cmd/importideconfig" @@ -157,8 +159,13 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin PersistentPreRunE: func(cmd *cobra.Command, args []string) error { breverrors.GetDefaultErrorReporter().AddTag("command", cmd.Name()) // version info gets in the way of the output for - // configure-env-vars, since shells are going to eval it - if featureflag.ShowVersionOnRun() && !printVersion && cmd.Name() != "configure-env-vars" { + // configure-env-vars (shells eval it) and gpu-create/provision (piped to other commands) + skipVersionCommands := map[string]bool{ + "configure-env-vars": true, + "gpu-create": true, + "provision": true, + } + if featureflag.ShowVersionOnRun() && !printVersion && !skipVersionCommands[cmd.Name()] { v, err := remoteversion.BuildCheckLatestVersionString(t, noLoginCmdStore) // todo this should not be fatal when it errors if err != nil { @@ -270,6 +277,8 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor } cmd.AddCommand(workspacegroups.NewCmdWorkspaceGroups(t, loginCmdStore)) cmd.AddCommand(scale.NewCmdScale(t, noLoginCmdStore)) + cmd.AddCommand(gpusearch.NewCmdGPUSearch(t, noLoginCmdStore)) + cmd.AddCommand(gpucreate.NewCmdGPUCreate(t, loginCmdStore)) cmd.AddCommand(configureenvvars.NewCmdConfigureEnvVars(t, loginCmdStore)) cmd.AddCommand(importideconfig.NewCmdImportIDEConfig(t, noLoginCmdStore)) cmd.AddCommand(shell.NewCmdShell(t, loginCmdStore, noLoginCmdStore)) diff --git a/pkg/cmd/cmderrors/cmderrors.go b/pkg/cmd/cmderrors/cmderrors.go index 6c290e59..4b07989a 100644 --- a/pkg/cmd/cmderrors/cmderrors.go +++ b/pkg/cmd/cmderrors/cmderrors.go @@ -39,7 +39,7 @@ func DisplayAndHandleError(err error) { case *breverrors.NvidiaMigrationError: // Handle nvidia migration error if nvErr, ok := errors.Cause(err).(*breverrors.NvidiaMigrationError); ok { - fmt.Println("\n This account has been migrated to NVIDIA Auth. Attempting to log in with NVIDIA account...") + fmt.Fprintln(os.Stderr, "\n This account has been migrated to NVIDIA Auth. Attempting to log in with NVIDIA account...") brevBin, err1 := os.Executable() if err1 == nil { cmd := exec.Command(brevBin, "login", "--auth", "nvidia") // #nosec G204 @@ -68,9 +68,9 @@ func DisplayAndHandleError(err error) { } } if featureflag.Debug() || featureflag.IsDev() { - fmt.Println(err) + fmt.Fprintln(os.Stderr, err) } else { - fmt.Println(prettyErr) + fmt.Fprintln(os.Stderr, prettyErr) } } } diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go new file mode 100644 index 00000000..2af8f96f --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -0,0 +1,766 @@ +// Package gpucreate provides a command to create GPU instances with retry logic +package gpucreate + +import ( + "encoding/json" + "fmt" + "io" + "os" + "strings" + "sync" + "time" + "unicode" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + "github.com/brevdev/brev-cli/pkg/cmd/util" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/featureflag" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/spf13/cobra" +) + +var ( + long = `Create GPU instances with automatic retry across multiple instance types. + +This command attempts to create GPU instances, trying different instance types +until the desired number of instances are successfully created. Instance types +can be specified directly, piped from 'brev search', or auto-selected using defaults. + +Default Behavior: +If no instance types are specified (no --type flag and no piped input), the command +automatically searches for GPUs matching these criteria: + - Minimum 20GB total VRAM + - Minimum 500GB disk + - Compute capability 8.0+ (Ampere or newer) + - Boot time under 7 minutes +Results are sorted by price (cheapest first). + +Retry and Fallback Logic: +When multiple instance types are provided (via --type or piped input), the command +tries to create ALL instances using the first type before falling back to the next: + + 1. Try first type for all instances (using --parallel workers if specified) + 2. If first type succeeds for all instances, done + 3. If first type fails for some instances, try second type for remaining instances + 4. Continue until all instances are created or all types are exhausted + +Example with --count 2 and types [A, B]: + - Try A for instance-1 → success + - Try A for instance-2 → success + - Done! (both instances use type A) + +If type A fails for instance-2: + - Try A for instance-1 → success + - Try A for instance-2 → fail + - Try B for instance-2 → success + - Done! (instance-1 uses A, instance-2 uses B) + +Startup Scripts: +You can attach a startup script that runs when the instance boots using the +--startup-script flag. The script can be provided as: + - An inline string: --startup-script 'pip install torch' + - A file path (prefix with @): --startup-script @setup.sh + - An absolute file path: --startup-script @/path/to/setup.sh` + + example = ` + # Quick start: create an instance using smart defaults (sorted by price) + brev create my-instance + + # Create with explicit --name flag + brev create --name my-instance + + # Create and immediately open in VS Code + brev create my-instance | brev open + + # Create and SSH into the instance + brev shell $(brev create my-instance) + + # Create and run a command + brev create my-instance | brev shell -c "nvidia-smi" + + # Create with a specific GPU type + brev create my-instance --type g5.xlarge + + # Pipe instance types from brev search (tries first type, falls back if needed) + brev search --min-vram 24 | brev create my-instance + + # Create multiple instances (all use same type, with fallback) + brev create my-cluster --count 3 --type g5.xlarge + # Creates: my-cluster-1, my-cluster-2, my-cluster-3 (all g5.xlarge) + + # Create multiple instances with fallback types + brev search --gpu-name A100 | brev create my-cluster --count 2 + # Tries first A100 type for both instances, falls back to next type if needed + + # Create instances in parallel (faster, but may use more types on partial failures) + brev search --gpu-name A100 | brev create my-cluster --count 3 --parallel 3 + + # Try multiple specific types in order (fallback chain) + brev create my-instance --type g5.xlarge,g5.2xlarge,g4dn.xlarge + + # Attach a startup script from a file + brev create my-instance --type g5.xlarge --startup-script @setup.sh + + # Attach an inline startup script + brev create my-instance --startup-script 'pip install torch' + + # Combine: find cheapest A100, attach setup script + brev search --gpu-name A100 --sort price | brev create ml-box -s @ml-setup.sh +` +) + +// GPUCreateStore defines the interface for GPU create operations +type GPUCreateStore interface { + util.GetWorkspaceByNameOrIDErrStore + gpusearch.GPUSearchStore + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetCurrentUser() (*entity.User, error) + GetWorkspace(workspaceID string) (*entity.Workspace, error) + CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) + DeleteWorkspace(workspaceID string) (*entity.Workspace, error) + GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) +} + +// Default filter values for automatic GPU selection +const ( + defaultMinTotalVRAM = 20.0 // GB + defaultMinDisk = 500.0 // GB + defaultMinCapability = 8.0 + defaultMaxBootTime = 7 // minutes +) + +// CreateResult holds the result of a workspace creation attempt +type CreateResult struct { + Workspace *entity.Workspace + InstanceType string + Error error +} + +// NewCmdGPUCreate creates the gpu-create command +func NewCmdGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore) *cobra.Command { + var name string + var instanceTypes string + var count int + var parallel int + var detached bool + var timeout int + var startupScript string + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "create [name]", + Aliases: []string{"provision", "gpu-create", "gpu-retry", "gcreate"}, + DisableFlagsInUseLine: true, + Short: "Create GPU instances with automatic retry", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + // Accept name as positional arg or --name flag + if len(args) > 0 && name == "" { + name = args[0] + } + + // Check if output is being piped (for chaining with brev shell/open) + piped := isStdoutPiped() + + // Parse instance types from flag or stdin + types, err := parseInstanceTypes(instanceTypes) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + // If no types provided, use default filters to find suitable GPUs + if len(types) == 0 { + msg := fmt.Sprintf("No instance types specified, using defaults: min-total-vram=%.0fGB, min-disk=%.0fGB, min-capability=%.1f, max-boot-time=%dm\n\n", + defaultMinTotalVRAM, defaultMinDisk, defaultMinCapability, defaultMaxBootTime) + if piped { + fmt.Fprint(os.Stderr, msg) + } else { + t.Vprint(msg) + } + + types, err = getDefaultInstanceTypes(gpuCreateStore) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(types) == 0 { + return breverrors.NewValidationError("no GPU instances match the default filters. Try 'brev search' to see available options") + } + } + + if name == "" { + return breverrors.NewValidationError("name is required (as argument or --name flag)") + } + + if count < 1 { + return breverrors.NewValidationError("--count must be at least 1") + } + + if parallel < 1 { + parallel = 1 + } + + // Parse startup script (can be a string or @filepath) + scriptContent, err := parseStartupScript(startupScript) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + opts := GPUCreateOptions{ + Name: name, + InstanceTypes: types, + Count: count, + Parallel: parallel, + Detached: detached, + Timeout: time.Duration(timeout) * time.Second, + StartupScript: scriptContent, + } + + err = RunGPUCreate(t, gpuCreateStore, opts) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&name, "name", "n", "", "Base name for the instances (or pass as first argument)") + cmd.Flags().StringVarP(&instanceTypes, "type", "t", "", "Comma-separated list of instance types to try") + cmd.Flags().IntVarP(&count, "count", "c", 1, "Number of instances to create") + cmd.Flags().IntVarP(¶llel, "parallel", "p", 1, "Number of parallel creation attempts") + cmd.Flags().BoolVarP(&detached, "detached", "d", false, "Don't wait for instances to be ready") + cmd.Flags().IntVar(&timeout, "timeout", 300, "Timeout in seconds for each instance to become ready") + cmd.Flags().StringVarP(&startupScript, "startup-script", "s", "", "Startup script to run on instance (string or @filepath)") + + return cmd +} + +// InstanceSpec holds an instance type and its target disk size +type InstanceSpec struct { + Type string + DiskGB float64 // Target disk size in GB, 0 means use default +} + +// GPUCreateOptions holds the options for GPU instance creation +type GPUCreateOptions struct { + Name string + InstanceTypes []InstanceSpec + Count int + Parallel int + Detached bool + Timeout time.Duration + StartupScript string +} + +// parseStartupScript parses the startup script from a string or file path +// If the value starts with @, it's treated as a file path +func parseStartupScript(value string) (string, error) { + if value == "" { + return "", nil + } + + // Check if it's a file path (prefixed with @) + if strings.HasPrefix(value, "@") { + filePath := strings.TrimPrefix(value, "@") + content, err := os.ReadFile(filePath) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return string(content), nil + } + + // Otherwise, treat it as the script content directly + return value, nil +} + +// getDefaultInstanceTypes fetches GPU instance types using default filters +func getDefaultInstanceTypes(store GPUCreateStore) ([]InstanceSpec, error) { + response, err := store.GetInstanceTypes() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + if response == nil || len(response.Items) == 0 { + return nil, nil + } + + // Use gpusearch package to process, filter, and sort instances + instances := gpusearch.ProcessInstances(response.Items) + filtered := gpusearch.FilterInstances(instances, "", "", 0, defaultMinTotalVRAM, defaultMinCapability, defaultMinDisk, defaultMaxBootTime) + gpusearch.SortInstances(filtered, "price", false) + + // Convert to InstanceSpec with disk info + var specs []InstanceSpec + for _, inst := range filtered { + // For defaults, use the minimum disk size that meets the filter + diskGB := inst.DiskMin + if inst.DiskMin != inst.DiskMax && defaultMinDisk > inst.DiskMin && defaultMinDisk <= inst.DiskMax { + diskGB = defaultMinDisk + } + specs = append(specs, InstanceSpec{Type: inst.Type, DiskGB: diskGB}) + } + + return specs, nil +} + +// parseInstanceTypes parses instance types from flag value or stdin +// Returns InstanceSpec with type and optional disk size (from JSON input) +func parseInstanceTypes(flagValue string) ([]InstanceSpec, error) { + var specs []InstanceSpec + + // First check if there's a flag value + if flagValue != "" { + parts := strings.Split(flagValue, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + specs = append(specs, InstanceSpec{Type: p}) + } + } + } + + // Check if there's piped input from stdin + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + // Data is being piped to stdin - read all input first + input, err := io.ReadAll(os.Stdin) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + inputStr := strings.TrimSpace(string(input)) + if inputStr == "" { + return specs, nil + } + + // Check if input is JSON (starts with '[') + if strings.HasPrefix(inputStr, "[") { + jsonSpecs, err := parseJSONInput(inputStr) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + specs = append(specs, jsonSpecs...) + } else { + // Parse as table format + tableSpecs := parseTableInput(inputStr) + specs = append(specs, tableSpecs...) + } + } + + return specs, nil +} + +// parseJSONInput parses JSON array input from gpu-search --json +func parseJSONInput(input string) ([]InstanceSpec, error) { + var instances []gpusearch.GPUInstanceInfo + if err := json.Unmarshal([]byte(input), &instances); err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + var specs []InstanceSpec + for _, inst := range instances { + spec := InstanceSpec{ + Type: inst.Type, + DiskGB: inst.TargetDisk, + } + specs = append(specs, spec) + } + return specs, nil +} + +// parseTableInput parses table format input from gpu-search +func parseTableInput(input string) []InstanceSpec { + var specs []InstanceSpec + lines := strings.Split(input, "\n") + + for i, line := range lines { + // Skip header line (first line typically contains column names) + if i == 0 && (strings.Contains(line, "TYPE") || strings.Contains(line, "GPU")) { + continue + } + + // Skip empty lines + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Skip summary lines (e.g., "Found X GPU instance types") + if strings.HasPrefix(line, "Found ") { + continue + } + + // Extract the first column (TYPE) from the table output + // The format is: TYPE GPU COUNT VRAM/GPU TOTAL VRAM CAPABILITY VCPUs $/HR + fields := strings.Fields(line) + if len(fields) > 0 { + instanceType := fields[0] + // Validate it looks like an instance type (contains letters and possibly numbers/dots) + if isValidInstanceType(instanceType) { + specs = append(specs, InstanceSpec{Type: instanceType}) + } + } + } + + return specs +} + +// isValidInstanceType checks if a string looks like a valid instance type. +// Instance types typically have formats like: g5.xlarge, p4d.24xlarge, n1-highmem-4:nvidia-tesla-t4:1 +func isValidInstanceType(s string) bool { + if len(s) < 2 { + return false + } + var hasLetter, hasDigit bool + for _, c := range s { + if unicode.IsLetter(c) { + hasLetter = true + } else if unicode.IsDigit(c) { + hasDigit = true + } + if hasLetter && hasDigit { + return true + } + } + return hasLetter && hasDigit +} + +// isStdoutPiped returns true if stdout is being piped (not a terminal) +func isStdoutPiped() bool { + stat, _ := os.Stdout.Stat() + return (stat.Mode() & os.ModeCharDevice) == 0 +} + +// stderrPrintf prints to stderr (used when stdout is piped) +func stderrPrintf(format string, a ...interface{}) { + fmt.Fprintf(os.Stderr, format, a...) +} + +// formatInstanceSpecs formats a slice of InstanceSpec for display +func formatInstanceSpecs(specs []InstanceSpec) string { + var parts []string + for _, spec := range specs { + if spec.DiskGB > 0 { + parts = append(parts, fmt.Sprintf("%s (%.0fGB disk)", spec.Type, spec.DiskGB)) + } else { + parts = append(parts, spec.Type) + } + } + return strings.Join(parts, ", ") +} + +// RunGPUCreate executes the GPU create with retry logic +func RunGPUCreate(t *terminal.Terminal, gpuCreateStore GPUCreateStore, opts GPUCreateOptions) error { + // Check if output is being piped (for chaining with brev shell/open) + piped := isStdoutPiped() + + // Helper to print progress - uses stderr when piped so only instance name goes to stdout + logf := func(format string, a ...interface{}) { + if piped { + fmt.Fprintf(os.Stderr, format, a...) + } else { + t.Vprintf(format, a...) + } + } + + user, err := gpuCreateStore.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + org, err := gpuCreateStore.GetActiveOrganizationOrDefault() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if org == nil { + return breverrors.NewValidationError("no organization found") + } + + // Fetch instance types with workspace groups to determine correct workspace group ID + allInstanceTypes, err := gpuCreateStore.GetAllInstanceTypesWithWorkspaceGroups(org.ID) + if err != nil { + logf("Warning: could not fetch instance types with workspace groups: %s\n", err.Error()) + logf("Falling back to default workspace group\n") + allInstanceTypes = nil + } + + logf("Attempting to create %d instance(s) with %d parallel attempts\n", opts.Count, opts.Parallel) + logf("Instance types to try: %s\n\n", formatInstanceSpecs(opts.InstanceTypes)) + + // Track successful creations + var successfulWorkspaces []*entity.Workspace + var fatalError error + + // Try each instance type in order, attempting to create ALL instances with that type + // before falling back to the next type + for _, spec := range opts.InstanceTypes { + // Check if we've created enough instances + if len(successfulWorkspaces) >= opts.Count { + break + } + + remaining := opts.Count - len(successfulWorkspaces) + logf("Trying %s for %d instance(s)...\n", spec.Type, remaining) + + // Create instances with this type (in parallel if requested) + var mu sync.Mutex + var wg sync.WaitGroup + + // Determine how many parallel workers to use + workerCount := opts.Parallel + if workerCount > remaining { + workerCount = remaining + } + + // Track which instance indices need to be created + indicesToCreate := make(chan int, remaining) + for i := len(successfulWorkspaces); i < opts.Count; i++ { + indicesToCreate <- i + } + close(indicesToCreate) + + // Track results for this type + var typeSuccesses []*entity.Workspace + var typeHadFailure bool + + for i := 0; i < workerCount; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + + for idx := range indicesToCreate { + // Check if we've already created enough + mu.Lock() + currentSuccessCount := len(successfulWorkspaces) + len(typeSuccesses) + if currentSuccessCount >= opts.Count { + mu.Unlock() + return + } + mu.Unlock() + + // Determine instance name + instanceName := opts.Name + if opts.Count > 1 { + instanceName = fmt.Sprintf("%s-%d", opts.Name, idx+1) + } + + logf("[Worker %d] Trying %s for instance '%s'...\n", workerID+1, spec.Type, instanceName) + + // Attempt to create the workspace + workspace, err := createWorkspaceWithType(gpuCreateStore, org.ID, instanceName, spec.Type, spec.DiskGB, user, allInstanceTypes, opts.StartupScript) + + if err != nil { + errStr := err.Error() + if piped { + logf("[Worker %d] %s Failed: %s\n", workerID+1, spec.Type, errStr) + } else { + logf("[Worker %d] %s Failed: %s\n", workerID+1, t.Yellow(spec.Type), errStr) + } + + mu.Lock() + typeHadFailure = true + // Check for fatal errors + if strings.Contains(errStr, "duplicate workspace") { + fatalError = fmt.Errorf("workspace '%s' already exists. Use a different name or delete the existing workspace", instanceName) + } + mu.Unlock() + } else { + if piped { + logf("[Worker %d] %s Success! Created instance '%s'\n", workerID+1, spec.Type, instanceName) + } else { + logf("[Worker %d] %s Success! Created instance '%s'\n", workerID+1, t.Green(spec.Type), instanceName) + } + mu.Lock() + typeSuccesses = append(typeSuccesses, workspace) + mu.Unlock() + } + } + }(i) + } + + wg.Wait() + + // Add successful creations from this type + successfulWorkspaces = append(successfulWorkspaces, typeSuccesses...) + + // Check for fatal error + if fatalError != nil { + logf("\nError: %s\n", fatalError.Error()) + break + } + + // If this type worked for all remaining instances, we're done + if !typeHadFailure && len(successfulWorkspaces) >= opts.Count { + break + } + + // If we still need more instances and this type had failures, try the next type + if len(successfulWorkspaces) < opts.Count && typeHadFailure { + logf("\nType %s had failures, trying next type...\n\n", spec.Type) + } + } + + // Check if we created enough instances + if len(successfulWorkspaces) < opts.Count { + logf("\nWarning: Only created %d/%d instances\n", len(successfulWorkspaces), opts.Count) + + if len(successfulWorkspaces) > 0 { + logf("Successfully created instances:\n") + for _, ws := range successfulWorkspaces { + logf(" - %s (ID: %s)\n", ws.Name, ws.ID) + } + } + + return breverrors.NewValidationError(fmt.Sprintf("could only create %d/%d instances", len(successfulWorkspaces), opts.Count)) + } + + // If we created more than needed, clean up extras + if len(successfulWorkspaces) > opts.Count { + extras := successfulWorkspaces[opts.Count:] + logf("\nCleaning up %d extra instance(s)...\n", len(extras)) + + for _, ws := range extras { + logf(" Deleting %s...", ws.Name) + _, err := gpuCreateStore.DeleteWorkspace(ws.ID) + if err != nil { + logf(" Failed\n") + } else { + logf(" Done\n") + } + } + + successfulWorkspaces = successfulWorkspaces[:opts.Count] + } + + // Wait for instances to be ready (unless detached) + if !opts.Detached { + logf("\nWaiting for instance(s) to be ready...\n") + logf("You can safely ctrl+c to exit\n") + + for _, ws := range successfulWorkspaces { + err := pollUntilReady(t, ws.ID, gpuCreateStore, opts.Timeout, piped, logf) + if err != nil { + logf(" %s: Timeout waiting for ready state\n", ws.Name) + } + } + } + + // If output is piped, just print instance name(s) for chaining with brev shell/open + if piped { + for _, ws := range successfulWorkspaces { + fmt.Println(ws.Name) + } + return nil + } + + // Print summary + fmt.Print("\n") + t.Vprint(t.Green(fmt.Sprintf("Successfully created %d instance(s)!\n\n", len(successfulWorkspaces)))) + + for _, ws := range successfulWorkspaces { + t.Vprintf("Instance: %s\n", t.Green(ws.Name)) + t.Vprintf(" ID: %s\n", ws.ID) + t.Vprintf(" Type: %s\n", ws.InstanceType) + displayConnectBreadCrumb(t, ws) + fmt.Print("\n") + } + + return nil +} + +// createWorkspaceWithType creates a workspace with the specified instance type +func createWorkspaceWithType(gpuCreateStore GPUCreateStore, orgID, name, instanceType string, diskGB float64, user *entity.User, allInstanceTypes *gpusearch.AllInstanceTypesResponse, startupScript string) (*entity.Workspace, error) { + clusterID := config.GlobalConfig.GetDefaultClusterID() + cwOptions := store.NewCreateWorkspacesOptions(clusterID, name) + cwOptions.WithInstanceType(instanceType) + cwOptions = resolveWorkspaceUserOptions(cwOptions, user) + + // Set disk size if specified (convert GB to Gi format) + if diskGB > 0 { + cwOptions.DiskStorage = fmt.Sprintf("%.0fGi", diskGB) + } + + // Look up the workspace group ID for this instance type + if allInstanceTypes != nil { + workspaceGroupID := allInstanceTypes.GetWorkspaceGroupID(instanceType) + if workspaceGroupID != "" { + cwOptions.WorkspaceGroupID = workspaceGroupID + } + } + + // Set startup script if provided using VMBuild lifecycle script + if startupScript != "" { + cwOptions.VMBuild = &store.VMBuild{ + ForceJupyterInstall: true, + LifeCycleScriptAttr: &store.LifeCycleScriptAttr{ + Script: startupScript, + }, + } + } + + workspace, err := gpuCreateStore.CreateWorkspace(orgID, cwOptions) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return workspace, nil +} + +// resolveWorkspaceUserOptions sets workspace template and class based on user type +func resolveWorkspaceUserOptions(options *store.CreateWorkspacesOptions, user *entity.User) *store.CreateWorkspacesOptions { + if options.WorkspaceTemplateID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceTemplateID = store.DevWorkspaceTemplateID + } else { + options.WorkspaceTemplateID = store.UserWorkspaceTemplateID + } + } + if options.WorkspaceClassID == "" { + if featureflag.IsAdmin(user.GlobalUserType) { + options.WorkspaceClassID = store.DevWorkspaceClassID + } else { + options.WorkspaceClassID = store.UserWorkspaceClassID + } + } + return options +} + +// pollUntilReady waits for a workspace to reach the running state +func pollUntilReady(t *terminal.Terminal, wsID string, gpuCreateStore GPUCreateStore, timeout time.Duration, piped bool, logf func(string, ...interface{})) error { + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + ws, err := gpuCreateStore.GetWorkspace(wsID) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if ws.Status == entity.Running { + if piped { + logf(" %s: Ready\n", ws.Name) + } else { + logf(" %s: %s\n", ws.Name, t.Green("Ready")) + } + return nil + } + + if ws.Status == entity.Failure { + return breverrors.NewValidationError(fmt.Sprintf("instance %s failed", ws.Name)) + } + + time.Sleep(5 * time.Second) + } + + return breverrors.NewValidationError("timeout waiting for instance to be ready") +} + +// displayConnectBreadCrumb shows connection instructions +func displayConnectBreadCrumb(t *terminal.Terminal, workspace *entity.Workspace) { + t.Vprintf(" Connect:\n") + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev open %s", workspace.Name))) + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("brev shell %s", workspace.Name))) +} diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go new file mode 100644 index 00000000..2ad98680 --- /dev/null +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -0,0 +1,419 @@ +package gpucreate + +import ( + "strings" + "testing" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/stretchr/testify/assert" +) + +// MockGPUCreateStore is a mock implementation of GPUCreateStore for testing +type MockGPUCreateStore struct { + User *entity.User + Org *entity.Organization + Workspaces map[string]*entity.Workspace + CreateError error + CreateErrorTypes map[string]error // Errors for specific instance types + DeleteError error + CreatedWorkspaces []*entity.Workspace + DeletedWorkspaceIDs []string +} + +func NewMockGPUCreateStore() *MockGPUCreateStore { + return &MockGPUCreateStore{ + User: &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + }, + Org: &entity.Organization{ + ID: "org-123", + Name: "test-org", + }, + Workspaces: make(map[string]*entity.Workspace), + CreateErrorTypes: make(map[string]error), + CreatedWorkspaces: []*entity.Workspace{}, + DeletedWorkspaceIDs: []string{}, + } +} + +func (m *MockGPUCreateStore) GetCurrentUser() (*entity.User, error) { + return m.User, nil +} + +func (m *MockGPUCreateStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.Org, nil +} + +func (m *MockGPUCreateStore) GetWorkspace(workspaceID string) (*entity.Workspace, error) { + if ws, ok := m.Workspaces[workspaceID]; ok { + return ws, nil + } + return &entity.Workspace{ + ID: workspaceID, + Status: entity.Running, + }, nil +} + +func (m *MockGPUCreateStore) CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) { + // Check for type-specific errors first + if err, ok := m.CreateErrorTypes[options.InstanceType]; ok { + return nil, err + } + + if m.CreateError != nil { + return nil, m.CreateError + } + + ws := &entity.Workspace{ + ID: "ws-" + options.Name, + Name: options.Name, + InstanceType: options.InstanceType, + Status: entity.Running, + } + m.Workspaces[ws.ID] = ws + m.CreatedWorkspaces = append(m.CreatedWorkspaces, ws) + return ws, nil +} + +func (m *MockGPUCreateStore) DeleteWorkspace(workspaceID string) (*entity.Workspace, error) { + if m.DeleteError != nil { + return nil, m.DeleteError + } + + m.DeletedWorkspaceIDs = append(m.DeletedWorkspaceIDs, workspaceID) + ws := m.Workspaces[workspaceID] + delete(m.Workspaces, workspaceID) + return ws, nil +} + +func (m *MockGPUCreateStore) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) { + return []entity.Workspace{}, nil +} + +func (m *MockGPUCreateStore) GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) { + return nil, nil +} + +func (m *MockGPUCreateStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + // Return a default set of instance types for testing + return &gpusearch.InstanceTypesResponse{ + Items: []gpusearch.InstanceType{ + { + Type: "g5.xlarge", + SupportedGPUs: []gpusearch.GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + SupportedStorage: []gpusearch.Storage{ + {Size: "500GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: gpusearch.BasePrice{Currency: "USD", Amount: "1.006"}, + EstimatedDeployTime: "5m0s", + }, + }, + }, nil +} + +func TestIsValidInstanceType(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"Valid AWS instance type", "g5.xlarge", true}, + {"Valid AWS large instance", "p4d.24xlarge", true}, + {"Valid GCP instance type", "n1-highmem-4:nvidia-tesla-t4:1", true}, + {"Single letter", "a", false}, + {"No numbers", "xlarge", false}, + {"No letters", "12345", false}, + {"Empty string", "", false}, + {"Single character", "1", false}, + {"Valid with colon", "g5:xlarge", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isValidInstanceType(tt.input) + assert.Equal(t, tt.expected, result, "Validation failed for %s", tt.input) + }) + } +} + +func TestParseInstanceTypesFromFlag(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"Single type", "g5.xlarge", []string{"g5.xlarge"}}, + {"Multiple types comma separated", "g5.xlarge,g5.2xlarge,p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"With spaces", "g5.xlarge, g5.2xlarge, p3.2xlarge", []string{"g5.xlarge", "g5.2xlarge", "p3.2xlarge"}}, + {"Empty string", "", []string{}}, + {"Only spaces", " ", []string{}}, + {"Trailing comma", "g5.xlarge,", []string{"g5.xlarge"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseInstanceTypes(tt.input) + assert.NoError(t, err) + + // Handle nil vs empty slice + if len(tt.expected) == 0 { + assert.Empty(t, result) + } else { + // Compare just the Type field of each InstanceSpec + var resultTypes []string + for _, spec := range result { + resultTypes = append(resultTypes, spec.Type) + } + assert.Equal(t, tt.expected, resultTypes) + } + }) + } +} + +func TestGPUCreateOptions(t *testing.T) { + opts := GPUCreateOptions{ + Name: "my-instance", + InstanceTypes: []InstanceSpec{ + {Type: "g5.xlarge", DiskGB: 500}, + {Type: "g5.2xlarge"}, + }, + Count: 2, + Parallel: 3, + Detached: true, + } + + assert.Equal(t, "my-instance", opts.Name) + assert.Len(t, opts.InstanceTypes, 2) + assert.Equal(t, "g5.xlarge", opts.InstanceTypes[0].Type) + assert.Equal(t, 500.0, opts.InstanceTypes[0].DiskGB) + assert.Equal(t, "g5.2xlarge", opts.InstanceTypes[1].Type) + assert.Equal(t, 0.0, opts.InstanceTypes[1].DiskGB) + assert.Equal(t, 2, opts.Count) + assert.Equal(t, 3, opts.Parallel) + assert.True(t, opts.Detached) +} + +func TestResolveWorkspaceUserOptionsStandard(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.UserWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.UserWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsAdmin(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Admin", + } + + options := &store.CreateWorkspacesOptions{} + result := resolveWorkspaceUserOptions(options, user) + + assert.Equal(t, store.DevWorkspaceTemplateID, result.WorkspaceTemplateID) + assert.Equal(t, store.DevWorkspaceClassID, result.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptionsPreserveExisting(t *testing.T) { + user := &entity.User{ + ID: "user-123", + GlobalUserType: "Standard", + } + + options := &store.CreateWorkspacesOptions{ + WorkspaceTemplateID: "custom-template", + WorkspaceClassID: "custom-class", + } + result := resolveWorkspaceUserOptions(options, user) + + // Should preserve existing values + assert.Equal(t, "custom-template", result.WorkspaceTemplateID) + assert.Equal(t, "custom-class", result.WorkspaceClassID) +} + +func TestMockGPUCreateStoreBasics(t *testing.T) { + mock := NewMockGPUCreateStore() + + user, err := mock.GetCurrentUser() + assert.NoError(t, err) + assert.Equal(t, "user-123", user.ID) + + org, err := mock.GetActiveOrganizationOrDefault() + assert.NoError(t, err) + assert.Equal(t, "org-123", org.ID) +} + +func TestMockGPUCreateStoreCreateWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + ws, err := mock.CreateWorkspace("org-123", options) + assert.NoError(t, err) + assert.Equal(t, "test-instance", ws.Name) + assert.Equal(t, "g5.xlarge", ws.InstanceType) + assert.Len(t, mock.CreatedWorkspaces, 1) +} + +func TestMockGPUCreateStoreDeleteWorkspace(t *testing.T) { + mock := NewMockGPUCreateStore() + + // First create a workspace + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + ws, _ := mock.CreateWorkspace("org-123", options) + + // Then delete it + _, err := mock.DeleteWorkspace(ws.ID) + assert.NoError(t, err) + assert.Contains(t, mock.DeletedWorkspaceIDs, ws.ID) +} + +func TestMockGPUCreateStoreTypeSpecificError(t *testing.T) { + mock := NewMockGPUCreateStore() + mock.CreateErrorTypes["g5.xlarge"] = assert.AnError + + options := store.NewCreateWorkspacesOptions("cluster-1", "test-instance") + options.WithInstanceType("g5.xlarge") + + _, err := mock.CreateWorkspace("org-123", options) + assert.Error(t, err) + + // Different type should work + options2 := store.NewCreateWorkspacesOptions("cluster-1", "test-instance-2") + options2.WithInstanceType("g5.2xlarge") + + ws, err := mock.CreateWorkspace("org-123", options2) + assert.NoError(t, err) + assert.NotNil(t, ws) +} + +func TestGetDefaultInstanceTypes(t *testing.T) { + mock := NewMockGPUCreateStore() + + // Get default instance types - the mock returns a g5.xlarge which has: + // - 24GB VRAM (>= 20GB total VRAM requirement) + // - 500GB disk (>= 500GB requirement) + // - A10G GPU = 8.6 capability (>= 8.0 requirement) + // - 5m boot time (< 7m requirement) + specs, err := getDefaultInstanceTypes(mock) + assert.NoError(t, err) + assert.Len(t, specs, 1) + assert.Equal(t, "g5.xlarge", specs[0].Type) + assert.Equal(t, 500.0, specs[0].DiskGB) // Should use the instance's disk size +} + +func TestGetDefaultInstanceTypesFiltersOut(t *testing.T) { + // The mock returns a g5.xlarge which meets all requirements + mock := NewMockGPUCreateStore() + + specs, err := getDefaultInstanceTypes(mock) + assert.NoError(t, err) + // Should return the A10G instance which meets all requirements + assert.Len(t, specs, 1) + assert.Equal(t, "g5.xlarge", specs[0].Type) +} + +func TestParseInstanceTypesFromTableOutput(t *testing.T) { + // Simulated table output from brev gpus command + // Note: This tests the parsing logic, not actual stdin reading + tableLines := []string{ + "TYPE GPU COUNT VRAM/GPU TOTAL VRAM CAPABILITY VCPUs $/HR", + "g5.xlarge A10G 1 24 GB 24 GB 8.6 4 $1.01", + "g5.2xlarge A10G 1 24 GB 24 GB 8.6 8 $1.21", + "p4d.24xlarge A100 8 40 GB 320 GB 8.0 96 $32.77", + "", + "Found 3 GPU instance types", + } + + // Test parsing each line (simulating the scanner behavior) + var types []string + lineNum := 0 + for _, line := range tableLines { + lineNum++ + + // Skip header line + if lineNum == 1 && (strings.Contains(line, "TYPE") || strings.Contains(line, "GPU")) { + continue + } + + // Skip empty lines and summary + if line == "" || strings.HasPrefix(line, "Found ") { + continue + } + + // Extract first column + fields := strings.Fields(line) + if len(fields) > 0 && isValidInstanceType(fields[0]) { + types = append(types, fields[0]) + } + } + + assert.Len(t, types, 3) + assert.Contains(t, types, "g5.xlarge") + assert.Contains(t, types, "g5.2xlarge") + assert.Contains(t, types, "p4d.24xlarge") +} + +func TestParseJSONInput(t *testing.T) { + // Simulated JSON output from gpu-search --json + jsonInput := `[ + { + "type": "g5.xlarge", + "provider": "aws", + "gpu_name": "A10G", + "target_disk_gb": 1000 + }, + { + "type": "p4d.24xlarge", + "provider": "aws", + "gpu_name": "A100", + "target_disk_gb": 500 + }, + { + "type": "g6.xlarge", + "provider": "aws", + "gpu_name": "L4" + } + ]` + + specs, err := parseJSONInput(jsonInput) + assert.NoError(t, err) + assert.Len(t, specs, 3) + + // Check first instance with disk + assert.Equal(t, "g5.xlarge", specs[0].Type) + assert.Equal(t, 1000.0, specs[0].DiskGB) + + // Check second instance with different disk + assert.Equal(t, "p4d.24xlarge", specs[1].Type) + assert.Equal(t, 500.0, specs[1].DiskGB) + + // Check third instance without disk (should be 0) + assert.Equal(t, "g6.xlarge", specs[2].Type) + assert.Equal(t, 0.0, specs[2].DiskGB) +} + +func TestFormatInstanceSpecs(t *testing.T) { + specs := []InstanceSpec{ + {Type: "g5.xlarge", DiskGB: 1000}, + {Type: "p4d.24xlarge", DiskGB: 0}, + {Type: "g6.xlarge", DiskGB: 500}, + } + + result := formatInstanceSpecs(specs) + assert.Equal(t, "g5.xlarge (1000GB disk), p4d.24xlarge, g6.xlarge (500GB disk)", result) +} + diff --git a/pkg/cmd/gpusearch/gpusearch.go b/pkg/cmd/gpusearch/gpusearch.go new file mode 100644 index 00000000..59cfa3b6 --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch.go @@ -0,0 +1,784 @@ +// Package gpusearch provides a command to search and filter GPU instance types +package gpusearch + +import ( + "encoding/json" + "fmt" + "os" + "regexp" + "sort" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/jedib0t/go-pretty/v6/table" + "github.com/spf13/cobra" +) + +// MemoryBytes represents the memory size with value and unit +type MemoryBytes struct { + Value int64 `json:"value"` + Unit string `json:"unit"` +} + +// GPU represents a GPU configuration within an instance type +type GPU struct { + Count int `json:"count"` + Name string `json:"name"` + Manufacturer string `json:"manufacturer"` + Memory string `json:"memory"` + MemoryBytes MemoryBytes `json:"memory_bytes"` +} + +// BasePrice represents the pricing information +type BasePrice struct { + Currency string `json:"currency"` + Amount string `json:"amount"` +} + +// Storage represents a storage configuration within an instance type +type Storage struct { + Count int `json:"count"` + Size string `json:"size"` + Type string `json:"type"` + MinSize string `json:"min_size"` + MaxSize string `json:"max_size"` + SizeBytes MemoryBytes `json:"size_bytes"` + PricePerGBHr BasePrice `json:"price_per_gb_hr"` // Uses BasePrice since API returns {currency, amount} +} + +// WorkspaceGroup represents a workspace group that can run an instance type +type WorkspaceGroup struct { + ID string `json:"id"` + Name string `json:"name"` + PlatformType string `json:"platformType"` +} + +// InstanceType represents an instance type from the API +type InstanceType struct { + Type string `json:"type"` + SupportedGPUs []GPU `json:"supported_gpus"` + SupportedStorage []Storage `json:"supported_storage"` + Memory string `json:"memory"` + VCPU int `json:"vcpu"` + BasePrice BasePrice `json:"base_price"` + Location string `json:"location"` + SubLocation string `json:"sub_location"` + AvailableLocations []string `json:"available_locations"` + Provider string `json:"provider"` + WorkspaceGroups []WorkspaceGroup `json:"workspace_groups"` + EstimatedDeployTime string `json:"estimated_deploy_time"` + Stoppable bool `json:"stoppable"` + Rebootable bool `json:"rebootable"` + CanModifyFirewallRules bool `json:"can_modify_firewall_rules"` +} + +// InstanceTypesResponse represents the API response +type InstanceTypesResponse struct { + Items []InstanceType `json:"items"` +} + +// AllInstanceTypesResponse represents the authenticated API response with workspace groups +type AllInstanceTypesResponse struct { + AllInstanceTypes []InstanceType `json:"allInstanceTypes"` +} + +// GetWorkspaceGroupID returns the workspace group ID for an instance type, or empty string if not found +func (r *AllInstanceTypesResponse) GetWorkspaceGroupID(instanceType string) string { + for _, it := range r.AllInstanceTypes { + if it.Type == instanceType { + if len(it.WorkspaceGroups) > 0 { + return it.WorkspaceGroups[0].ID + } + } + } + return "" +} + +// GPUSearchStore defines the interface for fetching instance types +type GPUSearchStore interface { + GetInstanceTypes() (*InstanceTypesResponse, error) +} + +var ( + long = `Search and filter GPU instance types available on Brev. + +Filter instances by GPU name, provider, VRAM, total VRAM, GPU compute capability, disk size, and boot time. +Sort results by various columns to find the best instance for your needs. + +Features column shows instance capabilities: + S = Stoppable (can stop and restart without losing data) + R = Rebootable (can reboot the instance) + P = Flex Ports (can modify firewall/port rules)` + + example = ` + # List all GPU instances + brev search + + # Filter by GPU name (case-insensitive, partial match) + brev search --gpu-name A100 + brev search --gpu-name "L40S" + + # Filter by provider/cloud (case-insensitive, partial match) + brev search --provider aws + brev search --provider gcp + + # Filter by minimum VRAM per GPU (in GB) + brev search --min-vram 24 + + # Filter by minimum total VRAM (in GB) + brev search --min-total-vram 80 + + # Filter by minimum GPU compute capability + brev search --min-capability 8.0 + + # Filter by minimum disk size (in GB) + brev search --min-disk 500 + + # Filter by maximum boot time (in minutes) + brev search --max-boot-time 5 + + # Sort by different columns (price, gpu-count, vram, total-vram, vcpu, provider, disk, boot-time) + brev search --sort price + brev search --sort boot-time + brev search --sort disk --desc + + # Combine filters + brev search --gpu-name A100 --min-vram 40 --sort price + brev search --gpu-name H100 --max-boot-time 3 --sort price +` +) + +// NewCmdGPUSearch creates the search command +func NewCmdGPUSearch(t *terminal.Terminal, store GPUSearchStore) *cobra.Command { + var gpuName string + var provider string + var minVRAM float64 + var minTotalVRAM float64 + var minCapability float64 + var minDisk float64 + var maxBootTime int + var sortBy string + var descending bool + var jsonOutput bool + + cmd := &cobra.Command{ + Annotations: map[string]string{"workspace": ""}, + Use: "search", + Aliases: []string{"gpu-search", "gpu", "gpus", "gpu-list"}, + DisableFlagsInUseLine: true, + Short: "Search and filter GPU instance types", + Long: long, + Example: example, + RunE: func(cmd *cobra.Command, args []string) error { + err := RunGPUSearch(t, store, gpuName, provider, minVRAM, minTotalVRAM, minCapability, minDisk, maxBootTime, sortBy, descending, jsonOutput) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil + }, + } + + cmd.Flags().StringVarP(&gpuName, "gpu-name", "g", "", "Filter by GPU name (case-insensitive, partial match)") + cmd.Flags().StringVarP(&provider, "provider", "p", "", "Filter by provider/cloud (case-insensitive, partial match)") + cmd.Flags().Float64VarP(&minVRAM, "min-vram", "v", 0, "Minimum VRAM per GPU in GB") + cmd.Flags().Float64VarP(&minTotalVRAM, "min-total-vram", "t", 0, "Minimum total VRAM (GPU count * VRAM) in GB") + cmd.Flags().Float64VarP(&minCapability, "min-capability", "c", 0, "Minimum GPU compute capability (e.g., 8.0 for Ampere)") + cmd.Flags().Float64Var(&minDisk, "min-disk", 0, "Minimum disk size in GB") + cmd.Flags().IntVar(&maxBootTime, "max-boot-time", 0, "Maximum boot time in minutes") + cmd.Flags().StringVarP(&sortBy, "sort", "s", "price", "Sort by: price, gpu-count, vram, total-vram, vcpu, type, provider, disk, boot-time") + cmd.Flags().BoolVarP(&descending, "desc", "d", false, "Sort in descending order") + cmd.Flags().BoolVar(&jsonOutput, "json", false, "Output results as JSON") + + return cmd +} + +// GPUInstanceInfo holds processed GPU instance information for display +type GPUInstanceInfo struct { + Type string `json:"type"` + Cloud string `json:"cloud"` // Underlying cloud (e.g., hyperstack, aws, gcp) + Provider string `json:"provider"` // Provider/aggregator (e.g., shadeform, aws, gcp) + GPUName string `json:"gpu_name"` + GPUCount int `json:"gpu_count"` + VRAMPerGPU float64 `json:"vram_per_gpu_gb"` + TotalVRAM float64 `json:"total_vram_gb"` + Capability float64 `json:"capability"` + VCPUs int `json:"vcpus"` + Memory string `json:"memory"` + DiskMin float64 `json:"disk_min_gb"` + DiskMax float64 `json:"disk_max_gb"` + DiskPricePerMo float64 `json:"disk_price_per_gb_mo,omitempty"` // $/GB/month for flexible storage + BootTime int `json:"boot_time_seconds"` + Stoppable bool `json:"stoppable"` + Rebootable bool `json:"rebootable"` + FlexPorts bool `json:"flex_ports"` + TargetDisk float64 `json:"target_disk_gb,omitempty"` + PricePerHour float64 `json:"price_per_hour"` + Manufacturer string `json:"-"` // exclude from JSON output +} + +// isStdoutPiped returns true if stdout is being piped (not a terminal) +func isStdoutPiped() bool { + stat, _ := os.Stdout.Stat() + return (stat.Mode() & os.ModeCharDevice) == 0 +} + +// RunGPUSearch executes the GPU search with filters and sorting +func RunGPUSearch(t *terminal.Terminal, store GPUSearchStore, gpuName, provider string, minVRAM, minTotalVRAM, minCapability, minDisk float64, maxBootTime int, sortBy string, descending, jsonOutput bool) error { + // Auto-switch to JSON when stdout is piped (for chaining with provision) + if !jsonOutput && isStdoutPiped() { + jsonOutput = true + } + + response, err := store.GetInstanceTypes() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if response == nil || len(response.Items) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } + t.Vprint(t.Yellow("No instance types found")) + return nil + } + + // Process and filter instances + instances := ProcessInstances(response.Items) + + // Apply filters + filtered := FilterInstances(instances, gpuName, provider, minVRAM, minTotalVRAM, minCapability, minDisk, maxBootTime) + + if len(filtered) == 0 { + if jsonOutput { + fmt.Println("[]") + return nil + } + t.Vprint(t.Yellow("No GPU instances match the specified filters")) + return nil + } + + // Set target disk for each instance + // For flexible storage, use minDisk if specified and within range + for i := range filtered { + inst := &filtered[i] + if inst.DiskMin != inst.DiskMax && minDisk > 0 && minDisk >= inst.DiskMin && minDisk <= inst.DiskMax { + inst.TargetDisk = minDisk + } else { + inst.TargetDisk = inst.DiskMin + } + } + + // Sort instances + SortInstances(filtered, sortBy, descending) + + // Display results + if jsonOutput { + return displayGPUJSON(filtered) + } + + displayGPUTable(t, filtered) + t.Vprintf("\n%s\n", t.Green(fmt.Sprintf("Found %d GPU instance types", len(filtered)))) + + return nil +} + +// displayGPUJSON outputs the GPU instances as JSON +func displayGPUJSON(instances []GPUInstanceInfo) error { + output, err := json.MarshalIndent(instances, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + fmt.Println(string(output)) + return nil +} + +// unitMultipliers maps size units to their GB equivalent +var unitMultipliers = map[string]float64{ + "TiB": 1024, + "TB": 1000, + "GiB": 1, + "GB": 1, + "MiB": 1.0 / 1024, + "MB": 1.0 / 1000, +} + +// parseToGB converts size/memory strings like "22GiB360MiB", "16TiB", "2TiB768GiB" to GB +func parseToGB(s string) float64 { + var totalGB float64 + re := regexp.MustCompile(`(\d+(?:\.\d+)?)\s*(TiB|TB|GiB|GB|MiB|MB)`) + for _, match := range re.FindAllStringSubmatch(s, -1) { + if val, err := strconv.ParseFloat(match[1], 64); err == nil { + totalGB += val * unitMultipliers[match[2]] + } + } + return totalGB +} + +// parseMemoryToGB converts memory string like "22GiB360MiB" or "40GiB" to GB +func parseMemoryToGB(memory string) float64 { + return parseToGB(memory) +} + +// parseSizeToGB parses size strings like "16TiB", "10GiB", "2TiB768GiB" to GB +func parseSizeToGB(size string) float64 { + return parseToGB(size) +} + +// parseDurationToSeconds parses Go duration strings like "7m0s", "1m30s" to seconds +func parseDurationToSeconds(duration string) int { + var totalSeconds int + + // Match hours + hRe := regexp.MustCompile(`(\d+)h`) + if match := hRe.FindStringSubmatch(duration); len(match) > 1 { + if val, err := strconv.Atoi(match[1]); err == nil { + totalSeconds += val * 3600 + } + } + + // Match minutes + mRe := regexp.MustCompile(`(\d+)m`) + if match := mRe.FindStringSubmatch(duration); len(match) > 1 { + if val, err := strconv.Atoi(match[1]); err == nil { + totalSeconds += val * 60 + } + } + + // Match seconds + sRe := regexp.MustCompile(`(\d+)s`) + if match := sRe.FindStringSubmatch(duration); len(match) > 1 { + if val, err := strconv.Atoi(match[1]); err == nil { + totalSeconds += val + } + } + + return totalSeconds +} + +// extractDiskInfo extracts min/max disk size and price from storage configuration +// Returns (minGB, maxGB, pricePerGBMonth). For fixed-size storage, min==max and price is 0. +func extractDiskInfo(storage []Storage) (float64, float64, float64) { + if len(storage) == 0 { + return 0, 0, 0 + } + + // Use the first storage entry + s := storage[0] + + // Convert price per GB per hour to per month (730 hours average) + var pricePerGBMonth float64 + if s.PricePerGBHr.Amount != "" { + pricePerHr, _ := strconv.ParseFloat(s.PricePerGBHr.Amount, 64) + if pricePerHr > 0 { + pricePerGBMonth = pricePerHr * 730 + } + } + + // Check if it's flexible storage (has min_size and max_size) + if s.MinSize != "" && s.MaxSize != "" { + minGB := parseSizeToGB(s.MinSize) + maxGB := parseSizeToGB(s.MaxSize) + return minGB, maxGB, pricePerGBMonth + } + + // Fixed storage - use size or size_bytes + var sizeGB float64 + if s.Size != "" && s.Size != "0B" { + sizeGB = parseSizeToGB(s.Size) + } + + // Fallback to size_bytes + if sizeGB == 0 && s.SizeBytes.Value > 0 { + if s.SizeBytes.Unit == "GiB" || s.SizeBytes.Unit == "GB" { + sizeGB = float64(s.SizeBytes.Value) + } else if s.SizeBytes.Unit == "MiB" || s.SizeBytes.Unit == "MB" { + sizeGB = float64(s.SizeBytes.Value) / 1024 + } else if s.SizeBytes.Unit == "TiB" || s.SizeBytes.Unit == "TB" { + sizeGB = float64(s.SizeBytes.Value) * 1024 + } + } + + // Fixed storage doesn't have separate pricing - it's included in base price + return sizeGB, sizeGB, 0 +} + +// extractCloud extracts the underlying cloud from the instance type and provider +// For aggregators like shadeform, the cloud is in the type name prefix (e.g., "hyperstack_H100" -> "hyperstack") +// For direct providers like aws/gcp, the provider IS the cloud +func extractCloud(instanceType, provider string) string { + // Direct cloud providers - provider is the cloud + directProviders := map[string]bool{ + "aws": true, "gcp": true, "azure": true, "oci": true, + "nebius": true, "crusoe": true, "lambda-labs": true, "launchpad": true, + } + + if directProviders[strings.ToLower(provider)] { + return provider + } + + // For aggregators, try to extract cloud from type name prefix + // Pattern: cloudname_GPUtype (e.g., "hyperstack_H100", "cudo_A40", "latitude_H100") + if idx := strings.Index(instanceType, "_"); idx > 0 { + return instanceType[:idx] + } + + // Fallback to provider + return provider +} + +// gpuCapabilityEntry represents a GPU pattern and its compute capability +type gpuCapabilityEntry struct { + pattern string + capability float64 +} + +// getGPUCapability returns the compute capability for known GPU types +func getGPUCapability(gpuName string) float64 { + gpuName = strings.ToUpper(gpuName) + + // Order matters: more specific patterns must come before less specific ones + // (e.g., "A100" before "A10", "L40S" before "L40") + capabilities := []gpuCapabilityEntry{ + // NVIDIA Professional (before other RTX patterns) + {"RTXPRO6000", 12.0}, + + // NVIDIA Blackwell + {"B300", 10.3}, + {"B200", 10.0}, + {"RTX5090", 10.0}, + + // NVIDIA Hopper + {"H100", 9.0}, + {"H200", 9.0}, + + // NVIDIA Ada Lovelace (L40S before L40, L4; RTX*Ada before RTX*) + {"L40S", 8.9}, + {"L40", 8.9}, + {"L4", 8.9}, + {"RTX6000ADA", 8.9}, + {"RTX4000ADA", 8.9}, + {"RTX4090", 8.9}, + {"RTX4080", 8.9}, + + // NVIDIA Ampere (A100 before A10G, A10) + {"A100", 8.0}, + {"A10G", 8.6}, + {"A10", 8.6}, + {"A40", 8.6}, + {"A6000", 8.6}, + {"A5000", 8.6}, + {"A4000", 8.6}, + {"A30", 8.0}, + {"A16", 8.6}, + {"RTX3090", 8.6}, + {"RTX3080", 8.6}, + + // NVIDIA Turing + {"T4", 7.5}, + {"RTX6000", 7.5}, + {"RTX2080", 7.5}, + + // NVIDIA Volta + {"V100", 7.0}, + + // NVIDIA Pascal (P100 before P40, P4) + {"P100", 6.0}, + {"P40", 6.1}, + {"P4", 6.1}, + + // NVIDIA Maxwell + {"M60", 5.2}, + + // NVIDIA Kepler + {"K80", 3.7}, + + // Gaudi (Habana) - not CUDA compatible + {"HL-205", 0}, + {"GAUDI3", 0}, + {"GAUDI2", 0}, + {"GAUDI", 0}, + } + + for _, entry := range capabilities { + if strings.Contains(gpuName, entry.pattern) { + return entry.capability + } + } + return 0 +} + +// ProcessInstances converts raw instance types to GPUInstanceInfo +func ProcessInstances(items []InstanceType) []GPUInstanceInfo { + var instances []GPUInstanceInfo + + for _, item := range items { + if len(item.SupportedGPUs) == 0 { + continue // Skip non-GPU instances + } + + // Extract disk size and price info from first storage entry + diskMin, diskMax, diskPricePerMo := extractDiskInfo(item.SupportedStorage) + + // Extract boot time + bootTime := parseDurationToSeconds(item.EstimatedDeployTime) + + for _, gpu := range item.SupportedGPUs { + vramPerGPU := parseMemoryToGB(gpu.Memory) + // Also check memory_bytes as fallback + if vramPerGPU == 0 && gpu.MemoryBytes.Value > 0 { + // Convert based on unit + if gpu.MemoryBytes.Unit == "MiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) / 1024 // MiB to GiB + } else if gpu.MemoryBytes.Unit == "GiB" { + vramPerGPU = float64(gpu.MemoryBytes.Value) + } + } + + totalVRAM := vramPerGPU * float64(gpu.Count) + capability := getGPUCapability(gpu.Name) + + price := 0.0 + if item.BasePrice.Amount != "" { + price, _ = strconv.ParseFloat(item.BasePrice.Amount, 64) + } + + instances = append(instances, GPUInstanceInfo{ + Type: item.Type, + Cloud: extractCloud(item.Type, item.Provider), + Provider: item.Provider, + GPUName: gpu.Name, + GPUCount: gpu.Count, + VRAMPerGPU: vramPerGPU, + TotalVRAM: totalVRAM, + Capability: capability, + VCPUs: item.VCPU, + Memory: item.Memory, + DiskMin: diskMin, + DiskMax: diskMax, + DiskPricePerMo: diskPricePerMo, + BootTime: bootTime, + Stoppable: item.Stoppable, + Rebootable: item.Rebootable, + FlexPorts: item.CanModifyFirewallRules, + PricePerHour: price, + Manufacturer: gpu.Manufacturer, + }) + } + } + + return instances +} + +// FilterInstances applies all filters to the instance list +func FilterInstances(instances []GPUInstanceInfo, gpuName, provider string, minVRAM, minTotalVRAM, minCapability, minDisk float64, maxBootTime int) []GPUInstanceInfo { + var filtered []GPUInstanceInfo + + for _, inst := range instances { + // Filter out non-NVIDIA GPUs (AMD, Intel/Habana, etc.) + if !strings.Contains(strings.ToUpper(inst.Manufacturer), "NVIDIA") { + continue + } + + // Filter by GPU name (case-insensitive partial match) + if gpuName != "" && !strings.Contains(strings.ToLower(inst.GPUName), strings.ToLower(gpuName)) { + continue + } + + // Filter by provider (case-insensitive partial match) + if provider != "" && !strings.Contains(strings.ToLower(inst.Provider), strings.ToLower(provider)) { + continue + } + + // Filter by minimum VRAM per GPU + if minVRAM > 0 && inst.VRAMPerGPU < minVRAM { + continue + } + + // Filter by minimum total VRAM + if minTotalVRAM > 0 && inst.TotalVRAM < minTotalVRAM { + continue + } + + // Filter by minimum GPU capability + if minCapability > 0 && inst.Capability < minCapability { + continue + } + + // Filter by minimum disk size (use max available size for comparison) + if minDisk > 0 && inst.DiskMax < minDisk { + continue + } + + // Filter by maximum boot time (convert minutes to seconds for comparison) + // Exclude instances with unknown boot time (0) when filter is specified + if maxBootTime > 0 && (inst.BootTime == 0 || inst.BootTime > maxBootTime*60) { + continue + } + + filtered = append(filtered, inst) + } + + return filtered +} + +// SortInstances sorts the instance list by the specified column +func SortInstances(instances []GPUInstanceInfo, sortBy string, descending bool) { + sort.Slice(instances, func(i, j int) bool { + var less bool + switch strings.ToLower(sortBy) { + case "price": + less = instances[i].PricePerHour < instances[j].PricePerHour + case "gpu-count": + less = instances[i].GPUCount < instances[j].GPUCount + case "vram": + less = instances[i].VRAMPerGPU < instances[j].VRAMPerGPU + case "total-vram": + less = instances[i].TotalVRAM < instances[j].TotalVRAM + case "vcpu": + less = instances[i].VCPUs < instances[j].VCPUs + case "type": + less = instances[i].Type < instances[j].Type + case "capability": + less = instances[i].Capability < instances[j].Capability + case "provider": + less = instances[i].Provider < instances[j].Provider + case "disk": + less = instances[i].DiskMax < instances[j].DiskMax + case "boot-time": + // Instances with no boot time (0) should always appear last + if instances[i].BootTime == 0 && instances[j].BootTime == 0 { + return false // both unknown, equal + } else if instances[i].BootTime == 0 { + return false // i unknown goes after j + } else if instances[j].BootTime == 0 { + return true // j unknown goes after i + } + less = instances[i].BootTime < instances[j].BootTime + default: + less = instances[i].PricePerHour < instances[j].PricePerHour + } + + if descending { + return !less + } + return less + }) +} + +// getBrevTableOptions returns table styling options +func getBrevTableOptions() table.Options { + options := table.OptionsDefault + options.DrawBorder = false + options.SeparateColumns = false + options.SeparateRows = false + options.SeparateHeader = false + return options +} + +// formatDiskSize formats the disk size for display +func formatDiskSize(minGB, maxGB float64) string { + if minGB == 0 && maxGB == 0 { + return "-" + } + + formatSize := func(gb float64) string { + if gb >= 1000 { + return fmt.Sprintf("%.0fTB", gb/1000) + } + return fmt.Sprintf("%.0fGB", gb) + } + + if minGB == maxGB { + // Fixed size + return formatSize(minGB) + } + // Range + return fmt.Sprintf("%s-%s", formatSize(minGB), formatSize(maxGB)) +} + +// formatBootTime formats boot time in seconds to a human-readable string +func formatBootTime(seconds int) string { + if seconds == 0 { + return "-" + } + minutes := seconds / 60 + secs := seconds % 60 + if secs == 0 { + return fmt.Sprintf("%dm", minutes) + } + return fmt.Sprintf("%dm%ds", minutes, secs) +} + +// formatFeatures formats feature flags as abbreviated string +// S=stoppable, R=rebootable, P=flex ports (can modify firewall) +func formatFeatures(stoppable, rebootable, flexPorts bool) string { + var features []string + if stoppable { + features = append(features, "S") + } + if rebootable { + features = append(features, "R") + } + if flexPorts { + features = append(features, "P") + } + if len(features) == 0 { + return "-" + } + return strings.Join(features, "") +} + +// displayGPUTable renders the GPU instances as a table +func displayGPUTable(t *terminal.Terminal, instances []GPUInstanceInfo) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + + header := table.Row{"TYPE", "PROVIDER", "GPU", "COUNT", "VRAM/GPU", "TOTAL VRAM", "CAPABILITY", "DISK", "$/GB/MO", "BOOT", "FEATURES", "VCPUs", "$/HR"} + ta.AppendHeader(header) + + for _, inst := range instances { + vramStr := fmt.Sprintf("%.0f GB", inst.VRAMPerGPU) + totalVramStr := fmt.Sprintf("%.0f GB", inst.TotalVRAM) + capStr := "-" + if inst.Capability > 0 { + capStr = fmt.Sprintf("%.1f", inst.Capability) + } + diskStr := formatDiskSize(inst.DiskMin, inst.DiskMax) + diskPriceStr := "-" + if inst.DiskPricePerMo > 0 { + diskPriceStr = fmt.Sprintf("$%.2f", inst.DiskPricePerMo) + } + bootStr := formatBootTime(inst.BootTime) + featuresStr := formatFeatures(inst.Stoppable, inst.Rebootable, inst.FlexPorts) + priceStr := fmt.Sprintf("$%.2f", inst.PricePerHour) + + // Format cloud:provider - only show both if different + providerStr := inst.Provider + if inst.Cloud != "" && inst.Cloud != inst.Provider { + providerStr = fmt.Sprintf("%s:%s", inst.Cloud, inst.Provider) + } + + row := table.Row{ + inst.Type, + providerStr, + t.Green(inst.GPUName), + inst.GPUCount, + vramStr, + totalVramStr, + capStr, + diskStr, + diskPriceStr, + bootStr, + featuresStr, + inst.VCPUs, + priceStr, + } + ta.AppendRow(row) + } + + ta.Render() +} diff --git a/pkg/cmd/gpusearch/gpusearch_test.go b/pkg/cmd/gpusearch/gpusearch_test.go new file mode 100644 index 00000000..93110c9e --- /dev/null +++ b/pkg/cmd/gpusearch/gpusearch_test.go @@ -0,0 +1,550 @@ +package gpusearch + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// MockGPUSearchStore is a mock implementation of GPUSearchStore for testing +type MockGPUSearchStore struct { + Response *InstanceTypesResponse + Err error +} + +func (m *MockGPUSearchStore) GetInstanceTypes() (*InstanceTypesResponse, error) { + if m.Err != nil { + return nil, m.Err + } + return m.Response, nil +} + +func createTestInstanceTypes() *InstanceTypesResponse { + return &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + { + Type: "g5.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "32GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "1.212"}, + }, + { + Type: "p3.2xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "61GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "3.06"}, + }, + { + Type: "p3.8xlarge", + SupportedGPUs: []GPU{ + {Count: 4, Name: "V100", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "244GiB", + VCPU: 32, + BasePrice: BasePrice{Currency: "USD", Amount: "12.24"}, + }, + { + Type: "p4d.24xlarge", + SupportedGPUs: []GPU{ + {Count: 8, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "1152GiB", + VCPU: 96, + BasePrice: BasePrice{Currency: "USD", Amount: "32.77"}, + }, + { + Type: "g4dn.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "T4", Manufacturer: "NVIDIA", Memory: "16GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.526"}, + }, + { + Type: "g6.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "L4", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.805"}, + }, + }, + } +} + +func TestParseMemoryToGB(t *testing.T) { + tests := []struct { + name string + input string + expected float64 + }{ + {"Simple GiB", "24GiB", 24}, + {"GiB with MiB", "22GiB360MiB", 22.3515625}, + {"Simple GB", "16GB", 16}, + {"Large GiB", "1152GiB", 1152}, + {"Empty string", "", 0}, + {"MiB only", "512MiB", 0.5}, + {"With spaces", "24 GiB", 24}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseMemoryToGB(tt.input) + assert.InDelta(t, tt.expected, result, 0.01, "Memory parsing failed for %s", tt.input) + }) + } +} + +func TestGetGPUCapability(t *testing.T) { + tests := []struct { + name string + gpuName string + expected float64 + }{ + {"A100", "A100", 8.0}, + {"A10G", "A10G", 8.6}, + {"V100", "V100", 7.0}, + {"T4", "T4", 7.5}, + {"L4", "L4", 8.9}, + {"L40S", "L40S", 8.9}, + {"H100", "H100", 9.0}, + {"Unknown GPU", "Unknown", 0}, + {"Case insensitive", "a100", 8.0}, + {"Gaudi", "HL-205", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getGPUCapability(tt.gpuName) + assert.Equal(t, tt.expected, result, "GPU capability mismatch for %s", tt.gpuName) + }) + } +} + +func TestProcessInstances(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + assert.Len(t, instances, 7, "Expected 7 GPU instances") + + // Check specific instance + var a10gInstance *GPUInstanceInfo + for i := range instances { + if instances[i].Type == "g5.xlarge" { + a10gInstance = &instances[i] + break + } + } + + assert.NotNil(t, a10gInstance, "g5.xlarge instance should exist") + assert.Equal(t, "A10G", a10gInstance.GPUName) + assert.Equal(t, 1, a10gInstance.GPUCount) + assert.Equal(t, 24.0, a10gInstance.VRAMPerGPU) + assert.Equal(t, 24.0, a10gInstance.TotalVRAM) + assert.Equal(t, 8.6, a10gInstance.Capability) + assert.Equal(t, 4, a10gInstance.VCPUs) + assert.InDelta(t, 1.006, a10gInstance.PricePerHour, 0.001) +} + +func TestFilterInstancesByGPUName(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Filter by A10G + filtered := FilterInstances(instances, "A10G", "", 0, 0, 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances") + + // Filter by V100 + filtered = FilterInstances(instances, "V100", "", 0, 0, 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances") + + // Filter by lowercase (case-insensitive) + filtered = FilterInstances(instances, "v100", "", 0, 0, 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 V100 instances (case-insensitive)") + + // Filter by partial match + filtered = FilterInstances(instances, "A1", "", 0, 0, 0, 0, 0) + assert.Len(t, filtered, 3, "Should have 3 instances matching 'A1' (A10G and A100)") +} + +func TestFilterInstancesByMinVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Filter by min VRAM 24GB + filtered := FilterInstances(instances, "", "", 24, 0, 0, 0, 0) + assert.Len(t, filtered, 4, "Should have 4 instances with >= 24GB VRAM") + + // Filter by min VRAM 40GB + filtered = FilterInstances(instances, "", "", 40, 0, 0, 0, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 40GB VRAM") + assert.Equal(t, "A100", filtered[0].GPUName) +} + +func TestFilterInstancesByMinTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Filter by min total VRAM 60GB + filtered := FilterInstances(instances, "", "", 0, 60, 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 instances with >= 60GB total VRAM") + + // Filter by min total VRAM 300GB + filtered = FilterInstances(instances, "", "", 0, 300, 0, 0, 0) + assert.Len(t, filtered, 1, "Should have 1 instance with >= 300GB total VRAM") + assert.Equal(t, "p4d.24xlarge", filtered[0].Type) +} + +func TestFilterInstancesByMinCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Filter by capability >= 8.0 + filtered := FilterInstances(instances, "", "", 0, 0, 8.0, 0, 0) + assert.Len(t, filtered, 4, "Should have 4 instances with capability >= 8.0") + + // Filter by capability >= 8.5 + filtered = FilterInstances(instances, "", "", 0, 0, 8.5, 0, 0) + assert.Len(t, filtered, 3, "Should have 3 instances with capability >= 8.5") +} + +func TestFilterInstancesCombined(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Filter by GPU name and min VRAM + filtered := FilterInstances(instances, "A10G", "", 24, 0, 0, 0, 0) + assert.Len(t, filtered, 2, "Should have 2 A10G instances with >= 24GB VRAM") + + // Filter by GPU name, min VRAM, and capability + filtered = FilterInstances(instances, "", "", 24, 0, 8.5, 0, 0) + assert.Len(t, filtered, 3, "Should have 3 instances with >= 24GB VRAM and capability >= 8.5") +} + +func TestSortInstancesByPrice(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by price ascending + SortInstances(instances, "price", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "Cheapest should be g4dn.xlarge") + assert.Equal(t, "p4d.24xlarge", instances[len(instances)-1].Type, "Most expensive should be p4d.24xlarge") + + // Sort by price descending + SortInstances(instances, "price", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "Most expensive should be first when descending") +} + +func TestSortInstancesByGPUCount(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by GPU count ascending + SortInstances(instances, "gpu-count", false) + assert.Equal(t, 1, instances[0].GPUCount, "Instances with 1 GPU should be first") + + // Sort by GPU count descending + SortInstances(instances, "gpu-count", true) + assert.Equal(t, 8, instances[0].GPUCount, "Instance with 8 GPUs should be first when descending") +} + +func TestSortInstancesByVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by VRAM ascending + SortInstances(instances, "vram", false) + assert.Equal(t, 16.0, instances[0].VRAMPerGPU, "Instances with 16GB VRAM should be first") + + // Sort by VRAM descending + SortInstances(instances, "vram", true) + assert.Equal(t, 40.0, instances[0].VRAMPerGPU, "Instance with 40GB VRAM should be first when descending") +} + +func TestSortInstancesByTotalVRAM(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by total VRAM ascending + SortInstances(instances, "total-vram", false) + assert.Equal(t, 16.0, instances[0].TotalVRAM, "Instances with 16GB total VRAM should be first") + + // Sort by total VRAM descending + SortInstances(instances, "total-vram", true) + assert.Equal(t, 320.0, instances[0].TotalVRAM, "Instance with 320GB total VRAM should be first when descending") +} + +func TestSortInstancesByVCPU(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by vCPU ascending + SortInstances(instances, "vcpu", false) + assert.Equal(t, 4, instances[0].VCPUs, "Instances with 4 vCPUs should be first") + + // Sort by vCPU descending + SortInstances(instances, "vcpu", true) + assert.Equal(t, 96, instances[0].VCPUs, "Instance with 96 vCPUs should be first when descending") +} + +func TestSortInstancesByCapability(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by capability ascending + SortInstances(instances, "capability", false) + assert.Equal(t, 7.0, instances[0].Capability, "Instances with capability 7.0 should be first") + + // Sort by capability descending + SortInstances(instances, "capability", true) + assert.Equal(t, 8.9, instances[0].Capability, "Instance with capability 8.9 should be first when descending") +} + +func TestSortInstancesByType(t *testing.T) { + response := createTestInstanceTypes() + instances := ProcessInstances(response.Items) + + // Sort by type ascending + SortInstances(instances, "type", false) + assert.Equal(t, "g4dn.xlarge", instances[0].Type, "g4dn.xlarge should be first alphabetically") + + // Sort by type descending + SortInstances(instances, "type", true) + assert.Equal(t, "p4d.24xlarge", instances[0].Type, "p4d.24xlarge should be first when descending") +} + +func TestEmptyInstanceTypes(t *testing.T) { + response := &InstanceTypesResponse{Items: []InstanceType{}} + instances := ProcessInstances(response.Items) + + assert.Len(t, instances, 0, "Should have 0 instances") + + filtered := FilterInstances(instances, "A100", "", 0, 0, 0, 0, 0) + assert.Len(t, filtered, 0, "Filtered should also be empty") +} + +func TestNonGPUInstancesAreFiltered(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "m5.xlarge", + SupportedGPUs: []GPU{}, // No GPUs + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "0.192"}, + }, + { + Type: "g5.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A10G", Manufacturer: "NVIDIA", Memory: "24GiB"}, + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.006"}, + }, + }, + } + + instances := ProcessInstances(response.Items) + assert.Len(t, instances, 1, "Should only have 1 GPU instance, non-GPU instances should be filtered") + assert.Equal(t, "g5.xlarge", instances[0].Type) +} + +func TestMemoryBytesAsFallback(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "test.xlarge", + SupportedGPUs: []GPU{ + {Count: 1, Name: "TestGPU", Manufacturer: "NVIDIA", Memory: "", MemoryBytes: MemoryBytes{Value: 24576, Unit: "MiB"}}, // 24GB in MiB + }, + Memory: "16GiB", + VCPU: 4, + BasePrice: BasePrice{Currency: "USD", Amount: "1.00"}, + }, + }, + } + + instances := ProcessInstances(response.Items) + assert.Len(t, instances, 1) + assert.Equal(t, 24.0, instances[0].VRAMPerGPU, "Should fall back to MemoryBytes when Memory string is empty") +} + +func TestFilterByMaxBootTimeExcludesUnknown(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "fast-boot", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "64GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "3.00"}, + EstimatedDeployTime: "5m0s", + }, + { + Type: "slow-boot", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "64GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "2.50"}, + EstimatedDeployTime: "15m0s", + }, + { + Type: "unknown-boot", + SupportedGPUs: []GPU{ + {Count: 1, Name: "A100", Manufacturer: "NVIDIA", Memory: "40GiB"}, + }, + Memory: "64GiB", + VCPU: 8, + BasePrice: BasePrice{Currency: "USD", Amount: "2.00"}, + EstimatedDeployTime: "", // Unknown boot time + }, + }, + } + + instances := ProcessInstances(response.Items) + assert.Len(t, instances, 3, "Should have 3 instances before filtering") + + // Filter by max boot time of 10 minutes - should exclude unknown and slow-boot + filtered := FilterInstances(instances, "", "", 0, 0, 0, 0, 10) + assert.Len(t, filtered, 1, "Should have 1 instance with boot time <= 10 minutes") + assert.Equal(t, "fast-boot", filtered[0].Type, "Only fast-boot should match") + + // Verify unknown boot time instance is excluded + for _, inst := range filtered { + assert.NotEqual(t, "unknown-boot", inst.Type, "Unknown boot time should be excluded") + assert.NotEqual(t, 0, inst.BootTime, "Instances with 0 boot time should be excluded") + } + + // Without filter, all instances should be included + noFilter := FilterInstances(instances, "", "", 0, 0, 0, 0, 0) + assert.Len(t, noFilter, 3, "Without filter, all 3 instances should be included") +} + +func TestExtractCloud(t *testing.T) { + tests := []struct { + name string + instanceType string + provider string + expectedCloud string + }{ + // Direct providers - cloud equals provider + {"AWS direct", "g5.xlarge", "aws", "aws"}, + {"GCP direct", "n1-highmem-8:nvidia-tesla-v100:8", "gcp", "gcp"}, + {"Nebius direct", "gpu-h100-sxm.1gpu-16vcpu-200gb", "nebius", "nebius"}, + {"OCI direct", "oci.h100x8.sxm", "oci", "oci"}, + {"Lambda Labs direct", "gpu_1x_h100_sxm5", "lambda-labs", "lambda-labs"}, + {"Crusoe direct", "l40s-48gb.1x", "crusoe", "crusoe"}, + {"Launchpad direct", "dmz.h100x2.pcie", "launchpad", "launchpad"}, + + // Aggregators - extract cloud from type name prefix + {"Shadeform hyperstack", "hyperstack_H100", "shadeform", "hyperstack"}, + {"Shadeform latitude", "latitude_H100x4", "shadeform", "latitude"}, + {"Shadeform cudo", "cudo_A40", "shadeform", "cudo"}, + {"Shadeform horizon", "horizon_H100x8", "shadeform", "horizon"}, + {"Shadeform paperspace", "paperspace_H100", "shadeform", "paperspace"}, + + // Edge cases + {"Unknown aggregator no underscore", "someinstance", "unknown-agg", "unknown-agg"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractCloud(tt.instanceType, tt.provider) + assert.Equal(t, tt.expectedCloud, result) + }) + } +} + +func TestExtractDiskInfoWithPricing(t *testing.T) { + // Test flexible storage with pricing + storageWithPrice := []Storage{ + { + MinSize: "50GiB", + MaxSize: "2560GiB", + PricePerGBHr: BasePrice{Currency: "USD", Amount: "0.00014"}, + }, + } + + minGB, maxGB, pricePerMo := extractDiskInfo(storageWithPrice) + assert.Equal(t, 50.0, minGB) + assert.Equal(t, 2560.0, maxGB) + assert.InDelta(t, 0.1022, pricePerMo, 0.001, "Price should be ~$0.10/GB/mo (0.00014 * 730)") + + // Test fixed storage (no pricing) + fixedStorage := []Storage{ + { + Size: "500GiB", + }, + } + + minGB, maxGB, pricePerMo = extractDiskInfo(fixedStorage) + assert.Equal(t, 500.0, minGB) + assert.Equal(t, 500.0, maxGB) + assert.Equal(t, 0.0, pricePerMo, "Fixed storage should have no separate price") + + // Test empty storage + minGB, maxGB, pricePerMo = extractDiskInfo([]Storage{}) + assert.Equal(t, 0.0, minGB) + assert.Equal(t, 0.0, maxGB) + assert.Equal(t, 0.0, pricePerMo) +} + +func TestProcessInstancesCloudExtraction(t *testing.T) { + response := &InstanceTypesResponse{ + Items: []InstanceType{ + { + Type: "hyperstack_H100", + Provider: "shadeform", + SupportedGPUs: []GPU{ + {Count: 1, Name: "H100", Manufacturer: "NVIDIA", Memory: "80GiB"}, + }, + Memory: "180GiB", + VCPU: 28, + BasePrice: BasePrice{Currency: "USD", Amount: "2.28"}, + }, + { + Type: "gpu-h100-sxm.1gpu-16vcpu-200gb", + Provider: "nebius", + SupportedGPUs: []GPU{ + {Count: 1, Name: "H100", Manufacturer: "NVIDIA", Memory: "80GiB"}, + }, + Memory: "200GiB", + VCPU: 16, + BasePrice: BasePrice{Currency: "USD", Amount: "3.54"}, + }, + }, + } + + instances := ProcessInstances(response.Items) + assert.Len(t, instances, 2) + + // Shadeform instance should have cloud extracted from type name + assert.Equal(t, "hyperstack", instances[0].Cloud) + assert.Equal(t, "shadeform", instances[0].Provider) + + // Nebius instance should have cloud = provider + assert.Equal(t, "nebius", instances[1].Cloud) + assert.Equal(t, "nebius", instances[1].Provider) +} diff --git a/pkg/cmd/open/open.go b/pkg/cmd/open/open.go index f0691ece..f5a5fa3f 100644 --- a/pkg/cmd/open/open.go +++ b/pkg/cmd/open/open.go @@ -1,10 +1,12 @@ package open import ( + "bufio" "errors" "fmt" "os" "os/exec" + "runtime" "strings" "time" @@ -30,15 +32,69 @@ import ( ) const ( - EditorVSCode = "code" - EditorCursor = "cursor" - EditorWindsurf = "windsurf" - EditorTmux = "tmux" + EditorVSCode = "code" + EditorCursor = "cursor" + EditorWindsurf = "windsurf" + EditorTerminal = "terminal" + EditorTmux = "tmux" ) var ( - openLong = "[command in beta] This will open VS Code, Cursor, Windsurf, or tmux SSH-ed in to your instance. You must have the editor installed in your path." - openExample = "brev open instance_id_or_name\nbrev open instance\nbrev open instance code\nbrev open instance cursor\nbrev open instance windsurf\nbrev open instance tmux\nbrev open --set-default cursor\nbrev open --set-default windsurf\nbrev open --set-default tmux" + openLong = `[command in beta] This will open an editor SSH-ed in to your instance. + +Supported editors: + code - VS Code + cursor - Cursor + windsurf - Windsurf + terminal - Opens a new terminal window with SSH + tmux - Opens a new terminal window with SSH + tmux session + +Terminal support by platform: + macOS: Terminal.app + Linux: gnome-terminal, konsole, or xterm + WSL: Windows Terminal (wt.exe) + Windows: Windows Terminal or cmd + +You must have the editor installed in your path.` + openExample = ` # Open an instance by name or ID + brev open instance_id_or_name + brev open my-instance + + # Open multiple instances (each in separate editor window) + brev open instance1 instance2 instance3 + + # Open with a specific editor + brev open my-instance code + brev open my-instance cursor + brev open my-instance windsurf + brev open my-instance terminal + brev open my-instance tmux + + # Open multiple instances with specific editor (flag is explicit) + brev open instance1 instance2 --editor cursor + brev open instance1 instance2 -e cursor + + # Or use positional arg (last arg is editor if it matches code/cursor/windsurf/tmux) + brev open instance1 instance2 cursor + + # Set a default editor + brev open --set-default cursor + brev open --set-default windsurf + + # Create a GPU instance and open it immediately (reads instance name from stdin) + brev create my-instance | brev open + + # Open a cluster (multiple instances from stdin) + brev create my-cluster --count 3 | brev open + + # Create with specific GPU and open in Cursor + brev search --gpu-name A100 | brev create ml-box | brev open cursor + + # Open in a new terminal window with SSH + brev create my-instance | brev open terminal + + # Open in a new terminal window with tmux (supports multiple instances) + brev create my-cluster --count 3 | brev open tmux` ) type OpenStore interface { @@ -59,10 +115,11 @@ func NewCmdOpen(t *terminal.Terminal, store OpenStore, noLoginStartStore OpenSto var directory string var host bool var setDefault string + var editor string cmd := &cobra.Command{ Annotations: map[string]string{"access": ""}, - Use: "open", + Use: "open [instance...] [editor]", DisableFlagsInUseLine: true, Short: "[beta] Open VSCode, Cursor, Windsurf, or tmux to your instance", Long: openLong, @@ -72,7 +129,8 @@ func NewCmdOpen(t *terminal.Terminal, store OpenStore, noLoginStartStore OpenSto if setDefaultFlag != "" { return cobra.NoArgs(cmd, args) } - return cobra.RangeArgs(1, 2)(cmd, args) + // Allow arbitrary args: instance names can come from stdin, last arg might be editor + return nil }), ValidArgsFunction: completions.GetAllWorkspaceNameCompletionHandler(noLoginStartStore, t), RunE: func(cmd *cobra.Command, args []string) error { @@ -80,19 +138,41 @@ func NewCmdOpen(t *terminal.Terminal, store OpenStore, noLoginStartStore OpenSto return handleSetDefault(t, setDefault) } - setupDoneString := "------ Git repo cloned ------" - if waitForSetupToFinish { - setupDoneString = "------ Done running execs ------" + // Validate editor flag if provided + if editor != "" && !isEditorType(editor) { + return breverrors.NewValidationError(fmt.Sprintf("invalid editor: %s. Must be 'code', 'cursor', 'windsurf', or 'tmux'", editor)) } - editorType, err := determineEditorType(args) + // Get instance names and editor type from args or stdin + instanceNames, editorType, err := getInstanceNamesAndEditor(args, editor) if err != nil { return breverrors.WrapAndTrace(err) } - err = runOpenCommand(t, store, args[0], setupDoneString, directory, host, editorType) - if err != nil { - return breverrors.WrapAndTrace(err) + + setupDoneString := "------ Git repo cloned ------" + if waitForSetupToFinish { + setupDoneString = "------ Done running execs ------" + } + + // Open each instance + var lastErr error + for _, instanceName := range instanceNames { + if len(instanceNames) > 1 { + fmt.Fprintf(os.Stderr, "Opening %s...\n", instanceName) + } + err = runOpenCommand(t, store, instanceName, setupDoneString, directory, host, editorType) + if err != nil { + if len(instanceNames) > 1 { + fmt.Fprintf(os.Stderr, "Error opening %s: %v\n", instanceName, err) + lastErr = err + continue + } + return breverrors.WrapAndTrace(err) + } + } + if lastErr != nil { + return breverrors.NewValidationError("one or more instances failed to open") } return nil }, @@ -100,14 +180,70 @@ func NewCmdOpen(t *terminal.Terminal, store OpenStore, noLoginStartStore OpenSto cmd.Flags().BoolVarP(&host, "host", "", false, "ssh into the host machine instead of the container") cmd.Flags().BoolVarP(&waitForSetupToFinish, "wait", "w", false, "wait for setup to finish") cmd.Flags().StringVarP(&directory, "dir", "d", "", "directory to open") - cmd.Flags().StringVar(&setDefault, "set-default", "", "set default editor (code, cursor, windsurf, or tmux)") + cmd.Flags().StringVar(&setDefault, "set-default", "", "set default editor (code, cursor, windsurf, terminal, or tmux)") + cmd.Flags().StringVarP(&editor, "editor", "e", "", "editor to use (code, cursor, windsurf, terminal, or tmux)") return cmd } +// isEditorType checks if a string is a valid editor type +func isEditorType(s string) bool { + return s == EditorVSCode || s == EditorCursor || s == EditorWindsurf || s == EditorTerminal || s == EditorTmux +} + +// getInstanceNamesAndEditor gets instance names from args/stdin and determines editor type +// editorFlag takes precedence, otherwise last arg may be an editor type (code, cursor, windsurf, tmux) +func getInstanceNamesAndEditor(args []string, editorFlag string) ([]string, string, error) { + var names []string + editorType := editorFlag + + // If no editor flag, check if last arg is an editor type + if editorType == "" && len(args) > 0 && isEditorType(args[len(args)-1]) { + editorType = args[len(args)-1] + args = args[:len(args)-1] + } + + // Add names from remaining args + names = append(names, args...) + + // Check if stdin is piped + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + // Stdin is piped, read instance names (one per line) + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + name := strings.TrimSpace(scanner.Text()) + if name != "" { + names = append(names, name) + } + } + } + + if len(names) == 0 { + return nil, "", breverrors.NewValidationError("instance name required: provide as argument or pipe from another command") + } + + // If no editor specified, get default + if editorType == "" { + homeDir, err := os.UserHomeDir() + if err != nil { + editorType = EditorVSCode + } else { + settings, err := files.ReadPersonalSettings(files.AppFs, homeDir) + if err != nil { + editorType = EditorVSCode + } else { + editorType = settings.DefaultEditor + } + } + } + + return names, editorType, nil +} + func handleSetDefault(t *terminal.Terminal, editorType string) error { - if editorType != EditorVSCode && editorType != EditorCursor && editorType != EditorWindsurf && editorType != EditorTmux { - return fmt.Errorf("invalid editor type: %s. Must be 'code', 'cursor', 'windsurf', or 'tmux'", editorType) + if !isEditorType(editorType) { + return fmt.Errorf("invalid editor type: %s. Must be 'code', 'cursor', 'windsurf', 'terminal', or 'tmux'", editorType) } homeDir, err := os.UserHomeDir() @@ -128,28 +264,6 @@ func handleSetDefault(t *terminal.Terminal, editorType string) error { return nil } -func determineEditorType(args []string) (string, error) { - if len(args) == 2 { - editorType := args[1] - if editorType != EditorVSCode && editorType != EditorCursor && editorType != EditorWindsurf && editorType != EditorTmux { - return "", fmt.Errorf("invalid editor type: %s. Must be 'code', 'cursor', 'windsurf', or 'tmux'", editorType) - } - return editorType, nil - } - - homeDir, err := os.UserHomeDir() - if err != nil { - return EditorVSCode, nil - } - - settings, err := files.ReadPersonalSettings(files.AppFs, homeDir) - if err != nil { - return EditorVSCode, nil - } - - return settings.DefaultEditor, nil -} - // Fetch workspace info, then open code editor func runOpenCommand(t *terminal.Terminal, tstore OpenStore, wsIDOrName string, setupDoneString string, directory string, host bool, editorType string) error { //nolint:funlen,gocyclo // define brev command // todo check if workspace is stopped and start if it if it is stopped @@ -360,6 +474,8 @@ func getEditorName(editorType string) string { return "Cursor" case EditorWindsurf: return "Windsurf" + case EditorTerminal: + return "Terminal" case EditorTmux: return "tmux" default: @@ -390,8 +506,10 @@ func openEditorByType(t *terminal.Terminal, editorType string, sshAlias string, case EditorWindsurf: tryToInstallWindsurfExtensions(t, extensions) return openWindsurf(sshAlias, path, tstore) + case EditorTerminal: + return openTerminal(sshAlias, path, tstore) case EditorTmux: - return openTmux(sshAlias, path, tstore) + return openTerminalWithTmux(sshAlias, path, tstore) default: tryToInstallExtensions(t, extensions) return openVsCode(sshAlias, path, tstore) @@ -539,8 +657,75 @@ func getWindowsWindsurfPaths(store vscodePathStore) []string { return paths } -func openTmux(sshAlias string, path string, store OpenStore) error { +// openInNewTerminalWindow opens a command in a new terminal window based on the platform +// macOS: Terminal.app via osascript +// Linux: gnome-terminal, konsole, or xterm (tries in order) +// Windows/WSL: Windows Terminal (wt.exe) +func openInNewTerminalWindow(command string) error { + switch runtime.GOOS { + case "darwin": + // macOS: use osascript to open Terminal.app + script := fmt.Sprintf(`tell application "Terminal" + activate + do script "%s" +end tell`, command) + cmd := exec.Command("osascript", "-e", script) // #nosec G204 + return cmd.Run() + + case "linux": + // Check if we're in WSL by looking for wt.exe + if _, err := exec.LookPath("wt.exe"); err == nil { + // WSL: use Windows Terminal + cmd := exec.Command("wt.exe", "new-tab", "bash", "-c", command) // #nosec G204 + return cmd.Run() + } + // Try gnome-terminal first (Ubuntu/GNOME) + if _, err := exec.LookPath("gnome-terminal"); err == nil { + cmd := exec.Command("gnome-terminal", "--", "bash", "-c", command+"; exec bash") // #nosec G204 + return cmd.Run() + } + // Try konsole (KDE) + if _, err := exec.LookPath("konsole"); err == nil { + cmd := exec.Command("konsole", "-e", "bash", "-c", command+"; exec bash") // #nosec G204 + return cmd.Run() + } + // Try xterm as fallback + if _, err := exec.LookPath("xterm"); err == nil { + cmd := exec.Command("xterm", "-e", "bash", "-c", command+"; exec bash") // #nosec G204 + return cmd.Run() + } + return breverrors.NewValidationError("no supported terminal emulator found. Install gnome-terminal, konsole, or xterm") + + case "windows": + // Windows: use Windows Terminal + if _, err := exec.LookPath("wt.exe"); err == nil { + cmd := exec.Command("wt.exe", "new-tab", "cmd", "/c", command) // #nosec G204 + return cmd.Run() + } + // Fallback to start cmd + cmd := exec.Command("cmd", "/c", "start", "cmd", "/k", command) // #nosec G204 + return cmd.Run() + + default: + return breverrors.NewValidationError(fmt.Sprintf("'terminal' editor is not supported on %s", runtime.GOOS)) + } +} + +func openTerminal(sshAlias string, path string, store OpenStore) error { _ = store // unused parameter required by interface + _ = path // unused, just opens SSH + + sshCmd := fmt.Sprintf("ssh %s", sshAlias) + err := openInNewTerminalWindow(sshCmd) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +func openTerminalWithTmux(sshAlias string, path string, store OpenStore) error { + _ = store // unused parameter required by interface + err := ensureTmuxInstalled(sshAlias) if err != nil { return breverrors.WrapAndTrace(err) @@ -548,23 +733,21 @@ func openTmux(sshAlias string, path string, store OpenStore) error { sessionName := "brev" + // Check if tmux session exists checkCmd := fmt.Sprintf("ssh %s 'tmux has-session -t %s 2>/dev/null'", sshAlias, sessionName) checkExec := exec.Command("bash", "-c", checkCmd) // #nosec G204 - err = checkExec.Run() + checkErr := checkExec.Run() var tmuxCmd string - if err == nil { + if checkErr == nil { + // Session exists, attach to it tmuxCmd = fmt.Sprintf("ssh -t %s 'tmux attach-session -t %s'", sshAlias, sessionName) } else { + // Create new session tmuxCmd = fmt.Sprintf("ssh -t %s 'cd %s && tmux new-session -s %s'", sshAlias, path, sessionName) } - sshCmd := exec.Command("bash", "-c", tmuxCmd) // #nosec G204 - sshCmd.Stderr = os.Stderr - sshCmd.Stdout = os.Stdout - sshCmd.Stdin = os.Stdin - - err = sshCmd.Run() + err = openInNewTerminalWindow(tmuxCmd) if err != nil { return breverrors.WrapAndTrace(err) } diff --git a/pkg/cmd/shell/shell.go b/pkg/cmd/shell/shell.go index 49474ad3..c24837f3 100644 --- a/pkg/cmd/shell/shell.go +++ b/pkg/cmd/shell/shell.go @@ -1,6 +1,7 @@ package shell import ( + "bufio" "errors" "fmt" "os" @@ -9,7 +10,6 @@ import ( "time" "github.com/brevdev/brev-cli/pkg/analytics" - "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" "github.com/brevdev/brev-cli/pkg/cmd/refresh" @@ -26,7 +26,34 @@ import ( var ( openLong = "[command in beta] This will shell in to your instance" - openExample = "brev shell instance_id_or_name\nbrev shell instance\nbrev open h9fp5vxwe" + openExample = ` # SSH into an instance by name or ID + brev shell instance_id_or_name + brev shell my-instance + + # Run a command on the instance (non-interactive, pipes stdout/stderr) + brev shell my-instance -c "nvidia-smi" + brev shell my-instance -c "python train.py" + + # Run a command on multiple instances + brev shell instance1 instance2 instance3 -c "nvidia-smi" + + # Run a script file on the instance + brev shell my-instance -c @setup.sh + + # Chain: create and run a command (reads instance names from stdin) + brev create my-instance | brev shell -c "nvidia-smi" + + # Run command on a cluster (multiple instances from stdin) + brev create my-cluster --count 3 | brev shell -c "nvidia-smi" + + # Create a GPU instance and SSH into it (use command substitution for interactive shell) + brev shell $(brev create my-instance) + + # Create with specific GPU and connect + brev shell $(brev search --gpu-name A100 | brev create ml-box) + + # SSH into the host machine instead of the container + brev shell my-instance --host` ) type ShellStore interface { @@ -40,35 +67,111 @@ type ShellStore interface { } func NewCmdShell(t *terminal.Terminal, store ShellStore, noLoginStartStore ShellStore) *cobra.Command { - var runRemoteCMD bool - var directory string var host bool + var command string cmd := &cobra.Command{ Annotations: map[string]string{"access": ""}, - Use: "shell", + Use: "shell [instance...]", Aliases: []string{"ssh"}, DisableFlagsInUseLine: true, Short: "[beta] Open a shell in your instance", Long: openLong, Example: openExample, - Args: cmderrors.TransformToValidationError(cmderrors.TransformToValidationError(cobra.ExactArgs(1))), + Args: cobra.ArbitraryArgs, ValidArgsFunction: completions.GetAllWorkspaceNameCompletionHandler(noLoginStartStore, t), RunE: func(cmd *cobra.Command, args []string) error { - err := runShellCommand(t, store, args[0], directory, host) + // Get instance names from args or stdin + instanceNames, err := getInstanceNames(args) if err != nil { return breverrors.WrapAndTrace(err) } + + // Parse command (can be inline or @filepath) + cmdToRun, err := parseCommand(command) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + // Interactive shell only supports one instance + if cmdToRun == "" && len(instanceNames) > 1 { + return breverrors.NewValidationError("interactive shell only supports one instance; use -c to run a command on multiple instances") + } + + // Run on each instance + var lastErr error + for _, instanceName := range instanceNames { + if len(instanceNames) > 1 { + fmt.Fprintf(os.Stderr, "\n=== %s ===\n", instanceName) + } + err = runShellCommand(t, store, instanceName, host, cmdToRun) + if err != nil { + if len(instanceNames) > 1 { + fmt.Fprintf(os.Stderr, "Error on %s: %v\n", instanceName, err) + lastErr = err + continue + } + return breverrors.WrapAndTrace(err) + } + } + if lastErr != nil { + return breverrors.NewValidationError("one or more instances failed") + } return nil }, } cmd.Flags().BoolVarP(&host, "host", "", false, "ssh into the host machine instead of the container") - cmd.Flags().BoolVarP(&runRemoteCMD, "remote", "r", true, "run remote commands") - cmd.Flags().StringVarP(&directory, "dir", "d", "", "override directory to launch shell") + cmd.Flags().StringVarP(&command, "command", "c", "", "command to run on the instance (use @filename to run a script file)") return cmd } -func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID, directory string, host bool) error { +// getInstanceNames gets instance names from args or stdin (supports multiple) +func getInstanceNames(args []string) ([]string, error) { + var names []string + + // Add names from args + names = append(names, args...) + + // Check if stdin is piped + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + // Stdin is piped, read instance names (one per line) + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + name := strings.TrimSpace(scanner.Text()) + if name != "" { + names = append(names, name) + } + } + } + + if len(names) == 0 { + return nil, breverrors.NewValidationError("instance name required: provide as argument or pipe from another command") + } + + return names, nil +} + +// parseCommand parses the command string, loading from file if prefixed with @ +func parseCommand(command string) (string, error) { + if command == "" { + return "", nil + } + + // If prefixed with @, read from file + if strings.HasPrefix(command, "@") { + filePath := strings.TrimPrefix(command, "@") + content, err := os.ReadFile(filePath) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return string(content), nil + } + + return command, nil +} + +func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID string, host bool, command string) error { s := t.NewSpinner() workspace, err := util.GetUserWorkspaceByNameOrIDErr(sstore, workspaceNameOrID) if err != nil { @@ -114,7 +217,7 @@ func runShellCommand(t *terminal.Terminal, sstore ShellStore, workspaceNameOrID, // legacy environments wont support this and cause errrors, // but we don't want to block the user from using the shell _ = writeconnectionevent.WriteWCEOnEnv(sstore, workspace.DNS) - err = runSSH(workspace, sshName, directory) + err = runSSH(workspace, sshName, command) if err != nil { return breverrors.WrapAndTrace(err) } @@ -162,14 +265,29 @@ func waitForSSHToBeAvailable(sshAlias string, s *spinner.Spinner) error { } } -func runSSH(_ *entity.Workspace, sshAlias, _ string) error { +func runSSH(_ *entity.Workspace, sshAlias string, command string) error { sshAgentEval := "eval $(ssh-agent -s)" - cmd := fmt.Sprintf("ssh %s", sshAlias) + + var cmd string + if command != "" { + // Non-interactive: run command and pipe stdout/stderr + // Escape the command for passing to SSH + escapedCmd := strings.ReplaceAll(command, "'", "'\\''") + cmd = fmt.Sprintf("ssh %s '%s'", sshAlias, escapedCmd) + } else { + // Interactive shell + cmd = fmt.Sprintf("ssh %s", sshAlias) + } + cmd = fmt.Sprintf("%s && %s", sshAgentEval, cmd) sshCmd := exec.Command("bash", "-c", cmd) //nolint:gosec //cmd is user input sshCmd.Stderr = os.Stderr sshCmd.Stdout = os.Stdout - sshCmd.Stdin = os.Stdin + + // Only attach stdin for interactive sessions + if command == "" { + sshCmd.Stdin = os.Stdin + } err := hello.SetHasRunShell(true) if err != nil { diff --git a/pkg/store/instancetypes.go b/pkg/store/instancetypes.go new file mode 100644 index 00000000..e2daa047 --- /dev/null +++ b/pkg/store/instancetypes.go @@ -0,0 +1,73 @@ +package store + +import ( + "encoding/json" + "fmt" + "runtime" + + "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + resty "github.com/go-resty/resty/v2" +) + +const ( + instanceTypesAPIURL = "https://api.brev.dev" + instanceTypesAPIPath = "v1/instance/types" + // Authenticated API for instance types with workspace groups + allInstanceTypesPathPattern = "api/instances/alltypesavailable/%s" +) + +// GetInstanceTypes fetches all available instance types from the public API +func (s NoAuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// GetInstanceTypes fetches all available instance types from the public API +func (s AuthHTTPStore) GetInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + return fetchInstanceTypes() +} + +// fetchInstanceTypes fetches instance types from the public Brev API +func fetchInstanceTypes() (*gpusearch.InstanceTypesResponse, error) { + client := resty.New() + client.SetBaseURL(instanceTypesAPIURL) + + res, err := client.R(). + SetHeader("Accept", "application/json"). + Get(instanceTypesAPIPath) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + var result gpusearch.InstanceTypesResponse + err = json.Unmarshal(res.Body(), &result) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + return &result, nil +} + +// GetAllInstanceTypesWithWorkspaceGroups fetches instance types with workspace groups from the authenticated API +func (s AuthHTTPStore) GetAllInstanceTypesWithWorkspaceGroups(orgID string) (*gpusearch.AllInstanceTypesResponse, error) { + path := fmt.Sprintf(allInstanceTypesPathPattern, orgID) + + var result gpusearch.AllInstanceTypesResponse + res, err := s.authHTTPClient.restyClient.R(). + SetHeader("Content-Type", "application/json"). + SetQueryParam("utm_source", "cli"). + SetQueryParam("os", runtime.GOOS). + SetResult(&result). + Get(path) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + return &result, nil +} diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index 5190d313..5d119f7b 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -34,6 +34,17 @@ type ModifyWorkspaceRequest struct { InstanceType string `json:"instanceType,omitempty"` } +// LifeCycleScriptAttr holds the lifecycle script configuration +type LifeCycleScriptAttr struct { + Script string `json:"script,omitempty"` +} + +// VMBuild holds VM-specific build configuration +type VMBuild struct { + ForceJupyterInstall bool `json:"forceJupyterInstall,omitempty"` + LifeCycleScriptAttr *LifeCycleScriptAttr `json:"lifeCycleScriptAttr,omitempty"` +} + type CreateWorkspacesOptions struct { Name string `json:"name"` WorkspaceGroupID string `json:"workspaceGroupId"` @@ -57,6 +68,7 @@ type CreateWorkspacesOptions struct { DiskStorage string `json:"diskStorage"` BaseImage string `json:"baseImage"` VMOnlyMode bool `json:"vmOnlyMode"` + VMBuild *VMBuild `json:"vmBuild,omitempty"` PortMappings map[string]string `json:"portMappings"` Files interface{} `json:"files"` Labels interface{} `json:"labels"` @@ -88,6 +100,7 @@ var ( var DefaultApplicationList = []entity.Application{DefaultApplication} func NewCreateWorkspacesOptions(clusterID, name string) *CreateWorkspacesOptions { + isStoppable := false return &CreateWorkspacesOptions{ BaseImage: "", Description: "", @@ -95,12 +108,12 @@ func NewCreateWorkspacesOptions(clusterID, name string) *CreateWorkspacesOptions ExecsV1: &entity.ExecsV1{}, Files: nil, InstanceType: "", - IsStoppable: nil, + IsStoppable: &isStoppable, Labels: nil, LaunchJupyterOnStart: false, Name: name, - PortMappings: nil, - ReposV1: nil, + PortMappings: map[string]string{}, + ReposV1: &entity.ReposV1{}, VMOnlyMode: true, WorkspaceGroupID: "GCP", WorkspaceTemplateID: DefaultWorkspaceTemplateID, diff --git a/pkg/util/util.go b/pkg/util/util.go index d9004e1d..299ba5b1 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -13,6 +14,53 @@ import ( "github.com/hashicorp/go-multierror" ) +// isWSL returns true if running in Windows Subsystem for Linux +func isWSL() bool { + if runtime.GOOS != "linux" { + return false + } + // Check for WSL-specific indicators + if _, err := os.Stat("/proc/sys/fs/binfmt_misc/WSLInterop"); err == nil { + return true + } + // Also check /proc/version for "microsoft" or "WSL" + if data, err := os.ReadFile("/proc/version"); err == nil { + lower := strings.ToLower(string(data)) + if strings.Contains(lower, "microsoft") || strings.Contains(lower, "wsl") { + return true + } + } + return false +} + +// wslPathToWindows converts a WSL path like /mnt/c/Users/... to C:\Users\... +func wslPathToWindows(wslPath string) string { + if strings.HasPrefix(wslPath, "/mnt/") && len(wslPath) > 6 { + // Extract drive letter: /mnt/c/... -> c + drive := strings.ToUpper(string(wslPath[5])) + // Get rest of path: /mnt/c/Users/... -> /Users/... + rest := wslPath[6:] + // Convert to Windows path: C:\Users\... + windowsPath := drive + ":" + strings.ReplaceAll(rest, "/", "\\") + return windowsPath + } + return wslPath +} + +// runWindowsExeInWSL runs a Windows executable from WSL using cmd.exe +func runWindowsExeInWSL(exePath string, args []string) ([]byte, error) { + // Convert WSL path to Windows path + windowsPath := wslPathToWindows(exePath) + + // Build the command string for cmd.exe + // We need to quote the path and args properly for Windows + cmdArgs := []string{"/c", windowsPath} + cmdArgs = append(cmdArgs, args...) + + cmd := exec.Command("cmd.exe", cmdArgs...) // #nosec G204 + return cmd.CombinedOutput() +} + // This package should only be used as a holding pattern to be later moved into more specific packages func MapAppend(m map[string]interface{}, n ...map[string]interface{}) map[string]interface{} { @@ -205,6 +253,15 @@ func runManyCursorCommand(cursorpaths []string, args []string) ([]byte, error) { } func runVsCodeCommand(vscodepath string, args []string) ([]byte, error) { + // In WSL, Windows .exe files need to be run through cmd.exe + if isWSL() && (strings.HasSuffix(vscodepath, ".exe") || strings.HasPrefix(vscodepath, "/mnt/")) { + res, err := runWindowsExeInWSL(vscodepath, args) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return res, nil + } + cmd := exec.Command(vscodepath, args...) // #nosec G204 res, err := cmd.CombinedOutput() if err != nil { @@ -214,6 +271,15 @@ func runVsCodeCommand(vscodepath string, args []string) ([]byte, error) { } func runCursorCommand(cursorpath string, args []string) ([]byte, error) { + // In WSL, Windows .exe files need to be run through cmd.exe + if isWSL() && (strings.HasSuffix(cursorpath, ".exe") || strings.HasPrefix(cursorpath, "/mnt/")) { + res, err := runWindowsExeInWSL(cursorpath, args) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return res, nil + } + cmd := exec.Command(cursorpath, args...) // #nosec G204 res, err := cmd.CombinedOutput() if err != nil { @@ -236,6 +302,15 @@ func runManyWindsurfCommand(windsurfpaths []string, args []string) ([]byte, erro } func runWindsurfCommand(windsurfpath string, args []string) ([]byte, error) { + // In WSL, Windows .exe files need to be run through cmd.exe + if isWSL() && (strings.HasSuffix(windsurfpath, ".exe") || strings.HasPrefix(windsurfpath, "/mnt/")) { + res, err := runWindowsExeInWSL(windsurfpath, args) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return res, nil + } + cmd := exec.Command(windsurfpath, args...) // #nosec G204 res, err := cmd.CombinedOutput() if err != nil {