From 8babf7645d841d91d9ca86b84dc1d564d59792ce Mon Sep 17 00:00:00 2001 From: harshil Date: Wed, 14 Jan 2026 17:10:07 +0000 Subject: [PATCH 01/17] Added layer wise reporting on push --- pkg/distribution/oci/remote/remote.go | 72 ++++++++++++++++----------- pkg/distribution/registry/client.go | 9 ++-- 2 files changed, 46 insertions(+), 35 deletions(-) diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index 6fb1dcf1..bf0c044e 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -693,8 +693,21 @@ func (l *remoteLayer) MediaType() (oci.MediaType, error) { return l.desc.MediaType, nil } +// syncWriter is a thread-safe wrapper around io.Writer for concurrent writes +type syncWriter struct { + w io.Writer + mu sync.Mutex +} + +// Write implements io.Writer interface with mutex protection +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + // Write pushes an image to a registry. -func Write(ref reference.Reference, img oci.Image, opts ...Option) error { +func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) error { o := makeOptions(opts...) // Pre-authorize with push scope to ensure we have the right permissions @@ -724,6 +737,12 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { totalSize += size } + // Create a thread-safe writer wrapper for concurrent progress reporting + var safeWriter io.Writer + if w != nil { + safeWriter = &syncWriter{w: w} + } + var completed int64 for _, layer := range layers { digest, err := layer.Digest() @@ -747,6 +766,13 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { Size: size, } + var pr *progress.Reporter + var progressChan chan<- oci.Update + if safeWriter != nil { + pr = progress.NewProgressReporter(safeWriter, progress.PushMsg, size, layer) + progressChan = pr.Updates() + } + rc, err := layer.Compressed() if err != nil { return fmt.Errorf("getting layer content: %w", err) @@ -759,29 +785,29 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { // If already exists, continue if errdefs.IsAlreadyExists(err) || strings.Contains(err.Error(), "already exists") { completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } } continue } - closeProgress(o.progress) + closeProgress(progressChan) return fmt.Errorf("pushing layer: %w", err) } // Wrap the reader with progress tracking to report incremental upload progress // Uses the shared progress.Reader from internal/progress package var reader io.Reader = rc - if o.progress != nil { - reader = progress.NewReaderWithOffset(rc, o.progress, completed) + if progressChan != nil { + reader = progress.NewReaderWithOffset(rc, progressChan, completed) } if _, err := io.Copy(cw, reader); err != nil { cw.Close() rc.Close() - closeProgress(o.progress) + closeProgress(progressChan) return fmt.Errorf("writing layer: %w", err) } @@ -789,27 +815,28 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw.Close() rc.Close() if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) + closeProgress(progressChan) return fmt.Errorf("committing layer: %w", err) } // If it already exists, we still want to update progress completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } } } else { // Successfully committed, update progress completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } } } + closeProgress(progressChan) cw.Close() rc.Close() } @@ -834,20 +861,17 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw, err := pusher.Push(o.ctx, configDesc) if err != nil { if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("pushing config: %w", err) } // If it already exists, we don't have a writer to close, just continue } else { if _, err := cw.Write(rawConfig); err != nil { cw.Close() - closeProgress(o.progress) return fmt.Errorf("writing config: %w", err) } if err := cw.Commit(o.ctx, int64(len(rawConfig)), configDesc.Digest); err != nil { cw.Close() if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("committing config: %w", err) } } @@ -857,19 +881,16 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { // Push manifest rawManifest, err := img.RawManifest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest: %w", err) } manifest, err := img.Manifest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest object: %w", err) } manifestDigest, err := img.Digest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest digest: %w", err) } @@ -882,24 +903,18 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw, err = pusher.Push(o.ctx, manifestDesc) if err != nil { if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("pushing manifest: %w", err) } - // If it already exists, we don't have a writer to close, just continue - // If it already exists, we still want to close progress and return success - closeProgress(o.progress) return nil } if _, err := cw.Write(rawManifest); err != nil { cw.Close() - closeProgress(o.progress) return fmt.Errorf("writing manifest: %w", err) } if err := cw.Commit(o.ctx, int64(len(rawManifest)), manifestDesc.Digest); err != nil { cw.Close() - closeProgress(o.progress) if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { return fmt.Errorf("committing manifest: %w", err) } @@ -908,9 +923,6 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { } cw.Close() - // Close progress channel to signal completion - closeProgress(o.progress) - return nil } diff --git a/pkg/distribution/registry/client.go b/pkg/distribution/registry/client.go index 82fca7a8..d0c6fcf1 100644 --- a/pkg/distribution/registry/client.go +++ b/pkg/distribution/registry/client.go @@ -9,7 +9,6 @@ import ( "strings" "sync" - "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/oci/authn" "github.com/docker/model-runner/pkg/distribution/oci/reference" @@ -264,15 +263,15 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW } imageSize += size } - pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil) - defer pr.Wait() + //pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil) + //defer pr.Wait() // Set up authentication options authOpts := []remote.Option{ remote.WithContext(ctx), remote.WithTransport(t.transport), remote.WithUserAgent(t.userAgent), - remote.WithProgress(pr.Updates()), + //remote.WithProgress(pr.Updates()), remote.WithPlainHTTP(t.plainHTTP), } @@ -283,7 +282,7 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW authOpts = append(authOpts, remote.WithAuthFromKeychain(t.keychain)) } - if err := remote.Write(t.reference, model, authOpts...); err != nil { + if err := remote.Write(t.reference, model, progressWriter, authOpts...); err != nil { return fmt.Errorf("write to registry %q: %w", t.reference.String(), err) } return nil From 4f1834e4c5eaa379d95868b8223adfec7a01289c Mon Sep 17 00:00:00 2001 From: harshil Date: Wed, 14 Jan 2026 17:55:18 +0000 Subject: [PATCH 02/17] Handled safe channel closure --- pkg/distribution/oci/remote/remote.go | 17 ++++++++++++++++- pkg/distribution/registry/client.go | 3 --- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index bf0c044e..f6449ef2 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -743,8 +743,8 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) safeWriter = &syncWriter{w: w} } - var completed int64 for _, layer := range layers { + var completed int64 digest, err := layer.Digest() if err != nil { return fmt.Errorf("getting layer digest: %w", err) @@ -790,10 +790,13 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) Complete: completed, Total: size, } + closeProgress(progressChan) + closeReporter(pr) } continue } closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("pushing layer: %w", err) } @@ -808,6 +811,7 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) cw.Close() rc.Close() closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("writing layer: %w", err) } @@ -816,6 +820,7 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) rc.Close() if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("committing layer: %w", err) } // If it already exists, we still want to update progress @@ -837,6 +842,7 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) } } closeProgress(progressChan) + closeReporter(pr) cw.Close() rc.Close() } @@ -933,6 +939,15 @@ func closeProgress(ch chan<- oci.Update) { } } +// closeReporter safely closes the progress reporter if not nil +func closeReporter(pr *progress.Reporter) { + if pr != nil { + if waitErr := pr.Wait(); waitErr != nil { + fmt.Printf("reporter finished with non-fatal error: %v\n", waitErr) + } + } +} + // Ensure remoteImage is cleaned up properly func (i *remoteImage) Close() error { // The local content store doesn't expose its root path, so cleanup diff --git a/pkg/distribution/registry/client.go b/pkg/distribution/registry/client.go index d0c6fcf1..493f6712 100644 --- a/pkg/distribution/registry/client.go +++ b/pkg/distribution/registry/client.go @@ -263,15 +263,12 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW } imageSize += size } - //pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil) - //defer pr.Wait() // Set up authentication options authOpts := []remote.Option{ remote.WithContext(ctx), remote.WithTransport(t.transport), remote.WithUserAgent(t.userAgent), - //remote.WithProgress(pr.Updates()), remote.WithPlainHTTP(t.plainHTTP), } From 4eee0bce6659362e92b1baf665656d2d15629e45 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Wed, 14 Jan 2026 17:40:08 +0200 Subject: [PATCH 03/17] feat(cli): conditionally init standalone runner in package command Only initialize standalone runner when --push is not used. When pushing to registry, no local runner is needed. This allows `docker model package --push` to work without Docker Desktop or a running model runner. Signed-off-by: Dorin Geman --- cmd/cli/commands/package.go | 4 ++++ cmd/cli/commands/root.go | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 1a8d23ad..f08f7e00 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -291,6 +291,10 @@ func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Clien registry.WithUserAgent("docker-model-cli/" + desktop.Version), ).NewTarget(opts.tag) } else { + // Ensure standalone runner is available when loading locally + if _, err := ensureStandaloneRunnerAvailable(ctx, asPrinter(cmd), false); err != nil { + return fmt.Errorf("unable to initialize standalone model runner: %w", err) + } target, err = newModelRunnerTarget(client, opts.tag) } if err != nil { diff --git a/cmd/cli/commands/root.go b/cmd/cli/commands/root.go index 11896491..b81cfed1 100644 --- a/cmd/cli/commands/root.go +++ b/cmd/cli/commands/root.go @@ -101,7 +101,6 @@ func NewRootCmd(cli *command.DockerCli) *cobra.Command { newStatusCmd(), newPullCmd(), newPushCmd(), - newPackagedCmd(), newListCmd(), newLogsCmd(), newRemoveCmd(), @@ -122,5 +121,8 @@ func NewRootCmd(cli *command.DockerCli) *cobra.Command { // run command handles standalone runner initialization itself (needs debug flag) rootCmd.AddCommand(newRunCmd()) + // package command handles standalone runner initialization itself (only when not pushing) + rootCmd.AddCommand(newPackagedCmd()) + return rootCmd } From 01081ba9280467ed000ee3816b0920237c3a6a39 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Wed, 14 Jan 2026 17:53:39 +0200 Subject: [PATCH 04/17] fix(cli): handle nil modelRunner in ensureStandaloneRunnerAvailable Add nil check for modelRunner before calling EngineKind() to prevent panic in test scenarios where packageModel is called directly without going through the root command's PersistentPreRunE initialization. Signed-off-by: Dorin Geman --- cmd/cli/commands/install-runner.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/cli/commands/install-runner.go b/cmd/cli/commands/install-runner.go index 597ffd39..c830333c 100644 --- a/cmd/cli/commands/install-runner.go +++ b/cmd/cli/commands/install-runner.go @@ -79,6 +79,11 @@ func inspectStandaloneRunner(container container.Summary) *standaloneRunner { // use to initialize a default standalone model runner. It is a no-op in // unsupported contexts or if automatic installs have been disabled. func ensureStandaloneRunnerAvailable(ctx context.Context, printer standalone.StatusPrinter, debug bool) (*standaloneRunner, error) { + // If the model runner context wasn't initialized, then don't do anything. + if modelRunner == nil { + return nil, nil + } + // If we're not in a supported model runner context, then don't do anything. engineKind := modelRunner.EngineKind() standaloneSupported := engineKind == types.ModelRunnerEngineKindMoby || From 9cb4471e27a057da45d799783f566596ed4be394 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Wed, 14 Jan 2026 18:28:56 +0200 Subject: [PATCH 05/17] feat(authn): support DOCKER_USERNAME/DOCKER_PASSWORD env vars Add support for both DOCKER_USERNAME/DOCKER_PASSWORD and the existing DOCKER_HUB_USER/DOCKER_HUB_PASSWORD environment variables in DefaultKeychain. This change is part of an effort to remove mdltool, which had its own env var handling for DOCKER_USERNAME/DOCKER_PASSWORD. By moving this support into DefaultKeychain, registry clients automatically pick up credentials without requiring explicit configuration. Signed-off-by: Dorin Geman --- pkg/distribution/oci/authn/authn.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pkg/distribution/oci/authn/authn.go b/pkg/distribution/oci/authn/authn.go index bf2c0274..1883e7e1 100644 --- a/pkg/distribution/oci/authn/authn.go +++ b/pkg/distribution/oci/authn/authn.go @@ -92,12 +92,17 @@ func (k *defaultKeychain) Resolve(r Resource) (Authenticator, error) { registry := r.RegistryStr() // Try environment variables first - if username := os.Getenv("DOCKER_HUB_USER"); username != "" { - if password := os.Getenv("DOCKER_HUB_PASSWORD"); password != "" { - return &Basic{ - Username: username, - Password: password, - }, nil + for _, envPair := range []struct{ user, pass string }{ + {"DOCKER_USERNAME", "DOCKER_PASSWORD"}, + {"DOCKER_HUB_USER", "DOCKER_HUB_PASSWORD"}, + } { + if username := os.Getenv(envPair.user); username != "" { + if password := os.Getenv(envPair.pass); password != "" { + return &Basic{ + Username: username, + Password: password, + }, nil + } } } From 1b5781a5ebea8a8231bfbc258dfc902d1f18e172 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Wed, 14 Jan 2026 19:07:38 +0200 Subject: [PATCH 06/17] chore: remove mdltool Thanks to https://github.com/docker/model-runner/pull/567 and https://github.com/docker/model-runner/pull/568. Tested in https://github.com/docker/model-publisher/pull/77. Signed-off-by: Dorin Geman --- .github/workflows/ci.yml | 3 - .gitignore | 9 - Makefile | 60 +--- cmd/mdltool/main.go | 613 ------------------------------------- cmd/mdltool/main_test.go | 179 ----------- pkg/distribution/Makefile | 85 ----- pkg/distribution/README.md | 92 +----- 7 files changed, 14 insertions(+), 1027 deletions(-) delete mode 100644 cmd/mdltool/main.go delete mode 100644 cmd/mdltool/main_test.go delete mode 100644 pkg/distribution/Makefile diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 44e4510b..1090460e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,6 +60,3 @@ jobs: - name: validate run: make validate - - - name: model-distribution # TODO: Create a single Makefile for the monorepo - run: make -C ./pkg/distribution/ all diff --git a/.gitignore b/.gitignore index e833ff3a..af5620ee 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,13 @@ .idea/ -model-distribution-tool model-runner model-runner.sock docker-model # Default MODELS_PATH in Makefile models-store/ -# Default MODELS_PATH in mdltool -model-store/ # Directory where we store the updated llama.cpp updated-inference/ vendor/ -# monorepo migration -# model-distribution -pkg/distribution/bin/ -/parallelget -/cli - llamacpp/build llamacpp/install diff --git a/Makefile b/Makefile index 4578f874..82193c30 100644 --- a/Makefile +++ b/Makefile @@ -21,18 +21,11 @@ DOCKER_BUILD_ARGS := \ --target $(DOCKER_TARGET) \ -t $(DOCKER_IMAGE) -# Model distribution tool configuration -MDL_TOOL_NAME := model-distribution-tool -STORE_PATH ?= ./model-store -SOURCE ?= -TAG ?= -LICENSE ?= - # Test configuration BUILD_DMR ?= 1 # Main targets -.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-run-impl help validate lint model-distribution-tool +.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-run-impl help validate lint # Default target .DEFAULT_GOAL := build @@ -40,10 +33,6 @@ BUILD_DMR ?= 1 build: CGO_ENABLED=1 go build -ldflags="-s -w" -o $(APP_NAME) . -# Build model-distribution-tool -model-distribution-tool: - CGO_ENABLED=1 go build -ldflags="-s -w" -o $(MDL_TOOL_NAME) ./cmd/mdltool - # Run the application locally run: build @LLAMACPP_BIN="llamacpp/install/bin"; \ @@ -56,7 +45,6 @@ run: build # Clean build artifacts clean: rm -f $(APP_NAME) - rm -f $(MDL_TOOL_NAME) rm -f model-runner.sock rm -rf $(MODELS_PATH) @@ -145,40 +133,10 @@ docker-run-impl: DEBUG="${DEBUG}" \ scripts/docker-run.sh -# Model distribution tool operations -mdl-pull: model-distribution-tool - @echo "Pulling model from $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) pull $(TAG) - -mdl-package: model-distribution-tool - @echo "Packaging model $(SOURCE) to $(TAG)..." - ./$(MDL_TOOL_NAME) package --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE)) $(SOURCE) - -mdl-list: model-distribution-tool - @echo "Listing models..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) list - -mdl-get: model-distribution-tool - @echo "Getting model $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) get $(TAG) - -mdl-get-path: model-distribution-tool - @echo "Getting path for model $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) get-path $(TAG) - -mdl-rm: model-distribution-tool - @echo "Removing model $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) rm $(TAG) - -mdl-tag: model-distribution-tool - @echo "Tagging model $(SOURCE) as $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) tag $(SOURCE) $(TAG) - # Show help help: @echo "Available targets:" @echo " build - Build the Go application" - @echo " model-distribution-tool - Build the model distribution tool" @echo " run - Run the application locally" @echo " clean - Clean build artifacts" @echo " test - Run tests" @@ -195,15 +153,6 @@ help: @echo " docker-run-sglang - Run SGLang Docker container" @echo " help - Show this help message" @echo "" - @echo "Model distribution tool targets:" - @echo " mdl-pull - Pull a model (TAG=registry/model:tag)" - @echo " mdl-package - Package and push a model (SOURCE=path/to/model.gguf TAG=registry/model:tag LICENSE=path/to/license.txt)" - @echo " mdl-list - List all models" - @echo " mdl-get - Get model info (TAG=registry/model:tag)" - @echo " mdl-get-path - Get model path (TAG=registry/model:tag)" - @echo " mdl-rm - Remove a model (TAG=registry/model:tag)" - @echo " mdl-tag - Tag a model (SOURCE=registry/model:tag TAG=registry/model:newtag)" - @echo "" @echo "Backend configuration options:" @echo " LLAMA_ARGS - Arguments for llama.cpp (e.g., \"--verbose --jinja -ngl 999 --ctx-size 2048\")" @echo " LOCAL_LLAMA - Use local llama.cpp build from llamacpp/install/bin (set to 1 to enable)" @@ -212,10 +161,3 @@ help: @echo " make run LLAMA_ARGS=\"--verbose --jinja -ngl 999 --ctx-size 2048\"" @echo " make run LOCAL_LLAMA=1" @echo " make docker-run LLAMA_ARGS=\"--verbose --jinja -ngl 999 --threads 4 --ctx-size 2048\"" - @echo "" - @echo "Model distribution tool examples:" - @echo " make mdl-pull TAG=registry.example.com/models/llama:v1.0" - @echo " make mdl-package SOURCE=./model.gguf TAG=registry.example.com/models/llama:v1.0 LICENSE=./license.txt" - @echo " make mdl-package SOURCE=./qwen2.5-3b-instruct TAG=registry.example.com/models/qwen:v1.0" - @echo " make mdl-list" - @echo " make mdl-rm TAG=registry.example.com/models/llama:v1.0" diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go deleted file mode 100644 index 3c2fc2eb..00000000 --- a/cmd/mdltool/main.go +++ /dev/null @@ -1,613 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/docker/model-runner/pkg/distribution/builder" - "github.com/docker/model-runner/pkg/distribution/distribution" - "github.com/docker/model-runner/pkg/distribution/packaging" - "github.com/docker/model-runner/pkg/distribution/registry" - "github.com/docker/model-runner/pkg/distribution/tarball" -) - -// stringSliceFlag is a flag that can be specified multiple times to collect multiple string values -type stringSliceFlag []string - -func (s *stringSliceFlag) String() string { - return strings.Join(*s, ", ") -} - -func (s *stringSliceFlag) Set(value string) error { - *s = append(*s, value) - return nil -} - -const ( - defaultStorePath = "./model-store" - version = "0.1.0" -) - -var ( - storePath string - showHelp bool - showVer bool -) - -func init() { - flag.StringVar(&storePath, "store-path", defaultStorePath, "Path to the model store") - flag.BoolVar(&showHelp, "help", false, "Show help") - flag.BoolVar(&showVer, "version", false, "Show version") -} - -func main() { - flag.Parse() - - if showVer { - fmt.Printf("model-distribution-tool version %s\n", version) - return - } - - if showHelp || flag.NArg() == 0 { - printUsage() - return - } - - // Create absolute path for store - absStorePath, err := filepath.Abs(storePath) - if err != nil { - fmt.Fprintf(os.Stderr, "Error resolving store path: %v\n", err) - os.Exit(1) - } - - // Create the client with auth if environment variables are set - clientOpts := []distribution.Option{ - distribution.WithStoreRootPath(absStorePath), - distribution.WithUserAgent("model-distribution-tool/" + version), - } - - if username := os.Getenv("DOCKER_USERNAME"); username != "" { - if password := os.Getenv("DOCKER_PASSWORD"); password != "" { - clientOpts = append(clientOpts, distribution.WithRegistryAuth(username, password)) - } - } - - client, err := distribution.NewClient(clientOpts...) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating client: %v\n", err) - os.Exit(1) - } - - // Get the command and arguments - command := flag.Arg(0) - args := flag.Args()[1:] - - // Execute the command - var exitCode int - switch command { - case "pull": - exitCode = cmdPull(client, args) - case "package": - exitCode = cmdPackage(args) - case "push": - exitCode = cmdPush(client, args) - case "list": - exitCode = cmdList(client, args) - case "get": - exitCode = cmdGet(client, args) - case "get-path": - exitCode = cmdGetPath(client, args) - case "rm": - exitCode = cmdRm(client, args) - case "tag": - exitCode = cmdTag(client, args) - case "load": - exitCode = cmdLoad(client, args) - case "bundle": - exitCode = cmdBundle(client, args) - default: - fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) - printUsage() - exitCode = 1 - } - - os.Exit(exitCode) -} - -func printUsage() { - fmt.Println("Usage: model-distribution-tool [options] [arguments]") - fmt.Println("\nOptions:") - flag.PrintDefaults() - fmt.Println("\nCommands:") - fmt.Println(" pull Pull a model from a registry") - fmt.Println(" package Package a model file as an OCI artifact and push it to a registry") - fmt.Println(" (use --licenses to add license files, --mmproj for multimodal projector, --dir-tar for directories)") - fmt.Println(" push Push a model from the content store to the registry") - fmt.Println(" list List all models") - fmt.Println(" get Get a model by reference") - fmt.Println(" get-path Get the local file path for a model") - fmt.Println(" rm Remove a model by reference") - fmt.Println(" bundle Create a runtime bundle for model") - fmt.Println("\nExamples:") - fmt.Println(" model-distribution-tool --store-path ./models pull registry.example.com/models/llama:v1.0") - fmt.Println(" model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0 --licenses ./license1.txt --licenses ./license2.txt") - fmt.Println(" model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0 --mmproj ./model.mmproj") - fmt.Println(" model-distribution-tool package ./model.gguf registry.example.com/models/llama:v1.0 --dir-tar ./config --dir-tar ./templates") - fmt.Println(" model-distribution-tool push registry.example.com/models/llama:v1.0") - fmt.Println(" model-distribution-tool list") - fmt.Println(" model-distribution-tool rm registry.example.com/models/llama:v1.0") - fmt.Println(" model-distribution-tool bundle registry.example.com/models/llama:v1.0") -} - -func cmdPull(client *distribution.Client, args []string) int { - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool pull \n") - return 1 - } - - reference := args[0] - ctx := context.Background() - - if err := client.PullModel(ctx, reference, os.Stdout); err != nil { - fmt.Fprintf(os.Stderr, "Error pulling model: %v\n", err) - return 1 - } - - fmt.Printf("Successfully pulled model: %s\n", reference) - return 0 -} - -func cmdPackage(args []string) int { - fs := flag.NewFlagSet("package", flag.ExitOnError) - var ( - licensePaths stringSliceFlag - dirTarPaths stringSliceFlag - contextSize uint64 - file string - tag string - mmproj string - chatTemplate string - ) - - fs.Var(&licensePaths, "licenses", "Paths to license files (can be specified multiple times)") - fs.Var(&dirTarPaths, "dir-tar", "Relative paths to directories to package as tar (can be specified multiple times)") - fs.Uint64Var(&contextSize, "context-size", 0, "Context size in tokens") - fs.StringVar(&mmproj, "mmproj", "", "Path to Multimodal Projector file") - fs.StringVar(&file, "file", "", "Write archived model to the given file") - fs.StringVar(&tag, "tag", "", "Push model to the given registry tag") - fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file") - - fs.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] \n\n") - fmt.Fprintf(os.Stderr, "Examples:\n") - fmt.Fprintf(os.Stderr, " # GGUF model:\n") - fmt.Fprintf(os.Stderr, " model-distribution-tool package model.gguf --tag registry/model:tag\n\n") - fmt.Fprintf(os.Stderr, " # Safetensors model:\n") - fmt.Fprintf(os.Stderr, " model-distribution-tool package ./qwen-model-dir --tag registry/model:tag\n\n") - fmt.Fprintf(os.Stderr, "Options:\n") - fs.PrintDefaults() - } - - if err := fs.Parse(args); err != nil { - fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) - return 1 - } - args = fs.Args() - - // Get the source from positional argument - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: no model file or directory specified\n") - fs.Usage() - return 1 - } - - source := args[0] - var isSafetensors bool - var configArchive string // For safetensors config - var safetensorsPaths []string // For safetensors model files - - // Check if source exists - sourceInfo, err := os.Stat(source) - if os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Error: source does not exist: %s\n", source) - return 1 - } - - // Handle directory-based packaging (for safetensors models) - if sourceInfo.IsDir() { - fmt.Printf("Detected directory, scanning for safetensors model...\n") - var err error - safetensorsPaths, configArchive, err = packaging.PackageFromDirectory(source) - if err != nil { - fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err) - return 1 - } - - isSafetensors = true - fmt.Printf("Found %d safetensors file(s)\n", len(safetensorsPaths)) - - // Clean up temp config archive when done - if configArchive != "" { - defer os.Remove(configArchive) - fmt.Printf("Created temporary config archive from directory\n") - } - } else { - // Handle single file (GGUF model) - if strings.HasSuffix(strings.ToLower(source), ".gguf") { - isSafetensors = false - fmt.Println("Detected GGUF model file") - } else { - fmt.Fprintf(os.Stderr, "Warning: could not determine model type for: %s\n", source) - fmt.Fprintf(os.Stderr, "Assuming GGUF format.\n") - } - } - - if file == "" && tag == "" { - fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n") - fs.Usage() - return 1 - } - - ctx := context.Background() - - // Prepare registry client options - registryClientOpts := []registry.ClientOption{ - registry.WithUserAgent("model-distribution-tool/" + version), - } - - // Add auth if available - if username := os.Getenv("DOCKER_USERNAME"); username != "" { - if password := os.Getenv("DOCKER_PASSWORD"); password != "" { - registryClientOpts = append(registryClientOpts, registry.WithAuthConfig(username, password)) - } - } - - // Create registry client once with all options - registryClient := registry.NewClient(registryClientOpts...) - - var target builder.Target - if file != "" { - target = tarball.NewFileTarget(file) - } else { - var err error - target, err = registryClient.NewTarget(tag) - if err != nil { - fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err) - return 1 - } - } - - // Create builder based on model type - var b *builder.Builder - if isSafetensors { - fmt.Println("Creating safetensors model") - b, err = builder.FromPaths(safetensorsPaths) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating model from safetensors: %v\n", err) - return 1 - } - - // Add config archive if provided - if configArchive != "" { - fmt.Printf("Adding config archive: %s\n", configArchive) - b, err = b.WithConfigArchive(configArchive) - if err != nil { - fmt.Fprintf(os.Stderr, "Error adding config archive: %v\n", err) - return 1 - } - } - } else { - b, err = builder.FromPath(source) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err) - return 1 - } - } - - // Add all license files as layers - for _, path := range licensePaths { - fmt.Println("Adding license file:", path) - b, err = b.WithLicense(path) - if err != nil { - fmt.Fprintf(os.Stderr, "Error adding license layer for %s: %v\n", path, err) - return 1 - } - } - - if contextSize > 0 { - fmt.Println("Setting context size:", contextSize) - b = b.WithContextSize(int32(contextSize)) - } - - if mmproj != "" { - fmt.Println("Adding multimodal projector file:", mmproj) - b, err = b.WithMultimodalProjector(mmproj) - if err != nil { - fmt.Fprintf(os.Stderr, "Error adding multimodal projector layer for %s: %v\n", mmproj, err) - return 1 - } - } - - if chatTemplate != "" { - fmt.Println("Adding chat template file:", chatTemplate) - b, err = b.WithChatTemplateFile(chatTemplate) - if err != nil { - fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err) - return 1 - } - } - - // Process directory tar archives - if len(dirTarPaths) > 0 { - // Determine base directory for resolving relative paths - var baseDir string - if isSafetensors { - baseDir = source - } else { - // For GGUF, use the directory containing the GGUF file - baseDir = filepath.Dir(source) - } - - processor := packaging.NewDirTarProcessor(dirTarPaths, baseDir) - tarPaths, cleanup, err := processor.Process() - if err != nil { - fmt.Fprintf(os.Stderr, "Error processing dir-tar paths: %v\n", err) - return 1 - } - defer cleanup() - - for _, tarPath := range tarPaths { - b, err = b.WithDirTar(tarPath) - if err != nil { - fmt.Fprintf(os.Stderr, "Error adding directory tar: %v\n", err) - return 1 - } - } - } - - // Push the image - if err := b.Build(ctx, target, os.Stdout); err != nil { - fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err) - return 1 - } - if tag != "" { - fmt.Printf("Successfully packaged and pushed model: %s\n", tag) - } else { - fmt.Printf("Successfully packaged model to file: %s\n", file) - } - return 0 -} - -func cmdLoad(client *distribution.Client, args []string) int { - fs := flag.NewFlagSet("load", flag.ExitOnError) - var ( - tag string - ) - fs.StringVar(&tag, "tag", "", "Apply tag to the loaded model") - fs.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool load [OPTIONS] \n\n") - fmt.Fprintf(os.Stderr, "Options:\n") - fs.PrintDefaults() - } - - if err := fs.Parse(args); err != nil { - fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) - return 1 - } - args = fs.Args() - - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing required argument\n") - fs.Usage() - return 1 - } - path := args[0] - - f, err := os.Open(path) - if err != nil { - fmt.Fprintf(os.Stderr, "Error opening model file: %v\n", err) - return 1 - } - defer f.Close() - - id, err := client.LoadModel(f, os.Stdout) - if err != nil { - fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err) - return 1 - } - fmt.Fprintln(os.Stdout, "Loaded model:", id) - if err := client.Tag(id, tag); err != nil { - fmt.Fprintf(os.Stderr, "Error tagging model: %v\n", err) - } - fmt.Fprintln(os.Stdout, "Tagged model:", tag) - return 0 -} - -func cmdPush(client *distribution.Client, args []string) int { - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing tag argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool push \n") - return 1 - } - - tag := args[0] - ctx := context.Background() - - if err := client.PushModel(ctx, tag, os.Stdout); err != nil { - fmt.Fprintf(os.Stderr, "Error pushing model: %v\n", err) - return 1 - } - - fmt.Printf("Successfully pushed model: %s\n", tag) - return 0 -} - -func cmdList(client *distribution.Client, args []string) int { - models, err := client.ListModels() - if err != nil { - fmt.Fprintf(os.Stderr, "Error listing models: %v\n", err) - return 1 - } - - if len(models) == 0 { - fmt.Println("No models found") - return 0 - } - - fmt.Println("Models:") - for i, model := range models { - id, err := model.ID() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting model ID: %v\n", err) - continue - } - fmt.Printf("%d. ID: %s\n", i+1, id) - fmt.Printf(" Tags: %s\n", strings.Join(model.Tags(), ", ")) - - ggufPaths, err := model.GGUFPaths() - if err == nil { - fmt.Print(" GGUF Paths:\n") - for _, path := range ggufPaths { - fmt.Printf("\t%s\n", path) - } - } - } - return 0 -} - -func cmdGet(client *distribution.Client, args []string) int { - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool get \n") - return 1 - } - - reference := args[0] - - model, err := client.GetModel(reference) - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting model: %v\n", err) - return 1 - } - - fmt.Printf("Model: %s\n", reference) - - id, err := model.ID() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting model ID %v\n", err) - return 1 - } - fmt.Printf("ID: %s\n", id) - - ggufPaths, err := model.GGUFPaths() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting gguf path %v\n", err) - return 1 - } - fmt.Print(" GGUF Paths:\n") - for _, path := range ggufPaths { - fmt.Printf("\t%s\n", path) - } - - cfg, err := model.Config() - if err != nil { - fmt.Fprintf(os.Stderr, "Error reading model config: %v\n", err) - return 1 - } - fmt.Printf("Format: %s\n", cfg.GetFormat()) - fmt.Printf("Architecture: %s\n", cfg.GetArchitecture()) - fmt.Printf("Parameters: %s\n", cfg.GetParameters()) - fmt.Printf("Quantization: %s\n", cfg.GetQuantization()) - return 0 -} - -func cmdGetPath(client *distribution.Client, args []string) int { - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool get-path \n") - return 1 - } - - reference := args[0] - - model, err := client.GetModel(reference) - if err != nil { - fmt.Fprintf(os.Stderr, "Failed to get model: %v\n", err) - return 1 - } - - modelPaths, err := model.GGUFPaths() - if err != nil || len(modelPaths) == 0 { - fmt.Fprintf(os.Stderr, "Error getting model path: %v\n", err) - return 1 - } - - fmt.Println(modelPaths[0]) - return 0 -} - -func cmdRm(client *distribution.Client, args []string) int { - var force bool - fs := flag.NewFlagSet("rm", flag.ExitOnError) - fs.BoolVar(&force, "force", false, "Force remove the model") - - if err := fs.Parse(args); err != nil { - fmt.Fprintf(os.Stderr, "Error parsing flags: %v\n", err) - return 1 - } - args = fs.Args() - - if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing reference argument\n") - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool rm [--force] \n") - return 1 - } - - reference := args[0] - - if _, err := client.DeleteModel(reference, force); err != nil { - fmt.Fprintf(os.Stderr, "Error removing model: %v\n", err) - return 1 - } - - fmt.Printf("Successfully removed model: %s\n", reference) - return 0 -} - -func cmdTag(client *distribution.Client, args []string) int { - if len(args) != 2 { - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool tag \n") - return 1 - } - - source := args[0] - target := args[1] - - if err := client.Tag(source, target); err != nil { - fmt.Fprintf(os.Stderr, "Error tagging model: %v\n", err) - return 1 - } - - fmt.Printf("Successfully applied tag %s to model: %s\n", target, source) - return 0 -} - -func cmdBundle(client *distribution.Client, args []string) int { - if len(args) != 1 { - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool bundle \n") - return 1 - } - bundle, err := client.GetBundle(args[0]) - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting model bundle: %v\n", err) - return 1 - } - fmt.Fprintf(os.Stderr, "Successfully created bundle for model %s\n", args[0]) - fmt.Fprint(os.Stdout, bundle.RootDir()) - return 0 -} diff --git a/cmd/mdltool/main_test.go b/cmd/mdltool/main_test.go deleted file mode 100644 index 36ec4e63..00000000 --- a/cmd/mdltool/main_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package main - -import ( - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - - "github.com/docker/model-runner/pkg/distribution/distribution" -) - -// TestMainHelp tests the help command -func TestMainHelp(t *testing.T) { - cmd := exec.Command("go", "run", "main.go", "--help") - output, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("Failed to run help command: %v\nOutput: %s", err, output) - } - - // Check that the output contains the usage information - if !strings.Contains(string(output), "Usage:") { - t.Errorf("Help output does not contain usage information") - } - - // Check that the output contains the commands - commands := []string{"pull", "package", "list", "get", "get-path"} - for _, cmd := range commands { - if !strings.Contains(string(output), cmd) { - t.Errorf("Help output does not contain command: %s", cmd) - } - } -} - -// TestMainVersion tests the version command -func TestMainVersion(t *testing.T) { - cmd := exec.Command("go", "run", "main.go", "--version") - output, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("Failed to run version command: %v\nOutput: %s", err, output) - } - - // Check that the output contains the version information - if !strings.Contains(string(output), "version") { - t.Errorf("Version output does not contain version information") - } -} - -// TestMainPull tests the pull command -func TestMainPull(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Create a model store directory - storeDir := filepath.Join(tempDir, "model-store") - if err := os.MkdirAll(storeDir, 0755); err != nil { - t.Fatalf("Failed to create model store directory: %v", err) - } - - // Create a client for testing - client, err := distribution.NewClient(distribution.WithStoreRootPath(storeDir)) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - - // Test the pull command with invalid arguments - exitCode := cmdPull(client, []string{}) - if exitCode != 1 { - t.Errorf("Pull command with invalid arguments should fail") - } -} - -// TestMainPackage tests the package command -func TestMainPackage(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Test the package command with invalid arguments - exitCode := cmdPackage([]string{}) - if exitCode != 1 { - t.Errorf("Push command with invalid arguments should fail") - } -} - -// TestMainList tests the list command -func TestMainList(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Create a client for testing - client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - - // Test the list command - exitCode := cmdList(client, []string{}) - if exitCode != 0 { - t.Errorf("List command failed with exit code: %d", exitCode) - } -} - -// TestMainGet tests the get command -func TestMainGet(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Create a client for testing - client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - - // Test the get command with invalid arguments - exitCode := cmdGet(client, []string{}) - if exitCode != 1 { - t.Errorf("Get command with invalid arguments should fail") - } -} - -// TestMainGetPath tests the get-path command -func TestMainGetPath(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Create a client for testing - client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - - // Test the get-path command with invalid arguments - exitCode := cmdGetPath(client, []string{}) - if exitCode != 1 { - t.Errorf("Get-path command with invalid arguments should fail") - } -} - -// TestMainPush tests the push command -func TestMainPush(t *testing.T) { - // Create a temporary directory for the test - tempDir, err := os.MkdirTemp("", "model-distribution-test-*") - if err != nil { - t.Fatalf("Failed to create temp directory: %v", err) - } - defer os.RemoveAll(tempDir) - - // Create a client for testing - client, err := distribution.NewClient(distribution.WithStoreRootPath(tempDir)) - if err != nil { - t.Fatalf("Failed to create client: %v", err) - } - - // Test the push command with invalid arguments - exitCode := cmdPush(client, []string{}) - if exitCode != 1 { - t.Errorf("Push command with invalid arguments should fail") - } -} diff --git a/pkg/distribution/Makefile b/pkg/distribution/Makefile deleted file mode 100644 index 458ff6c8..00000000 --- a/pkg/distribution/Makefile +++ /dev/null @@ -1,85 +0,0 @@ -.PHONY: all build test clean lint run - -# Import env file if it exists --include .env - -# Build variables -BINARY_NAME=model-distribution-tool -VERSION?=0.1.0 - -# Go related variables -GOBASE=$(shell pwd) -GOBIN=$(GOBASE)/bin - -# Run configuration -SOURCE?= -TAG?= -STORE_PATH?=./model-store - -# Use linker flags to provide version/build information -LDFLAGS=-ldflags "-X main.Version=${VERSION}" - -all: clean lint build test - -build: - @echo "Building ${BINARY_NAME}..." - @mkdir -p ${GOBIN} - @go build ${LDFLAGS} -o ${GOBIN}/${BINARY_NAME} ../../cmd/mdltool/ - -test: - @echo "Running unit tests..." - @go test -race ./... ../../cmd/mdltool/... - -clean: - @echo "Cleaning..." - @rm -rf ${GOBIN} - @rm -f ${BINARY_NAME} - @rm -f *.test - @rm -rf test/artifacts/* - -lint: - @echo "Running linters..." - @gofmt -s -l . | tee /dev/stderr | xargs -r false - @go vet ./... - -run-pull: - @echo "Pulling model from ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} pull ${TAG} - -run-package: - @echo "Pushing model ${SOURCE} to ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} package ${SOURCE} ${TAG} ${LICENSE:+--license ${LICENSE}} - -run-list: - @echo "Listing models..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} list - -run-get: - @echo "Getting model ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} get ${TAG} - -run-get-path: - @echo "Getting path for model ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} get-path ${TAG} - -run-rm: - @echo "Removing model ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} rm ${TAG} - -run-tag: - @echo "Tagging model ${SOURCE} as ${TAG}..." - @${GOBIN}/${BINARY_NAME} --store-path ${STORE_PATH} tag ${SOURCE} ${TAG} - -help: - @echo "Available targets:" - @echo " all - Clean, build, and test" - @echo " build - Build the binary" - @echo " test - Run unit tests" - @echo " clean - Clean build artifacts" - @echo " run-pull - Pull a model (TAG=registry/model:tag)" - @echo " run-package - Package and push a model (SOURCE=path/to/model.gguf TAG=registry/model:tag LICENSE=path/to/license.txt)" - @echo " run-list - List all models" - @echo " run-get - Get model info (TAG=registry/model:tag)" - @echo " run-get-path - Get model path (TAG=registry/model:tag)" - @echo " run-rm - Remove a model (TAG=registry/model:tag)" - @echo " run-tag - Tag a model (SOURCE=registry/model:tag TAG=registry/model:newtag)" diff --git a/pkg/distribution/README.md b/pkg/distribution/README.md index 1042dcdb..5fe13fb1 100644 --- a/pkg/distribution/README.md +++ b/pkg/distribution/README.md @@ -1,10 +1,10 @@ # Model Distribution -A library and CLI tool for distributing models using container registries. +A Go library for distributing models using container registries. ## Overview -Model Distribution is a Go library and CLI tool that allows you to package, push, pull, and manage models using container registries. It provides a simple API and command-line interface for working with models in GGUF and Safetensors format. +Model Distribution is a Go library that allows you to package, push, pull, and manage models using container registries. It provides a simple API for working with models in GGUF and Safetensors format. ## Features @@ -12,92 +12,26 @@ Model Distribution is a Go library and CLI tool that allows you to package, push - Pull models from container registries - Local model storage - Model metadata management -- Command-line interface for all operations -- GitHub workflows for automated model packaging - Support for both GGUF and Safetensors model formats ## Usage -### As a CLI Tool - -```bash -# Build the CLI tool -make build - -# Pull a model from a registry -./bin/model-distribution-tool pull registry.example.com/models/llama:v1.0 - -# Package a model and push to a registry -./bin/model-distribution-tool package --tag registry.example.com/models/llama:v1.0 ./model.gguf - -# Package a sharded model and push to a registry -./bin/model-distribution-tool package --tag registry.example.com/models/example ./model-00001-of-00007.gguf - -# Package a model with license files and push to a registry -./bin/model-distribution-tool package --licenses license1.txt --licenses license2.txt --tag registry.example.com/models/llama:v1.0 ./model.gguf - -# Package a model with a default context size and push to a registry -./bin/model-distribution-tool package --context-size 2048 --tag registry.example.com/models/llama:v1.0 ./model.gguf - -# Package a model with a multimodal projector file and push to a registry -./bin/model-distribution-tool package --mmproj ./model.mmproj --tag registry.example.com/models/llama:v1.0 ./model.gguf - -# Package a model with a custom chat template and push to a registry -./bin/model-distribution-tool package --chat-template ./template.jinja --tag registry.example.com/models/llama:v1.0 ./model.gguf - -# Package a model and output the result to a file -./bin/model-distribution-tool package --file ./model.tar ./model.gguf - -# Load a model from an archive into the local store -./bin/model-distribution-tool load ./model.tar - -# Push a model from the content store to the registry -./bin/model-distribution-tool push registry.example.com/models/llama:v1.0 - -# List all models in the local store -./bin/model-distribution-tool list - -# Get information about a model -./bin/model-distribution-tool get registry.example.com/models/llama:v1.0 - -# Get the local file path for a model -./bin/model-distribution-tool get-path registry.example.com/models/llama:v1.0 - -# Remove a model from the local store (will untag w/o deleting if there are multiple tags) -./bin/model-distribution-tool rm registry.example.com/models/llama:v1.0 - -# Force Removal of a model from the local store, even when there are multiple referring tags -./bin/model-distribution-tool rm --force sha256:0b329b335467cccf7aa219e8f5e1bd65e59b6dfa81cfa42fba2f8881268fbf82 - -# Tag a model with an additional reference -./bin/model-distribution-tool tag registry.example.com/models/llama:v1.0 registry.example.com/models/llama:latest - -# Create a runtime bundle for model -./bin/model-distribution-tool bundle registry.example.com/models/llama:v1.0 -``` - -For more information about the CLI tool, run: - -```bash -./bin/model-distribution-tool --help -``` - -### As a Library - ```go import ( "context" - "github.com/docker/model-runner/pkg/distribution/pkg/distribution" + "github.com/docker/model-runner/pkg/distribution/distribution" ) // Create a new client -client, err := distribution.NewClient("/path/to/cache") +client, err := distribution.NewClient( + distribution.WithStoreRootPath("/path/to/cache"), +) if err != nil { // Handle error } // Pull a model -err := client.PullModel(context.Background(), "registry.example.com/models/llama:v1.0", os.Stdout) +err = client.PullModel(context.Background(), "registry.example.com/models/llama:v1.0", os.Stdout) if err != nil { // Handle error } @@ -109,15 +43,15 @@ if err != nil { } // Create a bundle -bundlePath, err := client.GetBundle("registry.example.com/models/llama:v1.0") +bundle, err := client.GetBundle("registry.example.com/models/llama:v1.0") if err != nil { -// Handle error + // Handle error } // Get the GGUF file path within the bundle modelPath, err := bundle.GGUFPath() if err != nil { -// Handle error + // Handle error } fmt.Println("Model path:", modelPath) @@ -129,7 +63,7 @@ if err != nil { } // Delete a model -err = client.DeleteModel("registry.example.com/models/llama:v1.0", false) +_, err = client.DeleteModel("registry.example.com/models/llama:v1.0", false) if err != nil { // Handle error } @@ -141,8 +75,8 @@ if err != nil { } // Push a model -err = client.PushModel("registry.example.com/models/llama:v1.0") +err = client.PushModel(context.Background(), "registry.example.com/models/llama:v1.0", nil) if err != nil { // Handle error } -``` +``` \ No newline at end of file From a7368feea78f704babd98f564d61b2b1986853ac Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Wed, 14 Jan 2026 18:10:22 +0200 Subject: [PATCH 07/17] refactor(distribution): use shared registry client instead of duplicating options Replace WithTransport, WithUserAgent, WithPlainHTTP, and WithRegistryAuth options with a single WithRegistryClient option. This eliminates the duplicate registry client creation - callers now create one registry client and pass it to both the distribution client and use it directly. Signed-off-by: Dorin Geman --- pkg/distribution/distribution/client.go | 64 +++------------ pkg/distribution/distribution/client_test.go | 86 ++++++++------------ pkg/distribution/registry/client_test.go | 25 ++++++ pkg/inference/models/manager.go | 18 ++-- 4 files changed, 77 insertions(+), 116 deletions(-) diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index d153e96d..deb9a9e5 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "net/http" "os" "slices" "strings" @@ -40,13 +39,9 @@ type Option func(*options) // options holds the configuration for a new Client type options struct { - storeRootPath string - logger *logrus.Entry - transport http.RoundTripper - userAgent string - username string - password string - plainHTTP bool + storeRootPath string + logger *logrus.Entry + registryClient *registry.Client } // WithStoreRootPath sets the store root path @@ -67,46 +62,18 @@ func WithLogger(logger *logrus.Entry) Option { } } -// WithTransport sets the HTTP transport to use when pulling and pushing models. -func WithTransport(transport http.RoundTripper) Option { +// WithRegistryClient sets the registry client to use for pulling and pushing models. +func WithRegistryClient(client *registry.Client) Option { return func(o *options) { - if transport != nil { - o.transport = transport + if client != nil { + o.registryClient = client } } } -// WithUserAgent sets the User-Agent header to use when pulling and pushing models. -func WithUserAgent(ua string) Option { - return func(o *options) { - if ua != "" { - o.userAgent = ua - } - } -} - -// WithRegistryAuth sets the registry authentication credentials -func WithRegistryAuth(username, password string) Option { - return func(o *options) { - if username != "" && password != "" { - o.username = username - o.password = password - } - } -} - -// WithPlainHTTP allows connecting to registries using plain HTTP instead of HTTPS. -func WithPlainHTTP(plain bool) Option { - return func(o *options) { - o.plainHTTP = plain - } -} - func defaultOptions() *options { return &options{ - logger: logrus.NewEntry(logrus.StandardLogger()), - transport: registry.DefaultTransport, - userAgent: registry.DefaultUserAgent, + logger: logrus.NewEntry(logrus.StandardLogger()), } } @@ -128,23 +95,16 @@ func NewClient(opts ...Option) (*Client, error) { return nil, fmt.Errorf("initializing store: %w", err) } - // Create registry client options - registryOpts := []registry.ClientOption{ - registry.WithTransport(options.transport), - registry.WithUserAgent(options.userAgent), - registry.WithPlainHTTP(options.plainHTTP), - } - - // Add auth if credentials are provided - if options.username != "" && options.password != "" { - registryOpts = append(registryOpts, registry.WithAuthConfig(options.username, options.password)) + registryClient := options.registryClient + if registryClient == nil { + registryClient = registry.NewClient() } options.logger.Infoln("Successfully initialized store") return &Client{ store: s, log: options.logger, - registry: registry.NewClient(registryOpts...), + registry: registryClient, }, nil } diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 8e23a7bf..77613e72 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -32,6 +32,14 @@ var ( testGGUFFile = filepath.Join("..", "assets", "dummy.gguf") ) +// newTestClient creates a new client configured for testing with plain HTTP enabled. +func newTestClient(storeRootPath string) (*Client, error) { + return NewClient( + WithStoreRootPath(storeRootPath), + WithRegistryClient(mdregistry.NewClient(mdregistry.WithPlainHTTP(true))), + ) +} + func TestClientPullModel(t *testing.T) { // Set up test registry server := httptest.NewServer(testregistry.New()) @@ -50,7 +58,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -152,7 +160,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - testClient, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -205,7 +213,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - testClient, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -309,7 +317,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - testClient, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -460,7 +468,7 @@ func TestClientPullModel(t *testing.T) { } defer os.RemoveAll(clientTempDir) - testClient, err := NewClient(WithStoreRootPath(clientTempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(clientTempDir) if err != nil { t.Fatalf("Failed to create test client: %v", err) } @@ -495,7 +503,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - testClient, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -569,7 +577,7 @@ func TestClientPullModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - testClient, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + testClient, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -605,7 +613,7 @@ func TestClientGetModel(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -644,7 +652,7 @@ func TestClientGetModelNotFound(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -665,7 +673,7 @@ func TestClientListModels(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -746,7 +754,7 @@ func TestClientGetStorePath(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -844,44 +852,14 @@ func TestWithFunctionsNilChecks(t *testing.T) { } }) - // Test WithTransport with nil - t.Run("WithTransport nil", func(t *testing.T) { - // Create options with default transport - opts := defaultOptions() - defaultTransport := opts.transport - - // Try to override with nil - WithTransport(nil)(opts) - - // Verify the transport wasn't changed to nil - if opts.transport == nil { - t.Error("WithTransport with nil changed transport to nil") - } - - // Verify it's still the default transport - if opts.transport != defaultTransport { - t.Error("WithTransport with nil changed the transport") - } - }) - - // Test WithUserAgent with empty string - t.Run("WithUserAgent empty string", func(t *testing.T) { - // Create options with default user agent + t.Run("WithRegistryClient nil", func(t *testing.T) { opts := defaultOptions() - defaultUA := opts.userAgent - - // Try to override with empty string - WithUserAgent("")(opts) + opts.registryClient = mdregistry.NewClient() - // Verify the user agent wasn't changed to empty - if opts.userAgent == "" { - t.Error("WithUserAgent with empty string changed user agent to empty") - } + WithRegistryClient(nil)(opts) - // Verify it's still the default user agent - if opts.userAgent != defaultUA { - t.Errorf("WithUserAgent with empty string changed the user agent: got %q, want %q", - opts.userAgent, defaultUA) + if opts.registryClient == nil { + t.Error("WithRegistryClient with nil changed registryClient to nil") } }) } @@ -895,7 +873,7 @@ func TestNewReferenceError(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -921,7 +899,7 @@ func TestPush(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -989,7 +967,7 @@ func TestPushProgress(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -1079,7 +1057,7 @@ func TestTag(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -1140,7 +1118,7 @@ func TestTagNotFound(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -1160,7 +1138,7 @@ func TestClientPushModelNotFound(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -1179,7 +1157,7 @@ func TestIsModelInStoreNotFound(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } @@ -1200,7 +1178,7 @@ func TestIsModelInStoreFound(t *testing.T) { defer os.RemoveAll(tempDir) // Create client with plainHTTP for test registry - client, err := NewClient(WithStoreRootPath(tempDir), WithPlainHTTP(true)) + client, err := newTestClient(tempDir) if err != nil { t.Fatalf("Failed to create client: %v", err) } diff --git a/pkg/distribution/registry/client_test.go b/pkg/distribution/registry/client_test.go index 4d65ecf1..6c1bf228 100644 --- a/pkg/distribution/registry/client_test.go +++ b/pkg/distribution/registry/client_test.go @@ -136,3 +136,28 @@ func resetOnceForTest() { once = sync.Once{} defaultRegistryOpts = nil } + +func TestWithTransportNil(t *testing.T) { + client := NewClient(WithTransport(nil)) + + if client.transport == nil { + t.Error("WithTransport with nil changed transport to nil") + } + + if client.transport != DefaultTransport { + t.Error("WithTransport with nil changed the transport from default") + } +} + +func TestWithUserAgentEmpty(t *testing.T) { + client := NewClient(WithUserAgent("")) + + if client.userAgent == "" { + t.Error("WithUserAgent with empty string changed user agent to empty") + } + + if client.userAgent != DefaultUserAgent { + t.Errorf("WithUserAgent with empty string changed the user agent: got %q, want %q", + client.userAgent, DefaultUserAgent) + } +} diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 825c8195..85839590 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -38,13 +38,18 @@ type Manager struct { // NewManager creates a new model models with the provided clients. func NewManager(log logging.Logger, c ClientConfig) *Manager { + // Create the registry client (shared between distribution and direct registry access). + registryClient := registry.NewClient( + registry.WithTransport(c.Transport), + registry.WithUserAgent(c.UserAgent), + registry.WithPlainHTTP(c.PlainHTTP), + ) + // Create the model distribution client. distributionClient, err := distribution.NewClient( distribution.WithStoreRootPath(c.StoreRootPath), distribution.WithLogger(c.Logger), - distribution.WithTransport(c.Transport), - distribution.WithUserAgent(c.UserAgent), - distribution.WithPlainHTTP(c.PlainHTTP), + distribution.WithRegistryClient(registryClient), ) if err != nil { log.Errorf("Failed to create distribution client: %v", err) @@ -52,13 +57,6 @@ func NewManager(log logging.Logger, c ClientConfig) *Manager { // respond to requests, but may return errors if the client is required. } - // Create the model registry client. - registryClient := registry.NewClient( - registry.WithTransport(c.Transport), - registry.WithUserAgent(c.UserAgent), - registry.WithPlainHTTP(c.PlainHTTP), - ) - tokens := make(chan struct{}, maximumConcurrentModelPulls) // Populate the pull concurrency semaphore. From 72f68d0e092bc4b12d2712b3db1a99ed7e2cb0d1 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Thu, 15 Jan 2026 18:42:08 +0200 Subject: [PATCH 08/17] feat(cli/bench): add shell completion Signed-off-by: Dorin Geman --- cmd/cli/commands/bench.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/cli/commands/bench.go b/cmd/cli/commands/bench.go index c8b9a546..4cc2d910 100644 --- a/cmd/cli/commands/bench.go +++ b/cmd/cli/commands/bench.go @@ -13,6 +13,7 @@ import ( "sync" "time" + "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/pkg/inference" "github.com/spf13/cobra" @@ -63,7 +64,8 @@ func newBenchCmd() *cobra.Command { This command runs a series of benchmarks with 1, 2, 4, and 8 concurrent requests by default, measuring the tokens per second (TPS) that the model can generate.`, - Args: cobra.ExactArgs(1), + Args: requireExactArgs(1, "bench", "MODEL"), + ValidArgsFunction: completion.ModelNames(getDesktopClient, 1), RunE: func(cmd *cobra.Command, args []string) error { model = args[0] From ac48f5390a4949235b6e7f2d4c0454686cfe3e23 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Fri, 16 Jan 2026 11:38:26 +0200 Subject: [PATCH 09/17] fix(cli/completion): handle PersistentPreRunE error Signed-off-by: Dorin Geman --- cmd/cli/commands/completion/functions.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cmd/cli/commands/completion/functions.go b/cmd/cli/commands/completion/functions.go index 2930c022..2589ac78 100644 --- a/cmd/cli/commands/completion/functions.go +++ b/cmd/cli/commands/completion/functions.go @@ -17,7 +17,9 @@ func ModelNames(desktopClient func() *desktop.Client, limit int) cobra.Completio // HACK: Invoke rootCmd's PersistentPreRunE, which is needed for context // detection and client initialization. This function isn't invoked // automatically on autocompletion paths. - cmd.Parent().PersistentPreRunE(cmd, args) + if err := cmd.Parent().PersistentPreRunE(cmd, args); err != nil { + return nil, cobra.ShellCompDirectiveError + } if limit > 0 && len(args) >= limit { return nil, cobra.ShellCompDirectiveNoFileComp @@ -41,7 +43,9 @@ func ModelNamesAndTags(desktopClient func() *desktop.Client, limit int) cobra.Co // HACK: Invoke rootCmd's PersistentPreRunE, which is needed for context // detection and client initialization. This function isn't invoked // automatically on autocompletion paths. - cmd.Parent().PersistentPreRunE(cmd, args) + if err := cmd.Parent().PersistentPreRunE(cmd, args); err != nil { + return nil, cobra.ShellCompDirectiveError + } if limit > 0 && len(args) >= limit { return nil, cobra.ShellCompDirectiveNoFileComp From 59a52ba03a60bc1e46e76c20e3deb7e9ed7cb9e2 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Fri, 16 Jan 2026 14:01:53 +0000 Subject: [PATCH 10/17] Add /set system command to interactive mode The /set system command allows users to set or update the system message during interactive sessions. The system prompt is now included in the message history sent to the chat endpoint, enabling customized behavior for the AI assistant. Signed-off-by: Eric Curtin --- cmd/cli/commands/run.go | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 22780f27..2f18aaee 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -90,6 +90,7 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /set system Set or update the system message") fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") @@ -154,6 +155,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. var sb strings.Builder var multiline bool var conversationHistory []desktop.OpenAIChatMessage + var systemPrompt string // Add a helper function to handle file inclusion when @ is pressed // We'll implement a basic version here that shows a message when @ is pressed @@ -217,6 +219,16 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. usage() } continue + case strings.HasPrefix(line, "/set system ") || line == "/set system": + // Extract the system prompt text after "/set system " + systemPrompt = strings.TrimPrefix(line, "/set system ") + systemPrompt = strings.TrimSpace(systemPrompt) + if systemPrompt == "" { + fmt.Fprintln(os.Stderr, "Cleared system message.") + } else { + fmt.Fprintln(os.Stderr, "Set system message.") + } + continue case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): return nil case strings.HasPrefix(line, "/"): @@ -245,7 +257,20 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. } }() - assistantResponse, processedUserMessage, err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, conversationHistory) + // Build message history with system prompt prepended if set + var messagesWithSystem []desktop.OpenAIChatMessage + if systemPrompt == "" { + messagesWithSystem = conversationHistory + } else { + messagesWithSystem = make([]desktop.OpenAIChatMessage, 1, 1+len(conversationHistory)) + messagesWithSystem[0] = desktop.OpenAIChatMessage{ + Role: "system", + Content: systemPrompt, + } + messagesWithSystem = append(messagesWithSystem, conversationHistory...) + } + + assistantResponse, processedUserMessage, err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, model, userInput, messagesWithSystem) // Clean up signal handler signal.Stop(sigChan) From 7041fab59c98be560081b45a6cd765cfe02b246e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:33:35 +0000 Subject: [PATCH 11/17] chore(deps): bump actions/setup-go in the github-actions group Bumps the github-actions group with 1 update: [actions/setup-go](https://github.com/actions/setup-go). Updates `actions/setup-go` from 6.1.0 to 6.2.0 - [Release notes](https://github.com/actions/setup-go/releases) - [Commits](https://github.com/actions/setup-go/compare/4dc6199c7b1a012772edbd06daecab0f50c9053c...7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5) --- updated-dependencies: - dependency-name: actions/setup-go dependency-version: 6.2.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions ... Signed-off-by: dependabot[bot] --- .github/workflows/ci.yml | 4 ++-- .github/workflows/cli-build.yml | 2 +- .github/workflows/integration-test.yml | 2 +- .github/workflows/release.yml | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1090460e..b18a33c5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 - name: Set up Go - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c + uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: go-version: 1.24.3 cache: true @@ -43,7 +43,7 @@ jobs: run: stat vendor && exit 1 || exit 0 - name: Set up Go - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c + uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: go-version: 1.24.3 cache: true diff --git a/.github/workflows/cli-build.yml b/.github/workflows/cli-build.yml index bd1175ad..814e2062 100644 --- a/.github/workflows/cli-build.yml +++ b/.github/workflows/cli-build.yml @@ -26,7 +26,7 @@ jobs: contents: read steps: - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 - - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c + - uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: go-version-file: cmd/cli/go.mod cache: true diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 5d34fb24..80b4888b 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -16,7 +16,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 - name: Set up Go - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c + uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: go-version: 1.24.3 cache: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 6aa0019d..9a2d0b4f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -44,7 +44,7 @@ jobs: uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 - name: Set up Go - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c + uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: go-version: 1.24.3 cache: true From 79f6bc9e68c050e24b5f7de07d50109d4367cf87 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:33:26 +0000 Subject: [PATCH 12/17] chore(deps): bump the go-modules-cli group in /cmd/cli with 2 updates Bumps the go-modules-cli group in /cmd/cli with 2 updates: [github.com/docker/cli](https://github.com/docker/cli) and [github.com/olekukonko/tablewriter](https://github.com/olekukonko/tablewriter). Updates `github.com/docker/cli` from 29.1.4+incompatible to 29.1.5+incompatible - [Commits](https://github.com/docker/cli/compare/v29.1.4...v29.1.5) Updates `github.com/olekukonko/tablewriter` from 1.1.2 to 1.1.3 - [Release notes](https://github.com/olekukonko/tablewriter/releases) - [Commits](https://github.com/olekukonko/tablewriter/compare/v1.1.2...v1.1.3) --- updated-dependencies: - dependency-name: github.com/docker/cli dependency-version: 29.1.5+incompatible dependency-type: direct:production update-type: version-update:semver-patch dependency-group: go-modules-cli - dependency-name: github.com/olekukonko/tablewriter dependency-version: 1.1.3 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: go-modules-cli ... Signed-off-by: dependabot[bot] --- cmd/cli/go.mod | 10 +++++----- cmd/cli/go.sum | 22 ++++++++++------------ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/cmd/cli/go.mod b/cmd/cli/go.mod index c56bf367..e7a7668a 100644 --- a/cmd/cli/go.mod +++ b/cmd/cli/go.mod @@ -5,7 +5,7 @@ go 1.24.3 require ( github.com/charmbracelet/glamour v0.10.0 github.com/containerd/errdefs v1.0.0 - github.com/docker/cli v29.1.4+incompatible + github.com/docker/cli v29.1.5+incompatible github.com/docker/cli-docs-tool v0.11.0 github.com/docker/docker v28.5.2+incompatible github.com/docker/go-connections v0.6.0 @@ -18,7 +18,7 @@ require ( github.com/moby/term v0.5.2 github.com/muesli/termenv v0.16.0 github.com/nxadm/tail v1.4.11 - github.com/olekukonko/tablewriter v1.1.2 + github.com/olekukonko/tablewriter v1.1.3 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 @@ -47,7 +47,7 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13 // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect github.com/charmbracelet/x/term v0.2.1 // indirect - github.com/clipperhouse/displaywidth v0.6.0 // indirect + github.com/clipperhouse/displaywidth v0.6.2 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/containerd/containerd/v2 v2.2.1 // indirect @@ -85,7 +85,7 @@ require ( github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-shellwords v1.0.12 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect @@ -105,7 +105,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.1.0 // indirect - github.com/olekukonko/ll v0.1.3 // indirect + github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/cmd/cli/go.sum b/cmd/cli/go.sum index 34cb44a6..10054e20 100644 --- a/cmd/cli/go.sum +++ b/cmd/cli/go.sum @@ -42,8 +42,8 @@ github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Y github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= -github.com/clipperhouse/displaywidth v0.6.0 h1:k32vueaksef9WIKCNcoqRNyKbyvkvkysNYnAWz2fN4s= -github.com/clipperhouse/displaywidth v0.6.0/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= +github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= @@ -81,8 +81,8 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/docker/cli v29.1.4+incompatible h1:AI8fwZhqsAsrqZnVv9h6lbexeW/LzNTasf6A4vcNN8M= -github.com/docker/cli v29.1.4+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v29.1.5+incompatible h1:GckbANUt3j+lsnQ6eCcQd70mNSOismSHWt8vk2AX8ao= +github.com/docker/cli v29.1.5+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/cli-docs-tool v0.11.0 h1:7d8QARFb7QEobizqxmEM7fOteZEHwH/zWgHQtHZEcfE= github.com/docker/cli-docs-tool v0.11.0/go.mod h1:ma8BKiisUo8D6W05XEYIh3oa1UbgrZhi1nowyKFJa8Q= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= @@ -167,9 +167,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= @@ -223,10 +222,10 @@ github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= -github.com/olekukonko/ll v0.1.3 h1:sV2jrhQGq5B3W0nENUISCR6azIPf7UBUpVq0x/y70Fg= -github.com/olekukonko/ll v0.1.3/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew= -github.com/olekukonko/tablewriter v1.1.2 h1:L2kI1Y5tZBct/O/TyZK1zIE9GlBj/TVs+AY5tZDCDSc= -github.com/olekukonko/tablewriter v1.1.2/go.mod h1:z7SYPugVqGVavWoA2sGsFIoOVNmEHxUAAMrhXONtfkg= +github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 h1:jrYnow5+hy3WRDCBypUFvVKNSPPCdqgSXIE9eJDD8LM= +github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew= +github.com/olekukonko/tablewriter v1.1.3 h1:VSHhghXxrP0JHl+0NnKid7WoEmd9/urKRJLysb70nnA= +github.com/olekukonko/tablewriter v1.1.3/go.mod h1:9VU0knjhmMkXjnMKrZ3+L2JhhtsQ/L38BbL3CRNE8tM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -354,7 +353,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From edd577a1fd50ab2ae79c34691587d82b59841661 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:33:13 +0000 Subject: [PATCH 13/17] chore(deps): bump github.com/sirupsen/logrus Bumps the go-modules-root group with 1 update: [github.com/sirupsen/logrus](https://github.com/sirupsen/logrus). Updates `github.com/sirupsen/logrus` from 1.9.3 to 1.9.4 - [Release notes](https://github.com/sirupsen/logrus/releases) - [Changelog](https://github.com/sirupsen/logrus/blob/master/CHANGELOG.md) - [Commits](https://github.com/sirupsen/logrus/compare/v1.9.3...v1.9.4) --- updated-dependencies: - dependency-name: github.com/sirupsen/logrus dependency-version: 1.9.4 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: go-modules-root ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index f4dfb5d5..a9c00dba 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/opencontainers/image-spec v1.1.1 github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.5 - github.com/sirupsen/logrus v1.9.3 + github.com/sirupsen/logrus v1.9.4 github.com/stretchr/testify v1.11.1 golang.org/x/sync v0.19.0 ) diff --git a/go.sum b/go.sum index 4b2c4ed3..f4cf2de0 100644 --- a/go.sum +++ b/go.sum @@ -105,13 +105,12 @@ github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTU github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -160,7 +159,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= @@ -189,7 +187,6 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= howett.net/plist v1.0.2-0.20250314012144-ee69052608d9 h1:eeH1AIcPvSc0Z25ThsYF+Xoqbn0CI/YnXVYoTLFdGQw= From 90300bb867b1484d4c3e9a730018961fb16de129 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Mon, 19 Jan 2026 10:47:17 +0200 Subject: [PATCH 14/17] chore(deps): go mod tidy Signed-off-by: Dorin Geman --- cmd/cli/go.mod | 2 +- cmd/cli/go.sum | 7 ++----- go.work.sum | 1 + 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cmd/cli/go.mod b/cmd/cli/go.mod index e7a7668a..32620942 100644 --- a/cmd/cli/go.mod +++ b/cmd/cli/go.mod @@ -115,7 +115,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sirupsen/logrus v1.9.4 // indirect github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/cmd/cli/go.sum b/cmd/cli/go.sum index 10054e20..b2b72973 100644 --- a/cmd/cli/go.sum +++ b/cmd/cli/go.sum @@ -255,8 +255,8 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= @@ -268,7 +268,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= @@ -352,7 +351,6 @@ golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -393,7 +391,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/go.work.sum b/go.work.sum index e440bb0a..3b4c9dfc 100644 --- a/go.work.sum +++ b/go.work.sum @@ -816,6 +816,7 @@ golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= From 81daf12473b011f177e01a6e5e7ca96509411eb0 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Mon, 19 Jan 2026 09:59:05 +0100 Subject: [PATCH 15/17] Add stable diffusion backend (#572) * feat(diffusers): implement diffusers backend for image generation * feat(diffusers): add support for DDUF (Diffusers Unified Format) file handling * feat(dduf): implement DDUF format support and enhance model loading * feat(dduf): calculate total size of files and add human-readable size format * feat(platform): restrict Diffusers support to Linux only until macOS distribution is designed * feat(diffusers): add support for DDUF file type handling in repository and config files * feat(diffusers): sanitize log output for Diffusers arguments * feat(docker): streamline Python server code copying in Dockerfile * feat(docker): specify exact versions for Python packages in Dockerfile * feat(model): add DDUF file support to packaging command and documentation * Update pkg/distribution/internal/bundle/unpack.go Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * refactor(dduf): replace formatDDUFSize with formatSize and clean up unused code * feat(docker): add support for building and running Diffusers Docker images * feat(client): add support for Diffusers format in GetSupportedFormats function * feat(docker): enhance GPU support for additional Docker image variants * feat: add support for image-generation mode in backend operations * feat(loader): support fallback for image-generation mode in runner config * feat(diffusers): initialize Diffusers backend in main.go * feat(diffusers): add error transformation for Python output and enhance backend error handling * fix(scripts/docker-run): conditionally add nvidia runtime flags Only add --gpus and --runtime=nvidia when the nvidia runtime is detected, allowing diffusers/sglang images to run on non-NVIDIA hosts. Signed-off-by: Dorin Geman --------- Signed-off-by: Dorin Geman Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> Co-authored-by: Dorin Geman --- Dockerfile | 60 ++- Makefile | 15 +- cmd/cli/commands/compose_test.go | 6 + cmd/cli/commands/configure_flags.go | 6 +- cmd/cli/commands/package.go | 40 +- .../reference/docker_model_configure.yaml | 3 +- .../docs/reference/docker_model_package.yaml | 16 +- cmd/cli/docs/reference/model.md | 52 +-- cmd/cli/docs/reference/model_package.md | 4 +- main.go | 21 +- pkg/distribution/distribution/client.go | 4 +- pkg/distribution/files/classify.go | 9 + pkg/distribution/format/dduf.go | 70 ++++ pkg/distribution/format/format.go | 15 + pkg/distribution/format/safetensors.go | 13 - pkg/distribution/huggingface/repository.go | 2 +- pkg/distribution/internal/bundle/bundle.go | 9 + pkg/distribution/internal/bundle/unpack.go | 38 +- pkg/distribution/internal/partial/partial.go | 4 + pkg/distribution/internal/store/model.go | 4 + pkg/distribution/modelpack/types.go | 2 + pkg/distribution/packaging/safetensors.go | 2 +- pkg/distribution/types/config.go | 5 + pkg/distribution/types/model.go | 2 + pkg/inference/backend.go | 7 + pkg/inference/backends/diffusers/diffusers.go | 237 +++++++++++ .../backends/diffusers/diffusers_config.go | 48 +++ pkg/inference/backends/diffusers/errors.go | 65 +++ .../backends/diffusers/errors_test.go | 97 +++++ .../backends/llamacpp/llamacpp_config.go | 2 +- .../backends/llamacpp/llamacpp_config_test.go | 4 + pkg/inference/backends/mlx/mlx_config.go | 2 +- pkg/inference/backends/runner.go | 15 +- .../backends/sglang/sglang_config.go | 3 +- .../backends/sglang/sglang_config_test.go | 4 + pkg/inference/backends/vllm/vllm_config.go | 7 +- .../backends/vllm/vllm_config_test.go | 4 + pkg/inference/platform/platform.go | 7 + pkg/inference/scheduling/api.go | 3 + pkg/inference/scheduling/http_handler.go | 3 + pkg/inference/scheduling/loader.go | 5 +- python/diffusers_server/__init__.py | 2 + python/diffusers_server/server.py | 374 ++++++++++++++++++ scripts/docker-run.sh | 11 +- 44 files changed, 1225 insertions(+), 77 deletions(-) create mode 100644 pkg/distribution/format/dduf.go create mode 100644 pkg/inference/backends/diffusers/diffusers.go create mode 100644 pkg/inference/backends/diffusers/diffusers_config.go create mode 100644 pkg/inference/backends/diffusers/errors.go create mode 100644 pkg/inference/backends/diffusers/errors_test.go create mode 100644 python/diffusers_server/__init__.py create mode 100644 python/diffusers_server/server.py diff --git a/Dockerfile b/Dockerfile index b4a8160c..e6738a69 100644 --- a/Dockerfile +++ b/Dockerfile @@ -75,7 +75,7 @@ ENV MODEL_RUNNER_PORT=12434 ENV LLAMA_SERVER_PATH=/app/bin ENV HOME=/home/modelrunner ENV MODELS_PATH=/models -ENV LD_LIBRARY_PATH=/app/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=/app/lib # Label the image so that it's hidden on cloud engines. LABEL com.docker.desktop.service="model-runner" @@ -144,6 +144,60 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ && ~/.local/bin/uv pip install --python /opt/sglang-env/bin/python "sglang==${SGLANG_VERSION}" RUN /opt/sglang-env/bin/python -c "import sglang; print(sglang.__version__)" > /opt/sglang-env/version + +# --- Diffusers variant --- +FROM llamacpp AS diffusers + +# Python package versions for reproducible builds +ARG DIFFUSERS_VERSION=0.36.0 +ARG TORCH_VERSION=2.9.1 +ARG TRANSFORMERS_VERSION=4.57.5 +ARG ACCELERATE_VERSION=1.3.0 +ARG SAFETENSORS_VERSION=0.5.2 +ARG HUGGINGFACE_HUB_VERSION=0.34.0 +ARG BITSANDBYTES_VERSION=0.49.1 +ARG FASTAPI_VERSION=0.115.12 +ARG UVICORN_VERSION=0.34.1 +ARG PILLOW_VERSION=11.2.1 + +USER root + +RUN apt update && apt install -y \ + python3 python3-venv python3-dev \ + curl ca-certificates build-essential \ + && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /opt/diffusers-env && chown -R modelrunner:modelrunner /opt/diffusers-env + +USER modelrunner + +# Install uv and diffusers as modelrunner user +RUN curl -LsSf https://astral.sh/uv/install.sh | sh \ + && ~/.local/bin/uv venv --python /usr/bin/python3 /opt/diffusers-env \ + && ~/.local/bin/uv pip install --python /opt/diffusers-env/bin/python \ + "diffusers==${DIFFUSERS_VERSION}" \ + "torch==${TORCH_VERSION}" \ + "transformers==${TRANSFORMERS_VERSION}" \ + "accelerate==${ACCELERATE_VERSION}" \ + "safetensors==${SAFETENSORS_VERSION}" \ + "huggingface_hub==${HUGGINGFACE_HUB_VERSION}" \ + "bitsandbytes==${BITSANDBYTES_VERSION}" \ + "fastapi==${FASTAPI_VERSION}" \ + "uvicorn[standard]==${UVICORN_VERSION}" \ + "pillow==${PILLOW_VERSION}" + +# Copy Python server code +USER root +COPY python/diffusers_server /tmp/diffusers_server/ +RUN PYTHON_SITE_PACKAGES=$(/opt/diffusers-env/bin/python -c "import site; print(site.getsitepackages()[0])") && \ + mkdir -p "$PYTHON_SITE_PACKAGES/diffusers_server" && \ + cp -r /tmp/diffusers_server/* "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ + chown -R modelrunner:modelrunner "$PYTHON_SITE_PACKAGES/diffusers_server/" && \ + rm -rf /tmp/diffusers_server +USER modelrunner + +RUN /opt/diffusers-env/bin/python -c "import diffusers; print(diffusers.__version__)" > /opt/diffusers-env/version + FROM llamacpp AS final-llamacpp # Copy the built binary from builder COPY --from=builder /app/model-runner /app/model-runner @@ -155,3 +209,7 @@ COPY --from=builder /app/model-runner /app/model-runner FROM sglang AS final-sglang # Copy the built binary from builder-sglang (without vLLM) COPY --from=builder-sglang /app/model-runner /app/model-runner + +FROM diffusers AS final-diffusers +# Copy the built binary from builder (with diffusers support) +COPY --from=builder /app/model-runner /app/model-runner diff --git a/Makefile b/Makefile index 82193c30..ff11138d 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,7 @@ VLLM_BASE_IMAGE := nvidia/cuda:13.0.2-runtime-ubuntu24.04 DOCKER_IMAGE := docker/model-runner:latest DOCKER_IMAGE_VLLM := docker/model-runner:latest-vllm-cuda DOCKER_IMAGE_SGLANG := docker/model-runner:latest-sglang +DOCKER_IMAGE_DIFFUSERS := docker/model-runner:latest-diffusers DOCKER_TARGET ?= final-llamacpp PORT := 8080 MODELS_PATH := $(shell pwd)/models-store @@ -25,7 +26,7 @@ DOCKER_BUILD_ARGS := \ BUILD_DMR ?= 1 # Main targets -.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-run-impl help validate lint +.PHONY: build run clean test integration-tests test-docker-ce-installation docker-build docker-build-multiplatform docker-run docker-build-vllm docker-run-vllm docker-build-sglang docker-run-sglang docker-run-impl help validate lint docker-build-diffusers docker-run-diffusers # Default target .DEFAULT_GOAL := build @@ -117,6 +118,16 @@ docker-build-sglang: docker-run-sglang: docker-build-sglang @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_SGLANG) +# Build Diffusers Docker image +docker-build-diffusers: + @$(MAKE) docker-build \ + DOCKER_TARGET=final-diffusers \ + DOCKER_IMAGE=$(DOCKER_IMAGE_DIFFUSERS) + +# Run Diffusers Docker container with TCP port access and mounted model storage +docker-run-diffusers: docker-build-diffusers + @$(MAKE) -s docker-run-impl DOCKER_IMAGE=$(DOCKER_IMAGE_DIFFUSERS) + # Common implementation for running Docker container docker-run-impl: @echo "" @@ -151,6 +162,8 @@ help: @echo " docker-run-vllm - Run vLLM Docker container" @echo " docker-build-sglang - Build SGLang Docker image" @echo " docker-run-sglang - Run SGLang Docker container" + @echo " docker-build-diffusers - Build Diffusers Docker image" + @echo " docker-run-diffusers - Run Diffusers Docker container" @echo " help - Show this help message" @echo "" @echo "Backend configuration options:" diff --git a/cmd/cli/commands/compose_test.go b/cmd/cli/commands/compose_test.go index d4a8b59e..f0960503 100644 --- a/cmd/cli/commands/compose_test.go +++ b/cmd/cli/commands/compose_test.go @@ -45,6 +45,12 @@ func TestParseBackendMode(t *testing.T) { expected: inference.BackendModeReranking, expectError: false, }, + { + name: "image-generation mode", + input: "image-generation", + expected: inference.BackendModeImageGeneration, + expectError: false, + }, { name: "invalid mode", input: "invalid", diff --git a/cmd/cli/commands/configure_flags.go b/cmd/cli/commands/configure_flags.go index 76991679..11e17cfd 100644 --- a/cmd/cli/commands/configure_flags.go +++ b/cmd/cli/commands/configure_flags.go @@ -146,7 +146,7 @@ func (f *ConfigureFlags) RegisterFlags(cmd *cobra.Command) { cmd.Flags().StringVar(&f.HFOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only") cmd.Flags().Var(NewFloat64PtrValue(&f.GPUMemoryUtilization), "gpu-memory-utilization", "fraction of GPU memory to use for the model executor (0.0-1.0) - vLLM only") cmd.Flags().Var(NewBoolPtrValue(&f.Think), "think", "enable reasoning mode for thinking models") - cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking)") + cmd.Flags().StringVar(&f.Mode, "mode", "", "backend operation mode (completion, embedding, reranking, image-generation)") } // BuildConfigureRequest builds a scheduling.ConfigureRequest from the flags. @@ -243,7 +243,9 @@ func parseBackendMode(mode string) (inference.BackendMode, error) { return inference.BackendModeEmbedding, nil case "reranking": return inference.BackendModeReranking, nil + case "image-generation": + return inference.BackendModeImageGeneration, nil default: - return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode) + return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking, image-generation", mode) } } diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index f08f7e00..dd943462 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -38,11 +38,12 @@ func newPackagedCmd() *cobra.Command { var opts packageOptions c := &cobra.Command{ - Use: "package (--gguf | --safetensors-dir | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL", - Short: "Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact.", - Long: "Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified.\n" + + Use: "package (--gguf | --safetensors-dir | --dduf | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL", + Short: "Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact.", + Long: "Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified.\n" + "When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).\n" + "When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive.\n" + + "When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file.\n" + "When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model.\n" + "For multimodal models, use --mmproj to include a multimodal projector file.", Args: func(cmd *cobra.Command, args []string) error { @@ -50,7 +51,7 @@ func newPackagedCmd() *cobra.Command { return err } - // Validate that exactly one of --gguf, --safetensors-dir, or --from is provided (mutually exclusive) + // Validate that exactly one of --gguf, --safetensors-dir, --dduf, or --from is provided (mutually exclusive) sourcesProvided := 0 if opts.ggufPath != "" { sourcesProvided++ @@ -58,19 +59,22 @@ func newPackagedCmd() *cobra.Command { if opts.safetensorsDir != "" { sourcesProvided++ } + if opts.ddufPath != "" { + sourcesProvided++ + } if opts.fromModel != "" { sourcesProvided++ } if sourcesProvided == 0 { return fmt.Errorf( - "One of --gguf, --safetensors-dir, or --from is required.\n\n" + + "One of --gguf, --safetensors-dir, --dduf, or --from is required.\n\n" + "See 'docker model package --help' for more information", ) } if sourcesProvided > 1 { return fmt.Errorf( - "Cannot specify more than one of --gguf, --safetensors-dir, or --from. Please use only one source.\n\n" + + "Cannot specify more than one of --gguf, --safetensors-dir, --dduf, or --from. Please use only one source.\n\n" + "See 'docker model package --help' for more information", ) } @@ -141,6 +145,15 @@ func newPackagedCmd() *cobra.Command { } } + // Validate DDUF path if provided + if opts.ddufPath != "" { + var err error + opts.ddufPath, err = validateAbsolutePath(opts.ddufPath, "DDUF") + if err != nil { + return err + } + } + // Validate dir-tar paths are relative (not absolute) for _, dirPath := range opts.dirTarPaths { if filepath.IsAbs(dirPath) { @@ -167,6 +180,7 @@ func newPackagedCmd() *cobra.Command { c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file") c.Flags().StringVar(&opts.safetensorsDir, "safetensors-dir", "", "absolute path to directory containing safetensors files and config") + c.Flags().StringVar(&opts.ddufPath, "dduf", "", "absolute path to DDUF archive file (Diffusers Unified Format)") c.Flags().StringVar(&opts.fromModel, "from", "", "reference to an existing model to repackage") c.Flags().StringVar(&opts.chatTemplatePath, "chat-template", "", "absolute path to chat template file (must be Jinja format)") c.Flags().StringArrayVarP(&opts.licensePaths, "license", "l", nil, "absolute path to a license file") @@ -182,6 +196,7 @@ type packageOptions struct { contextSize uint64 ggufPath string safetensorsDir string + ddufPath string fromModel string licensePaths []string dirTarPaths []string @@ -197,7 +212,7 @@ type builderInitResult struct { cleanupFunc func() // Optional cleanup function for temporary files } -// initializeBuilder creates a package builder from GGUF, Safetensors, or existing model +// initializeBuilder creates a package builder from GGUF, Safetensors, DDUF, or existing model func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitResult, error) { result := &builderInitResult{} @@ -246,7 +261,14 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes return nil, fmt.Errorf("add gguf file: %w", err) } result.builder = pkg - } else { + } else if opts.ddufPath != "" { + cmd.PrintErrf("Adding DDUF file from %q\n", opts.ddufPath) + pkg, err := builder.FromPath(opts.ddufPath) + if err != nil { + return nil, fmt.Errorf("add dduf file: %w", err) + } + result.builder = pkg + } else if opts.safetensorsDir != "" { // Safetensors model from directory cmd.PrintErrf("Scanning directory %q for safetensors model...\n", opts.safetensorsDir) safetensorsPaths, tempConfigArchive, err := packaging.PackageFromDirectory(opts.safetensorsDir) @@ -276,6 +298,8 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes } } result.builder = pkg + } else { + return nil, fmt.Errorf("no model source specified") } return result, nil diff --git a/cmd/cli/docs/reference/docker_model_configure.yaml b/cmd/cli/docs/reference/docker_model_configure.yaml index ce7ac015..a8bfb16a 100644 --- a/cmd/cli/docs/reference/docker_model_configure.yaml +++ b/cmd/cli/docs/reference/docker_model_configure.yaml @@ -40,7 +40,8 @@ options: swarm: false - option: mode value_type: string - description: backend operation mode (completion, embedding, reranking) + description: | + backend operation mode (completion, embedding, reranking, image-generation) deprecated: false hidden: false experimental: false diff --git a/cmd/cli/docs/reference/docker_model_package.yaml b/cmd/cli/docs/reference/docker_model_package.yaml index d59835ce..79d60899 100644 --- a/cmd/cli/docs/reference/docker_model_package.yaml +++ b/cmd/cli/docs/reference/docker_model_package.yaml @@ -1,13 +1,14 @@ command: docker model package short: | - Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact. + Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact. long: |- - Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. + Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf). When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive. + When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file. When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model. For multimodal models, use --mmproj to include a multimodal projector file. -usage: docker model package (--gguf | --safetensors-dir | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL +usage: docker model package (--gguf | --safetensors-dir | --dduf | --from ) [--license ...] [--mmproj ] [--context-size ] [--push] MODEL pname: docker model plink: docker_model.yaml options: @@ -30,6 +31,15 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: dduf + value_type: string + description: absolute path to DDUF archive file (Diffusers Unified Format) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: dir-tar value_type: stringArray default_value: '[]' diff --git a/cmd/cli/docs/reference/model.md b/cmd/cli/docs/reference/model.md index e139fc45..d47df7f8 100644 --- a/cmd/cli/docs/reference/model.md +++ b/cmd/cli/docs/reference/model.md @@ -5,32 +5,32 @@ Docker Model Runner ### Subcommands -| Name | Description | -|:------------------------------------------------|:------------------------------------------------------------------------------------------------| -| [`bench`](model_bench.md) | Benchmark a model's performance at different concurrency levels | -| [`df`](model_df.md) | Show Docker Model Runner disk usage | -| [`inspect`](model_inspect.md) | Display detailed information on one model | -| [`install-runner`](model_install-runner.md) | Install Docker Model Runner (Docker Engine only) | -| [`list`](model_list.md) | List the models pulled to your local environment | -| [`logs`](model_logs.md) | Fetch the Docker Model Runner logs | -| [`package`](model_package.md) | Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact. | -| [`ps`](model_ps.md) | List running models | -| [`pull`](model_pull.md) | Pull a model from Docker Hub or HuggingFace to your local environment | -| [`purge`](model_purge.md) | Remove all models | -| [`push`](model_push.md) | Push a model to Docker Hub | -| [`reinstall-runner`](model_reinstall-runner.md) | Reinstall Docker Model Runner (Docker Engine only) | -| [`requests`](model_requests.md) | Fetch requests+responses from Docker Model Runner | -| [`restart-runner`](model_restart-runner.md) | Restart Docker Model Runner (Docker Engine only) | -| [`rm`](model_rm.md) | Remove local models downloaded from Docker Hub | -| [`run`](model_run.md) | Run a model and interact with it using a submitted prompt or chat mode | -| [`search`](model_search.md) | Search for models on Docker Hub and HuggingFace | -| [`start-runner`](model_start-runner.md) | Start Docker Model Runner (Docker Engine only) | -| [`status`](model_status.md) | Check if the Docker Model Runner is running | -| [`stop-runner`](model_stop-runner.md) | Stop Docker Model Runner (Docker Engine only) | -| [`tag`](model_tag.md) | Tag a model | -| [`uninstall-runner`](model_uninstall-runner.md) | Uninstall Docker Model Runner (Docker Engine only) | -| [`unload`](model_unload.md) | Unload running models | -| [`version`](model_version.md) | Show the Docker Model Runner version | +| Name | Description | +|:------------------------------------------------|:-----------------------------------------------------------------------------------------------------------| +| [`bench`](model_bench.md) | Benchmark a model's performance at different concurrency levels | +| [`df`](model_df.md) | Show Docker Model Runner disk usage | +| [`inspect`](model_inspect.md) | Display detailed information on one model | +| [`install-runner`](model_install-runner.md) | Install Docker Model Runner (Docker Engine only) | +| [`list`](model_list.md) | List the models pulled to your local environment | +| [`logs`](model_logs.md) | Fetch the Docker Model Runner logs | +| [`package`](model_package.md) | Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact. | +| [`ps`](model_ps.md) | List running models | +| [`pull`](model_pull.md) | Pull a model from Docker Hub or HuggingFace to your local environment | +| [`purge`](model_purge.md) | Remove all models | +| [`push`](model_push.md) | Push a model to Docker Hub | +| [`reinstall-runner`](model_reinstall-runner.md) | Reinstall Docker Model Runner (Docker Engine only) | +| [`requests`](model_requests.md) | Fetch requests+responses from Docker Model Runner | +| [`restart-runner`](model_restart-runner.md) | Restart Docker Model Runner (Docker Engine only) | +| [`rm`](model_rm.md) | Remove local models downloaded from Docker Hub | +| [`run`](model_run.md) | Run a model and interact with it using a submitted prompt or chat mode | +| [`search`](model_search.md) | Search for models on Docker Hub and HuggingFace | +| [`start-runner`](model_start-runner.md) | Start Docker Model Runner (Docker Engine only) | +| [`status`](model_status.md) | Check if the Docker Model Runner is running | +| [`stop-runner`](model_stop-runner.md) | Stop Docker Model Runner (Docker Engine only) | +| [`tag`](model_tag.md) | Tag a model | +| [`uninstall-runner`](model_uninstall-runner.md) | Uninstall Docker Model Runner (Docker Engine only) | +| [`unload`](model_unload.md) | Unload running models | +| [`version`](model_version.md) | Show the Docker Model Runner version | diff --git a/cmd/cli/docs/reference/model_package.md b/cmd/cli/docs/reference/model_package.md index eaf3da29..062f1581 100644 --- a/cmd/cli/docs/reference/model_package.md +++ b/cmd/cli/docs/reference/model_package.md @@ -1,9 +1,10 @@ # docker model package -Package a GGUF file, Safetensors directory, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. +Package a GGUF file, Safetensors directory, DDUF file, or existing model into a Docker model OCI artifact, with optional licenses and multimodal projector. The package is sent to the model-runner, unless --push is specified. When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf). When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive. +When packaging a DDUF file (Diffusers Unified Format), --dduf should point to a .dduf archive file. When packaging from an existing model using --from, you can modify properties like context size to create a variant of the original model. For multimodal models, use --mmproj to include a multimodal projector file. @@ -13,6 +14,7 @@ For multimodal models, use --mmproj to include a multimodal projector file. |:--------------------|:--------------|:--------|:---------------------------------------------------------------------------------------| | `--chat-template` | `string` | | absolute path to chat template file (must be Jinja format) | | `--context-size` | `uint64` | `0` | context size in tokens | +| `--dduf` | `string` | | absolute path to DDUF archive file (Diffusers Unified Format) | | `--dir-tar` | `stringArray` | | relative path to directory to package as tar (can be specified multiple times) | | `--from` | `string` | | reference to an existing model to repackage | | `--gguf` | `string` | | absolute path to gguf file | diff --git a/main.go b/main.go index fea431ca..17826d41 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/docker/model-runner/pkg/anthropic" "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends/diffusers" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/mlx" "github.com/docker/model-runner/pkg/inference/backends/sglang" @@ -76,6 +77,7 @@ func main() { vllmServerPath := os.Getenv("VLLM_SERVER_PATH") sglangServerPath := os.Getenv("SGLANG_SERVER_PATH") mlxServerPath := os.Getenv("MLX_SERVER_PATH") + diffusersServerPath := os.Getenv("DIFFUSERS_SERVER_PATH") // Create a proxy-aware HTTP transport // Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment @@ -156,10 +158,23 @@ func main() { log.Fatalf("unable to initialize %s backend: %v", sglang.Name, err) } + diffusersBackend, err := diffusers.New( + log, + modelManager, + log.WithFields(logrus.Fields{"component": diffusers.Name}), + nil, + diffusersServerPath, + ) + + if err != nil { + log.Fatalf("unable to initialize diffusers backend: %v", err) + } + backends := map[string]inference.Backend{ - llamacpp.Name: llamaCppBackend, - mlx.Name: mlxBackend, - sglang.Name: sglangBackend, + llamacpp.Name: llamaCppBackend, + mlx.Name: mlxBackend, + sglang.Name: sglangBackend, + diffusers.Name: diffusersBackend, } registerVLLMBackend(backends, vllmBackend) diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index deb9a9e5..9f16da1d 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -598,9 +598,9 @@ func (c *Client) GetBundle(ref string) (types.ModelBundle, error) { func GetSupportedFormats() []types.Format { if platform.SupportsVLLM() { - return []types.Format{types.FormatGGUF, types.FormatSafetensors} + return []types.Format{types.FormatGGUF, types.FormatSafetensors, types.FormatDiffusers} } - return []types.Format{types.FormatGGUF} + return []types.Format{types.FormatGGUF, types.FormatDiffusers} } func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string, progressWriter io.Writer) error { diff --git a/pkg/distribution/files/classify.go b/pkg/distribution/files/classify.go index 9e4b24b6..d28fa586 100644 --- a/pkg/distribution/files/classify.go +++ b/pkg/distribution/files/classify.go @@ -17,6 +17,8 @@ const ( FileTypeGGUF // FileTypeSafetensors is a safetensors model weight file FileTypeSafetensors + // FileTypeDDUF is a DDUF (Diffusers Unified Format) file + FileTypeDDUF // FileTypeConfig is a configuration file (json, txt, etc.) FileTypeConfig // FileTypeLicense is a license file @@ -32,6 +34,8 @@ func (ft FileType) String() string { return "gguf" case FileTypeSafetensors: return "safetensors" + case FileTypeDDUF: + return "dduf" case FileTypeConfig: return "config" case FileTypeLicense: @@ -74,6 +78,11 @@ func Classify(path string) FileType { return FileTypeSafetensors } + // Check for DDUF files (Diffusers Unified Format) + if strings.HasSuffix(lower, ".dduf") { + return FileTypeDDUF + } + // Check for chat template files (before generic config check) for _, ext := range ChatTemplateExtensions { if strings.HasSuffix(lower, ext) { diff --git a/pkg/distribution/format/dduf.go b/pkg/distribution/format/dduf.go new file mode 100644 index 00000000..d4fe6e83 --- /dev/null +++ b/pkg/distribution/format/dduf.go @@ -0,0 +1,70 @@ +package format + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/docker/model-runner/pkg/distribution/oci" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// DDUFFormat implements the Format interface for DDUF (Diffusers Unified Format) model files. +// DDUF is a single-file archive format for diffusion models used by HuggingFace Diffusers. +type DDUFFormat struct{} + +// init registers the DDUF format implementation. +func init() { + Register(&DDUFFormat{}) +} + +// Name returns the format identifier for DDUF. +func (d *DDUFFormat) Name() types.Format { + return types.FormatDiffusers +} + +// MediaType returns the OCI media type for DDUF layers. +func (d *DDUFFormat) MediaType() oci.MediaType { + return types.MediaTypeDDUF +} + +// DiscoverShards finds all DDUF shard files for a model. +// DDUF is a single-file format, so this always returns a slice containing only the input path. +func (d *DDUFFormat) DiscoverShards(path string) ([]string, error) { + // DDUF files are single archives, not sharded + return []string{path}, nil +} + +// ExtractConfig parses DDUF file(s) and extracts model configuration metadata. +// DDUF files are zip archives containing model config, so we extract what we can. +func (d *DDUFFormat) ExtractConfig(paths []string) (types.Config, error) { + if len(paths) == 0 { + return types.Config{Format: types.FormatDiffusers}, nil + } + + // Calculate total size across all files + var totalSize int64 + for _, path := range paths { + info, err := os.Stat(path) + if err != nil { + return types.Config{}, fmt.Errorf("failed to stat file %s: %w", path, err) + } + totalSize += info.Size() + } + + // Extract the filename for metadata + ddufFile := filepath.Base(paths[0]) + + // Return config with diffusers-specific metadata + // In the future, we could extract model_index.json from the DDUF archive + // to get architecture details, etc. + return types.Config{ + Format: types.FormatDiffusers, + Architecture: "diffusers", + Size: formatSize(totalSize), + Diffusers: map[string]string{ + "layout": "dduf", + "dduf_file": ddufFile, + }, + }, nil +} diff --git a/pkg/distribution/format/format.go b/pkg/distribution/format/format.go index 6b4b5c40..b905e57b 100644 --- a/pkg/distribution/format/format.go +++ b/pkg/distribution/format/format.go @@ -6,6 +6,7 @@ package format import ( "fmt" + "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/files" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" @@ -60,6 +61,8 @@ func DetectFromPath(path string) (Format, error) { return Get(types.FormatGGUF) case files.FileTypeSafetensors: return Get(types.FormatSafetensors) + case files.FileTypeDDUF: + return Get(types.FormatDiffusers) case files.FileTypeUnknown, files.FileTypeConfig, files.FileTypeLicense, files.FileTypeChatTemplate: return nil, fmt.Errorf("unable to detect format from path: %s (file type: %s)", utils.SanitizeForLog(path), ft) } @@ -93,3 +96,15 @@ func DetectFromPaths(paths []string) (Format, error) { return format, nil } + +// formatParameters converts parameter count to human-readable format +// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) +func formatParameters(params int64) string { + return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) +} + +// formatSize converts bytes to human-readable format matching Docker's style +// Returns format like "256MB" (decimal units, no space, matching `docker images`) +func formatSize(bytes int64) string { + return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) +} diff --git a/pkg/distribution/format/safetensors.go b/pkg/distribution/format/safetensors.go index 23162577..fa656379 100644 --- a/pkg/distribution/format/safetensors.go +++ b/pkg/distribution/format/safetensors.go @@ -11,7 +11,6 @@ import ( "sort" "strconv" - "github.com/docker/go-units" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -291,15 +290,3 @@ func (h *safetensorsHeader) extractMetadata() map[string]string { return metadata } - -// formatParameters converts parameter count to human-readable format -// Returns format like "361.82M" or "1.5B" (no space before unit, base 1000, where B = Billion) -func formatParameters(params int64) string { - return units.CustomSize("%.2f%s", float64(params), 1000.0, []string{"", "K", "M", "B", "T"}) -} - -// formatSize converts bytes to human-readable format matching Docker's style -// Returns format like "256MB" (decimal units, no space, matching `docker images`) -func formatSize(bytes int64) string { - return units.CustomSize("%.2f%s", float64(bytes), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB"}) -} diff --git a/pkg/distribution/huggingface/repository.go b/pkg/distribution/huggingface/repository.go index c0d69a95..1c7bb3f8 100644 --- a/pkg/distribution/huggingface/repository.go +++ b/pkg/distribution/huggingface/repository.go @@ -55,7 +55,7 @@ func FilterModelFiles(repoFiles []RepoFile) (weights []RepoFile, configs []RepoF weights = append(weights, f) case files.FileTypeConfig, files.FileTypeChatTemplate: configs = append(configs, f) - case files.FileTypeUnknown, files.FileTypeLicense: + case files.FileTypeUnknown, files.FileTypeLicense, files.FileTypeDDUF: // Skip these file types } } diff --git a/pkg/distribution/internal/bundle/bundle.go b/pkg/distribution/internal/bundle/bundle.go index b17b241d..e9d82a8b 100644 --- a/pkg/distribution/internal/bundle/bundle.go +++ b/pkg/distribution/internal/bundle/bundle.go @@ -17,6 +17,7 @@ type Bundle struct { mmprojPath string ggufFile string // path to GGUF file (first shard when model is split among files) safetensorsFile string // path to safetensors file (first shard when model is split among files) + ddufFile string // path to DDUF file (Diffusers Unified Format) runtimeConfig types.ModelConfig chatTemplatePath string } @@ -59,6 +60,14 @@ func (b *Bundle) SafetensorsPath() string { return filepath.Join(b.dir, ModelSubdir, b.safetensorsFile) } +// DDUFPath returns the path to the DDUF file (Diffusers Unified Format) or "" if none is present. +func (b *Bundle) DDUFPath() string { + if b.ddufFile == "" { + return "" + } + return filepath.Join(b.dir, ModelSubdir, b.ddufFile) +} + // RuntimeConfig returns config that should be respected by the backend at runtime. // Can return either Docker format (*types.Config) or ModelPack format (*modelpack.Model). func (b *Bundle) RuntimeConfig() types.ModelConfig { diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index b2e67d52..86e39eaf 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -38,8 +38,12 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { if err := unpackSafetensors(bundle, model); err != nil { return nil, fmt.Errorf("unpack safetensors files: %w", err) } + case types.FormatDiffusers: + if err := unpackDDUF(bundle, model); err != nil { + return nil, fmt.Errorf("unpack DDUF file: %w", err) + } default: - return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)") + return nil, fmt.Errorf("no supported model weights found (expected GGUF, safetensors, or diffusers/DDUF)") } // Unpack optional components based on their presence @@ -88,9 +92,41 @@ func detectModelFormat(model types.Model) types.Format { return types.FormatSafetensors } + // Check for DDUF files + ddufPaths, err := model.DDUFPaths() + if err == nil && len(ddufPaths) > 0 { + return types.FormatDiffusers + } + return "" } +// unpackDDUF unpacks a DDUF (Diffusers Unified Format) file to the bundle. +func unpackDDUF(bundle *Bundle, mdl types.Model) error { + ddufPaths, err := mdl.DDUFPaths() + if err != nil { + return fmt.Errorf("get DDUF files for model: %w", err) + } + + if len(ddufPaths) == 0 { + return fmt.Errorf("no DDUF files found") + } + + modelDir := filepath.Join(bundle.dir, ModelSubdir) + + // DDUF is a single-file format + ddufFilename := filepath.Base(ddufPaths[0]) + // Ensure the filename has the .dduf extension for proper detection by diffusers server + if !strings.HasSuffix(strings.ToLower(ddufFilename), ".dduf") { + ddufFilename = ddufFilename + ".dduf" + } + if err := unpackFile(filepath.Join(modelDir, ddufFilename), ddufPaths[0]); err != nil { + return err + } + bundle.ddufFile = ddufFilename + return nil +} + // hasLayerWithMediaType checks if the model contains a layer with the specified media type func hasLayerWithMediaType(model types.Model, targetMediaType oci.MediaType) bool { // Check specific media types using the model's methods diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index fdea593a..b1750cb6 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -124,6 +124,10 @@ func SafetensorsPaths(i WithLayers) ([]string, error) { return layerPathsByMediaType(i, types.MediaTypeSafetensors) } +func DDUFPaths(i WithLayers) ([]string, error) { + return layerPathsByMediaType(i, types.MediaTypeDDUF) +} + func ConfigArchivePath(i WithLayers) (string, error) { paths, err := layerPathsByMediaType(i, types.MediaTypeVLLMConfigArchive) if err != nil { diff --git a/pkg/distribution/internal/store/model.go b/pkg/distribution/internal/store/model.go index 0335c71d..4fa53acf 100644 --- a/pkg/distribution/internal/store/model.go +++ b/pkg/distribution/internal/store/model.go @@ -157,6 +157,10 @@ func (m *Model) SafetensorsPaths() ([]string, error) { return mdpartial.SafetensorsPaths(m) } +func (m *Model) DDUFPaths() ([]string, error) { + return mdpartial.DDUFPaths(m) +} + func (m *Model) ConfigArchivePath() (string, error) { return mdpartial.ConfigArchivePath(m) } diff --git a/pkg/distribution/modelpack/types.go b/pkg/distribution/modelpack/types.go index d6d1b5e8..af4345af 100644 --- a/pkg/distribution/modelpack/types.go +++ b/pkg/distribution/modelpack/types.go @@ -158,6 +158,8 @@ func (m *Model) GetFormat() types.Format { return types.FormatGGUF case "safetensors": return types.FormatSafetensors + case "diffusers": + return types.FormatDiffusers default: return types.Format(f) } diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go index 2a5be421..a4348f20 100644 --- a/pkg/distribution/packaging/safetensors.go +++ b/pkg/distribution/packaging/safetensors.go @@ -40,7 +40,7 @@ func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig safetensorsPaths = append(safetensorsPaths, fullPath) case files.FileTypeConfig, files.FileTypeChatTemplate: configFiles = append(configFiles, fullPath) - case files.FileTypeUnknown, files.FileTypeGGUF, files.FileTypeLicense: + case files.FileTypeUnknown, files.FileTypeGGUF, files.FileTypeLicense, files.FileTypeDDUF: // Skip these file types } } diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 5b910585..ec4dab3d 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -25,6 +25,9 @@ const ( // MediaTypeDirTar indicates a tar archive containing a directory with its structure preserved. MediaTypeDirTar MediaType = "application/vnd.docker.ai.dir.tar" + // MediaTypeDDUF indicates a file in DDUF format (Diffusers Unified Format). + MediaTypeDDUF MediaType = "application/vnd.docker.ai.dduf" + // MediaTypeLicense indicates a plain text file containing a license MediaTypeLicense MediaType = "application/vnd.docker.ai.license" @@ -36,6 +39,7 @@ const ( FormatGGUF = Format("gguf") FormatSafetensors = Format("safetensors") + FormatDiffusers = Format("diffusers") // OCI Annotation keys for model layers // See https://github.com/opencontainers/image-spec/blob/main/annotations.md @@ -82,6 +86,7 @@ type Config struct { Size string `json:"size,omitempty"` GGUF map[string]string `json:"gguf,omitempty"` Safetensors map[string]string `json:"safetensors,omitempty"` + Diffusers map[string]string `json:"diffusers,omitempty"` ContextSize *int32 `json:"context_size,omitempty"` } diff --git a/pkg/distribution/types/model.go b/pkg/distribution/types/model.go index 8fe2956d..350f4975 100644 --- a/pkg/distribution/types/model.go +++ b/pkg/distribution/types/model.go @@ -8,6 +8,7 @@ type Model interface { ID() (string, error) GGUFPaths() ([]string, error) SafetensorsPaths() ([]string, error) + DDUFPaths() ([]string, error) ConfigArchivePath() (string, error) MMPROJPath() (string, error) Config() (ModelConfig, error) @@ -27,6 +28,7 @@ type ModelBundle interface { RootDir() string GGUFPath() string SafetensorsPath() string + DDUFPath() string ChatTemplatePath() string MMPROJPath() string RuntimeConfig() ModelConfig diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index c4c9dfb8..e83ff5f9 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -18,6 +18,9 @@ const ( // mode. BackendModeEmbedding BackendModeReranking + // BackendModeImageGeneration indicates that the backend should run in + // image generation mode. + BackendModeImageGeneration ) type ErrGGUFParse struct { @@ -37,6 +40,8 @@ func (m BackendMode) String() string { return "embedding" case BackendModeReranking: return "reranking" + case BackendModeImageGeneration: + return "image-generation" default: return "unknown" } @@ -72,6 +77,8 @@ func ParseBackendMode(mode string) (BackendMode, bool) { return BackendModeEmbedding, true case "reranking": return BackendModeReranking, true + case "image-generation": + return BackendModeImageGeneration, true default: return BackendModeCompletion, false } diff --git a/pkg/inference/backends/diffusers/diffusers.go b/pkg/inference/backends/diffusers/diffusers.go new file mode 100644 index 00000000..a966c667 --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers.go @@ -0,0 +1,237 @@ +package diffusers + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/docker/model-runner/pkg/diskusage" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/backends" + "github.com/docker/model-runner/pkg/inference/models" + "github.com/docker/model-runner/pkg/inference/platform" + "github.com/docker/model-runner/pkg/internal/utils" + "github.com/docker/model-runner/pkg/logging" +) + +const ( + // Name is the backend name. + Name = "diffusers" + diffusersDir = "/opt/diffusers-env" +) + +var ( + ErrNotImplemented = errors.New("not implemented") + ErrDiffusersNotFound = errors.New("diffusers package not installed") + ErrPythonNotFound = errors.New("python3 not found in PATH") + ErrNoDDUFFile = errors.New("no DDUF file found in model bundle") +) + +// diffusers is the diffusers-based backend implementation for image generation. +type diffusers struct { + // log is the associated logger. + log logging.Logger + // modelManager is the shared model manager. + modelManager *models.Manager + // serverLog is the logger to use for the diffusers server process. + serverLog logging.Logger + // config is the configuration for the diffusers backend. + config *Config + // status is the state in which the diffusers backend is in. + status string + // pythonPath is the path to the python3 binary. + pythonPath string + // customPythonPath is an optional custom path to the python3 binary. + customPythonPath string +} + +// New creates a new diffusers-based backend for image generation. +// customPythonPath is an optional path to a custom python3 binary; if empty, the default path is used. +func New(log logging.Logger, modelManager *models.Manager, serverLog logging.Logger, conf *Config, customPythonPath string) (inference.Backend, error) { + // If no config is provided, use the default configuration + if conf == nil { + conf = NewDefaultConfig() + } + + return &diffusers{ + log: log, + modelManager: modelManager, + serverLog: serverLog, + config: conf, + status: "not installed", + customPythonPath: customPythonPath, + }, nil +} + +// Name implements inference.Backend.Name. +func (d *diffusers) Name() string { + return Name +} + +// UsesExternalModelManagement implements inference.Backend.UsesExternalModelManagement. +// Diffusers uses the shared model manager with bundled DDUF files. +func (d *diffusers) UsesExternalModelManagement() bool { + return false // Use the bundle system for DDUF files +} + +// UsesTCP implements inference.Backend.UsesTCP. +// Diffusers uses TCP for communication, like SGLang. +func (d *diffusers) UsesTCP() bool { + return true +} + +// Install implements inference.Backend.Install. +func (d *diffusers) Install(_ context.Context, _ *http.Client) error { + if !platform.SupportsDiffusers() { + return ErrNotImplemented + } + + var pythonPath string + + // Use custom python path if specified + if d.customPythonPath != "" { + pythonPath = d.customPythonPath + } else { + venvPython := filepath.Join(diffusersDir, "bin", "python3") + pythonPath = venvPython + + if _, err := os.Stat(venvPython); err != nil { + // Fall back to system Python + systemPython, err := exec.LookPath("python3") + if err != nil { + d.status = ErrPythonNotFound.Error() + return ErrPythonNotFound + } + pythonPath = systemPython + } + } + + d.pythonPath = pythonPath + + // Check if diffusers is installed + if err := d.pythonCmd("-c", "import diffusers").Run(); err != nil { + d.status = "diffusers package not installed" + d.log.Warnf("diffusers package not found. Install with: uv pip install diffusers torch") + return ErrDiffusersNotFound + } + + // Get version + output, err := d.pythonCmd("-c", "import diffusers; print(diffusers.__version__)").Output() + if err != nil { + d.log.Warnf("could not get diffusers version: %v", err) + d.status = "running diffusers version: unknown" + } else { + d.status = fmt.Sprintf("running diffusers version: %s", strings.TrimSpace(string(output))) + } + + return nil +} + +// Run implements inference.Backend.Run. +func (d *diffusers) Run(ctx context.Context, socket, model string, modelRef string, mode inference.BackendMode, backendConfig *inference.BackendConfiguration) error { + if !platform.SupportsDiffusers() { + d.log.Warn("diffusers backend is not yet supported on this platform") + return ErrNotImplemented + } + + // For diffusers, we support image generation mode + if mode != inference.BackendModeImageGeneration { + return fmt.Errorf("diffusers backend only supports image-generation mode, got %s", mode) + } + + // Get the model bundle to find the DDUF file path + bundle, err := d.modelManager.GetBundle(model) + if err != nil { + return fmt.Errorf("failed to get model bundle for %s: %w", model, err) + } + + // Get the DDUF file path from the bundle + ddufPath := bundle.DDUFPath() + if ddufPath == "" { + return fmt.Errorf("%w: model %s", ErrNoDDUFFile, model) + } + + d.log.Infof("Loading DDUF file from: %s", ddufPath) + + args, err := d.config.GetArgs(ddufPath, socket, mode, backendConfig) + if err != nil { + return fmt.Errorf("failed to get diffusers arguments: %w", err) + } + + // Add served model name using the human-readable model reference + if modelRef != "" { + args = append(args, "--served-model-name", modelRef) + } + + d.log.Infof("Diffusers args: %v", utils.SanitizeForLog(strings.Join(args, " "))) + + if d.pythonPath == "" { + return fmt.Errorf("diffusers: python runtime not configured; did you forget to call Install") + } + + sandboxPath := "" + if _, err := os.Stat(diffusersDir); err == nil { + sandboxPath = diffusersDir + } + + return backends.RunBackend(ctx, backends.RunnerConfig{ + BackendName: "Diffusers", + Socket: socket, + BinaryPath: d.pythonPath, + SandboxPath: sandboxPath, + SandboxConfig: "", + Args: args, + Logger: d.log, + ServerLogWriter: d.serverLog.Writer(), + ErrorTransformer: ExtractPythonError, + }) +} + +// Status implements inference.Backend.Status. +func (d *diffusers) Status() string { + return d.status +} + +// GetDiskUsage implements inference.Backend.GetDiskUsage. +func (d *diffusers) GetDiskUsage() (int64, error) { + // Check if Docker installation exists + if _, err := os.Stat(diffusersDir); err == nil { + size, err := diskusage.Size(diffusersDir) + if err != nil { + return 0, fmt.Errorf("error while getting diffusers dir size: %w", err) + } + return size, nil + } + // Python installation doesn't have a dedicated installation directory + // It's installed via pip in the system Python environment + return 0, nil +} + +// GetRequiredMemoryForModel returns the estimated memory requirements for a model. +func (d *diffusers) GetRequiredMemoryForModel(_ context.Context, _ string, _ *inference.BackendConfiguration) (inference.RequiredMemory, error) { + if !platform.SupportsDiffusers() { + return inference.RequiredMemory{}, ErrNotImplemented + } + + // Stable Diffusion models typically require significant VRAM + // SD 1.5: ~4GB VRAM, SD 2.1: ~5GB VRAM, SDXL: ~8GB VRAM + return inference.RequiredMemory{ + RAM: 4 * 1024 * 1024 * 1024, // 4GB RAM + VRAM: 6 * 1024 * 1024 * 1024, // 6GB VRAM (average estimate) + }, nil +} + +// pythonCmd creates an exec.Cmd that runs python with the given arguments. +// It uses the configured pythonPath if available, otherwise falls back to "python3". +func (d *diffusers) pythonCmd(args ...string) *exec.Cmd { + pythonBinary := "python3" + if d.pythonPath != "" { + pythonBinary = d.pythonPath + } + return exec.Command(pythonBinary, args...) +} diff --git a/pkg/inference/backends/diffusers/diffusers_config.go b/pkg/inference/backends/diffusers/diffusers_config.go new file mode 100644 index 00000000..010445e6 --- /dev/null +++ b/pkg/inference/backends/diffusers/diffusers_config.go @@ -0,0 +1,48 @@ +package diffusers + +import ( + "fmt" + "net" + + "github.com/docker/model-runner/pkg/inference" +) + +// Config is the configuration for the diffusers backend. +type Config struct { + // Args are the base arguments that are always included. + Args []string +} + +// NewDefaultConfig creates a new Config with default values. +func NewDefaultConfig() *Config { + return &Config{} +} + +// GetArgs implements BackendConfig.GetArgs for the diffusers backend. +func (c *Config) GetArgs(model string, socket string, mode inference.BackendMode, config *inference.BackendConfiguration) ([]string, error) { + // Start with the arguments from Config + args := append([]string{}, c.Args...) + + // Diffusers uses Python module: python -m diffusers_server.server + args = append(args, "-m", "diffusers_server.server") + + // Add model path - for diffusers this can be a HuggingFace model ID or local path + args = append(args, "--model-path", model) + + // Parse host:port from socket + host, port, err := net.SplitHostPort(socket) + if err != nil { + return nil, fmt.Errorf("failed to parse host:port from %q: %w", socket, err) + } + args = append(args, "--host", host, "--port", port) + + // Add mode-specific arguments + switch mode { + case inference.BackendModeImageGeneration: + // Default mode for diffusers - image generation + case inference.BackendModeCompletion, inference.BackendModeEmbedding, inference.BackendModeReranking: + return nil, fmt.Errorf("unsupported backend mode %q for diffusers", mode) + } + + return args, nil +} diff --git a/pkg/inference/backends/diffusers/errors.go b/pkg/inference/backends/diffusers/errors.go new file mode 100644 index 00000000..e3bb6814 --- /dev/null +++ b/pkg/inference/backends/diffusers/errors.go @@ -0,0 +1,65 @@ +package diffusers + +import ( + "fmt" + "regexp" + "strings" +) + +// pythonErrorPatterns contains regex patterns to extract meaningful error messages +// from Python tracebacks. The patterns are tried in order, and the first match wins. +var pythonErrorPatterns = []*regexp.Regexp{ + // Custom error marker from our Python server (highest priority) + regexp.MustCompile(`(?m)^DIFFUSERS_ERROR:\s*(.+)$`), + // Python RuntimeError, ValueError, etc. + regexp.MustCompile(`(?m)^(RuntimeError|ValueError|TypeError|OSError|ImportError|ModuleNotFoundError):\s*(.+)$`), + // CUDA/GPU related errors + regexp.MustCompile(`(?mi)(CUDA|GPU|out of memory|OOM|No GPU found)[^.]*\.?`), + // Generic Python Exception with message + regexp.MustCompile(`(?m)^(\w+Error):\s*(.+)$`), +} + +// ExtractPythonError attempts to extract a meaningful error message from Python output. +// It looks for common error patterns and returns a cleaner, more user-friendly message. +// If no recognizable pattern is found, it returns the original output. +func ExtractPythonError(output string) string { + // Try each pattern in order + for i, pattern := range pythonErrorPatterns { + matches := pattern.FindStringSubmatch(output) + if len(matches) > 0 { + switch i { + case 0: + // Custom error marker: return just the message + return strings.TrimSpace(matches[1]) + case 1: + // Standard Python errors: "ErrorType: message" + return fmt.Sprintf("%s: %s", matches[1], strings.TrimSpace(matches[2])) + case 2: + // GPU/CUDA related errors + return strings.TrimSpace(matches[0]) + case 3: + // Generic Python errors + return fmt.Sprintf("%s: %s", matches[1], strings.TrimSpace(matches[2])) + } + } + } + + // No pattern matched - return original but try to trim some noise + // Take only the last few meaningful lines + lines := strings.Split(strings.TrimSpace(output), "\n") + if len(lines) > 5 { + // Return the last 5 non-empty lines + var meaningful []string + for i := len(lines) - 1; i >= 0 && len(meaningful) < 5; i-- { + line := strings.TrimSpace(lines[i]) + if line != "" && !strings.HasPrefix(line, " ") { + meaningful = append([]string{line}, meaningful...) + } + } + if len(meaningful) > 0 { + return strings.Join(meaningful, "\n") + } + } + + return output +} diff --git a/pkg/inference/backends/diffusers/errors_test.go b/pkg/inference/backends/diffusers/errors_test.go new file mode 100644 index 00000000..69f84d0f --- /dev/null +++ b/pkg/inference/backends/diffusers/errors_test.go @@ -0,0 +1,97 @@ +package diffusers + +import ( + "testing" +) + +func TestExtractPythonError(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "custom diffusers error marker", + input: "DIFFUSERS_ERROR: No GPU found. A GPU is needed for quantization.", + expected: "No GPU found. A GPU is needed for quantization.", + }, + { + name: "custom error marker in traceback", + input: `Traceback (most recent call last): + File "server.py", line 350, in main + load_model(args.model_path) +RuntimeError: Failed to load DDUF file: No GPU found +DIFFUSERS_ERROR: No GPU found. A GPU is needed for quantization.`, + expected: "No GPU found. A GPU is needed for quantization.", + }, + { + name: "python runtime error", + input: `RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization. +RuntimeError: No GPU found. A GPU is needed for quantization.`, + expected: "RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization.", + }, + { + name: "full python traceback", + input: ` raise RuntimeError(f"Failed to load DDUF file: {e}") +RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization. +RuntimeError: No GPU found. A GPU is needed for quantization. + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 358, in + main() + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 350, in main + load_model(args.model_path) + File "/opt/diffusers-env/lib/python3.12/site-packages/diffusers_server/server.py", line 139, in load_model + pipeline = load_model_from_dduf(model_path, device, dtype)`, + expected: "RuntimeError: Failed to load DDUF file: No GPU found. A GPU is needed for quantization.", + }, + { + name: "GPU not found error", + input: "Some log output\nNo GPU found. A GPU is needed for quantization.\nMore logs", + expected: "No GPU found.", + }, + { + name: "CUDA out of memory error", + input: "CUDA out of memory. Tried to allocate 2.00 GiB", + expected: "CUDA out of memory.", + }, + { + name: "import error", + input: "ImportError: No module named 'torch'", + expected: "ImportError: No module named 'torch'", + }, + { + name: "module not found error", + input: "ModuleNotFoundError: No module named 'diffusers'", + expected: "ModuleNotFoundError: No module named 'diffusers'", + }, + { + name: "value error", + input: "ValueError: Invalid model path", + expected: "ValueError: Invalid model path", + }, + { + name: "short output without pattern", + input: "some random error", + expected: "some random error", + }, + { + name: "empty output", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractPythonError(tt.input) + if result != tt.expected { + t.Errorf("ExtractPythonError() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index f0ed4106..87816410 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -65,7 +65,7 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference args = append(args, "--embeddings") case inference.BackendModeReranking: args = append(args, "--embeddings", "--reranking") - default: + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index 1a53a1c8..ee8223c1 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -448,6 +448,10 @@ func (f *fakeBundle) SafetensorsPath() string { return "" } +func (f *fakeBundle) DDUFPath() string { + return "" +} + func (f *fakeBundle) RuntimeConfig() types.ModelConfig { if f.config == nil { return nil diff --git a/pkg/inference/backends/mlx/mlx_config.go b/pkg/inference/backends/mlx/mlx_config.go index bc4f605c..29f98638 100644 --- a/pkg/inference/backends/mlx/mlx_config.go +++ b/pkg/inference/backends/mlx/mlx_config.go @@ -49,7 +49,7 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeReranking: // MLX may not support reranking mode return nil, fmt.Errorf("reranking mode not supported by MLX backend") - default: + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/runner.go b/pkg/inference/backends/runner.go index 244fa1c7..18685734 100644 --- a/pkg/inference/backends/runner.go +++ b/pkg/inference/backends/runner.go @@ -16,6 +16,11 @@ import ( "github.com/docker/model-runner/pkg/tailbuffer" ) +// ErrorTransformer is a function that transforms raw error output +// into a more user-friendly message. Backends can provide their own +// implementation to customize error presentation. +type ErrorTransformer func(output string) string + // RunnerConfig holds configuration for a backend runner type RunnerConfig struct { // BackendName is the display name of the backend (e.g., "llama.cpp", "vLLM") @@ -34,6 +39,9 @@ type RunnerConfig struct { Logger Logger // ServerLogWriter provides a writer for server logs ServerLogWriter io.WriteCloser + // ErrorTransformer is an optional function to transform error output + // into a more user-friendly message. If nil, the raw output is used. + ErrorTransformer ErrorTransformer } // Logger interface for backend logging @@ -103,7 +111,12 @@ func RunBackend(ctx context.Context, config RunnerConfig) error { } if errOutput.String() != "" { - backendErr = fmt.Errorf("%s exit status: %w\nwith output: %s", config.BackendName, backendErr, errOutput.String()) + errorMsg := errOutput.String() + // Apply error transformer if provided + if config.ErrorTransformer != nil { + errorMsg = config.ErrorTransformer(errorMsg) + } + backendErr = fmt.Errorf("%s failed: %s", config.BackendName, errorMsg) } else { backendErr = fmt.Errorf("%s exit status: %w", config.BackendName, backendErr) } diff --git a/pkg/inference/backends/sglang/sglang_config.go b/pkg/inference/backends/sglang/sglang_config.go index 4d220d96..814a516f 100644 --- a/pkg/inference/backends/sglang/sglang_config.go +++ b/pkg/inference/backends/sglang/sglang_config.go @@ -50,7 +50,8 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeEmbedding: args = append(args, "--is-embedding") case inference.BackendModeReranking: - default: + // SGLang does not have a specific flag for reranking + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/sglang/sglang_config_test.go b/pkg/inference/backends/sglang/sglang_config_test.go index e4aed925..2a96b0bc 100644 --- a/pkg/inference/backends/sglang/sglang_config_test.go +++ b/pkg/inference/backends/sglang/sglang_config_test.go @@ -35,6 +35,10 @@ func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { return m.runtimeConfig } +func (m *mockModelBundle) DDUFPath() string { + return "" +} + func (m *mockModelBundle) RootDir() string { return "/path/to/bundle" } diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index b172637f..b3ad0d2d 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -45,10 +45,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference case inference.BackendModeCompletion: // Default mode for vLLM case inference.BackendModeEmbedding: - // vLLM doesn't have a specific embedding flag like llama.cpp - // Embedding models are detected automatically + // vLLM doesn't have a specific embedding flag like llama.cpp + // Embedding models are detected automatically case inference.BackendModeReranking: - default: + // vLLM does not have a specific flag for reranking + case inference.BackendModeImageGeneration: return nil, fmt.Errorf("unsupported backend mode %q", mode) } diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index c52d65e1..d6498534 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -35,6 +35,10 @@ func (m *mockModelBundle) RuntimeConfig() types.ModelConfig { return m.runtimeConfig } +func (m *mockModelBundle) DDUFPath() string { + return "" +} + func (m *mockModelBundle) RootDir() string { return "/path/to/bundle" } diff --git a/pkg/inference/platform/platform.go b/pkg/inference/platform/platform.go index 49bffb75..8f1e1ed1 100644 --- a/pkg/inference/platform/platform.go +++ b/pkg/inference/platform/platform.go @@ -17,3 +17,10 @@ func SupportsMLX() bool { func SupportsSGLang() bool { return runtime.GOOS == "linux" } + +// SupportsDiffusers returns true if diffusers is supported on the current platform. +// Diffusers is supported on Linux (for Docker/CUDA) and macOS (for MPS/Apple Silicon). +func SupportsDiffusers() bool { + // return runtime.GOOS == "linux" || runtime.GOOS == "darwin" + return runtime.GOOS == "linux" // Support for macOS disabled for now until we design a solution to distribute it via Docker Desktop. +} diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index f9460dc2..7cd444a4 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -41,6 +41,9 @@ func backendModeForRequest(path string) (inference.BackendMode, bool) { } else if strings.HasSuffix(path, "/v1/messages") || strings.HasSuffix(path, "/v1/messages/count_tokens") { // Anthropic Messages API - treated as completion mode return inference.BackendModeCompletion, true + } else if strings.HasSuffix(path, "/v1/images/generations") { + // OpenAI Images API - image generation mode + return inference.BackendModeImageGeneration, true } return inference.BackendMode(0), false } diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index e36b6b4a..4f137879 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -66,6 +66,9 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { "POST " + inference.InferencePrefix + "/rerank", "POST " + inference.InferencePrefix + "/{backend}/score", "POST " + inference.InferencePrefix + "/score", + // Image generation routes + "POST " + inference.InferencePrefix + "/{backend}/v1/images/generations", + "POST " + inference.InferencePrefix + "/v1/images/generations", } // Anthropic Messages API routes diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 4a40ee09..ddfe582f 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -288,6 +288,7 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int { l.evictRunner(unload.Backend, modelID, inference.BackendModeCompletion) l.evictRunner(unload.Backend, modelID, inference.BackendModeEmbedding) l.evictRunner(unload.Backend, modelID, inference.BackendModeReranking) + l.evictRunner(unload.Backend, modelID, inference.BackendModeImageGeneration) } return len(l.runners) } @@ -425,8 +426,8 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string if runnerConfig.Speculative != nil && runnerConfig.Speculative.DraftModel != "" { draftModelID = l.modelManager.ResolveID(runnerConfig.Speculative.DraftModel) } - } else if mode == inference.BackendModeReranking { - // For reranking mode, fallback to completion config if specific config is not found. + } else if (mode == inference.BackendModeReranking) || (mode == inference.BackendModeImageGeneration) { + // For reranking or image-generation mode, fallback to completion config if specific config is not found. if rc, ok := l.runnerConfigs[makeConfigKey(backendName, modelID, inference.BackendModeCompletion)]; ok { runnerConfig = &rc if runnerConfig.Speculative != nil && runnerConfig.Speculative.DraftModel != "" { diff --git a/python/diffusers_server/__init__.py b/python/diffusers_server/__init__.py new file mode 100644 index 00000000..50af3f43 --- /dev/null +++ b/python/diffusers_server/__init__.py @@ -0,0 +1,2 @@ +# Diffusers Server for Docker Model Runner +# Provides OpenAI Images API compatible endpoint for Stable Diffusion models diff --git a/python/diffusers_server/server.py b/python/diffusers_server/server.py new file mode 100644 index 00000000..1db699b2 --- /dev/null +++ b/python/diffusers_server/server.py @@ -0,0 +1,374 @@ +""" +Diffusers Server for Docker Model Runner + +A FastAPI-based server that provides OpenAI Images API compatible endpoints +for Stable Diffusion and other diffusion models using the Hugging Face diffusers library. +""" + +import argparse +import base64 +import io +import logging +import os +import time +from typing import Optional, List, Literal + +import torch +from diffusers import DiffusionPipeline, StableDiffusionPipeline, AutoPipelineForText2Image +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field +import uvicorn + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = FastAPI(title="Diffusers Server", description="OpenAI Images API compatible server for diffusion models") + +# Global pipeline instance +pipeline: Optional[DiffusionPipeline] = None +current_model: Optional[str] = None +served_model_name: Optional[str] = None + + +class ImageGenerationRequest(BaseModel): + """Request model for image generation (OpenAI Images API compatible)""" + model: str = Field(..., description="The model to use for image generation") + prompt: str = Field(..., description="A text description of the desired image(s)") + n: int = Field(default=1, ge=1, le=10, description="The number of images to generate") + size: str = Field(default="512x512", description="The size of the generated images") + response_format: Literal["url", "b64_json"] = Field(default="b64_json", description="The format of the generated images") + quality: Optional[str] = Field(default="standard", description="The quality of the image") + style: Optional[str] = Field(default=None, description="The style of the generated images") + negative_prompt: Optional[str] = Field(default=None, description="Text to avoid in generation") + num_inference_steps: int = Field(default=50, ge=1, le=150, description="Number of denoising steps") + guidance_scale: float = Field(default=7.5, ge=1.0, le=20.0, description="Guidance scale for generation") + seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") + + +class ImageData(BaseModel): + """Single image in the response""" + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + +class ImageGenerationResponse(BaseModel): + """Response model for image generation (OpenAI Images API compatible)""" + created: int + data: List[ImageData] + + +def parse_size(size: str) -> tuple[int, int]: + """Parse size string like '512x512' into (width, height) tuple""" + try: + parts = size.lower().split('x') + if len(parts) != 2: + raise ValueError(f"Invalid size format: {size}") + width = int(parts[0]) + height = int(parts[1]) + return width, height + except (ValueError, IndexError) as e: + raise ValueError(f"Invalid size format '{size}'. Expected format like '512x512': {e}") + + +def is_dduf_file(path: str) -> bool: + """Check if the given path is a DDUF file""" + return path.lower().endswith('.dduf') and os.path.isfile(path) + + +def load_model_from_dduf(dduf_path: str, device: str, dtype: torch.dtype) -> DiffusionPipeline: + """Load a diffusion model from a DDUF (Diffusers Unified Format) file""" + logger.info(f"Loading model from DDUF file: {dduf_path}") + + try: + # Get the directory and filename from the DDUF path + # DiffusionPipeline.from_pretrained() expects: + # - First arg: directory containing the DDUF file (or repo ID for HF Hub) + # - dduf_file: the filename (string) of the DDUF file within that directory + dduf_dir = os.path.dirname(dduf_path) + dduf_filename = os.path.basename(dduf_path) + + logger.info(f"Using directory: {dduf_dir}") + logger.info(f"Using DDUF filename: {dduf_filename}") + + # Load the pipeline from the DDUF file + # The diffusers library will internally read the DDUF file and extract components + pipe = DiffusionPipeline.from_pretrained( + dduf_dir, + dduf_file=dduf_filename, + torch_dtype=dtype, + ) + + pipe = pipe.to(device) + logger.info(f"Model loaded successfully from DDUF on {device}") + return pipe + + except Exception as e: + logger.exception("Error loading DDUF file") + raise RuntimeError(f"Failed to load DDUF file: {e}") + + +def load_model(model_path: str) -> DiffusionPipeline: + """Load a diffusion model from the given path, DDUF file, or HuggingFace model ID""" + global pipeline, current_model + + if pipeline is not None and current_model == model_path: + logger.info(f"Model {model_path} already loaded") + return pipeline + + logger.info(f"Loading model: {model_path}") + + # Determine device + if torch.cuda.is_available(): + device = "cuda" + dtype = torch.float16 + logger.info("Using CUDA device with float16") + elif torch.backends.mps.is_available(): + device = "mps" + dtype = torch.float16 + logger.info("Using MPS device (Apple Silicon) with float16") + else: + device = "cpu" + dtype = torch.float32 + logger.info("Using CPU device with float32") + + # Check if this is a DDUF file + if is_dduf_file(model_path): + pipeline = load_model_from_dduf(model_path, device, dtype) + current_model = model_path + return pipeline + + # Check if this is a directory containing a model + if os.path.isdir(model_path): + logger.info(f"Loading model from directory: {model_path}") + + try: + # Try to load using AutoPipelineForText2Image which handles most model types + pipeline = AutoPipelineForText2Image.from_pretrained( + model_path, + torch_dtype=dtype, + safety_checker=None, # Disable safety checker for performance + requires_safety_checker=False, + ) + except Exception as e: + logger.warning(f"AutoPipelineForText2Image failed: {e}, trying StableDiffusionPipeline") + try: + pipeline = StableDiffusionPipeline.from_pretrained( + model_path, + torch_dtype=dtype, + safety_checker=None, + requires_safety_checker=False, + ) + except Exception as e2: + logger.warning(f"StableDiffusionPipeline failed: {e2}, trying generic DiffusionPipeline") + pipeline = DiffusionPipeline.from_pretrained( + model_path, + torch_dtype=dtype, + ) + + pipeline = pipeline.to(device) + + # Enable memory efficient attention if available + if hasattr(pipeline, 'enable_attention_slicing'): + pipeline.enable_attention_slicing() + + current_model = model_path + logger.info(f"Model loaded successfully on {device}") + return pipeline + + +def generate_images( + prompt: str, + n: int = 1, + width: int = 512, + height: int = 512, + negative_prompt: Optional[str] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + seed: Optional[int] = None, +) -> List[bytes]: + """Generate images using the loaded pipeline""" + global pipeline + + if pipeline is None: + raise RuntimeError("No model loaded") + + # Set seed for reproducibility + generator = None + if seed is not None: + if torch.cuda.is_available(): + generator = torch.Generator(device="cuda").manual_seed(seed) + elif torch.backends.mps.is_available(): + generator = torch.Generator(device="mps").manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + logger.info(f"Generating {n} image(s) with prompt: {prompt[:100]}...") + + # Generate images + images = [] + for i in range(n): + # If we have a seed, increment it for each image to get different but reproducible results + current_generator = None + if generator is not None and seed is not None: + if torch.cuda.is_available(): + current_generator = torch.Generator(device="cuda").manual_seed(seed + i) + elif torch.backends.mps.is_available(): + current_generator = torch.Generator(device="mps").manual_seed(seed + i) + else: + current_generator = torch.Generator().manual_seed(seed + i) + + result = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=current_generator, + ) + + image = result.images[0] + + # Convert to PNG bytes + buffer = io.BytesIO() + image.save(buffer, format="PNG") + images.append(buffer.getvalue()) + + logger.info(f"Generated {len(images)} image(s)") + return images + + +@app.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "healthy", "model_loaded": current_model is not None} + + +@app.get("/v1/models") +async def list_models(): + """List available models (OpenAI API compatible)""" + models = [] + if served_model_name: + models.append({ + "id": served_model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "diffusers", + }) + if current_model and current_model != served_model_name: + models.append({ + "id": current_model, + "object": "model", + "created": int(time.time()), + "owned_by": "diffusers", + }) + return {"object": "list", "data": models} + + +@app.post("/v1/images/generations", response_model=ImageGenerationResponse) +async def create_image(request: ImageGenerationRequest): + """Generate images from a prompt (OpenAI Images API compatible)""" + global pipeline + + # Check if the requested model matches + requested_model = request.model + if served_model_name and requested_model != served_model_name and requested_model != current_model: + raise HTTPException( + status_code=421, + detail=f"Model '{requested_model}' not loaded. Current model: {served_model_name or current_model}" + ) + + if pipeline is None: + raise HTTPException(status_code=503, detail="No model loaded. Server is not ready.") + + try: + # Parse size + width, height = parse_size(request.size) + + # Generate images + image_bytes_list = generate_images( + prompt=request.prompt, + n=request.n, + width=width, + height=height, + negative_prompt=request.negative_prompt, + num_inference_steps=request.num_inference_steps, + guidance_scale=request.guidance_scale, + seed=request.seed, + ) + + # Format response + data = [] + for img_bytes in image_bytes_list: + if request.response_format == "b64_json": + b64_str = base64.b64encode(img_bytes).decode("utf-8") + data.append(ImageData(b64_json=b64_str)) + else: + # URL format not supported in this implementation + raise HTTPException( + status_code=400, + detail="URL response format is not supported. Use 'b64_json' instead." + ) + + return ImageGenerationResponse( + created=int(time.time()), + data=data + ) + + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception("Error generating image") + raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}") + + +@app.on_event("startup") +async def startup_event(): + """Startup event handler""" + logger.info("Diffusers server starting up...") + if current_model: + logger.info(f"Model path: {current_model}") + + +def main(): + """Main entry point for the diffusers server""" + parser = argparse.ArgumentParser(description="Diffusers Server - OpenAI Images API compatible server") + parser.add_argument("--model-path", type=str, required=True, help="Path to the diffusion model, DDUF file, or HuggingFace model ID") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--served-model-name", type=str, default=None, help="Name to serve the model as") + + args = parser.parse_args() + + global served_model_name + served_model_name = args.served_model_name or args.model_path + + try: + # Load the model at startup + load_model(args.model_path) + + # Start the server + logger.info(f"Starting server on {args.host}:{args.port}") + uvicorn.run(app, host=args.host, port=args.port, log_level="info") + except Exception as e: + # Extract the root cause error message for cleaner output + error_msg = str(e) + # If this is a chained exception, try to get the original cause + root_cause = e + while root_cause.__cause__ is not None: + root_cause = root_cause.__cause__ + if root_cause is not e: + error_msg = str(root_cause) + + # Print a clean, single-line error message that can be easily parsed + # This format is recognized by the Go backend for better error reporting + import sys + print(f"DIFFUSERS_ERROR: {error_msg}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/docker-run.sh b/scripts/docker-run.sh index a8b064b9..59fbaea6 100755 --- a/scripts/docker-run.sh +++ b/scripts/docker-run.sh @@ -1,9 +1,13 @@ #!/bin/bash add_accelerators() { - # Add NVIDIA GPU support for CUDA variants - if [[ "${DOCKER_IMAGE-}" == *"-cuda" ]]; then - args+=("--gpus" "all" "--runtime=nvidia") + # Add NVIDIA GPU support for CUDA variants and GPU-accelerated backends + if [[ "${DOCKER_IMAGE-}" == *"-cuda" ]] || \ + [[ "${DOCKER_IMAGE-}" == *"-diffusers" ]] || \ + [[ "${DOCKER_IMAGE-}" == *"-sglang" ]]; then + if docker info -f '{{range $k, $v := .Runtimes}}{{$k}}{{"\n"}}{{end}}' 2>/dev/null | grep -qx "nvidia"; then + args+=("--gpus" "all" "--runtime=nvidia") + fi fi # Add GPU/accelerator devices if present @@ -79,4 +83,3 @@ main() { } main "$@" - From 9f1a80a4930bf197f9c53d6b5c268905e28cb345 Mon Sep 17 00:00:00 2001 From: Dorin Geman Date: Mon, 19 Jan 2026 11:23:13 +0200 Subject: [PATCH 16/17] chore(go): 1.25.6 Signed-off-by: Dorin Geman --- .github/workflows/ci.yml | 4 ++-- .github/workflows/integration-test.yml | 2 +- .github/workflows/release.yml | 2 +- Dockerfile | 2 +- Makefile | 2 +- README.md | 2 +- cmd/cli/Dockerfile | 4 ++-- cmd/cli/go.mod | 2 +- go.mod | 2 +- go.work | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b18a33c5..cbc09cb9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: - go-version: 1.24.3 + go-version: 1.25.6 cache: true - name: Install golangci-lint @@ -45,7 +45,7 @@ jobs: - name: Set up Go uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: - go-version: 1.24.3 + go-version: 1.25.6 cache: true - name: Check go mod tidy diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 80b4888b..afadc87d 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: - go-version: 1.24.3 + go-version: 1.25.6 cache: true - name: Set up Docker Buildx diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9a2d0b4f..9466c94a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -46,7 +46,7 @@ jobs: - name: Set up Go uses: actions/setup-go@7a3fe6cf4cb3a834922a1244abfce67bcef6a0c5 with: - go-version: 1.24.3 + go-version: 1.25.6 cache: true - name: Run tests diff --git a/Dockerfile b/Dockerfile index e6738a69..b20d796e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1 -ARG GO_VERSION=1.24 +ARG GO_VERSION=1.25 ARG LLAMA_SERVER_VERSION=latest ARG LLAMA_SERVER_VARIANT=cpu ARG LLAMA_BINARY_PATH=/com.docker.llama-server.native.linux.${LLAMA_SERVER_VARIANT}.${TARGETARCH} diff --git a/Makefile b/Makefile index ff11138d..d39f2437 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # Project variables APP_NAME := model-runner -GO_VERSION := 1.24.3 +GO_VERSION := 1.25.6 LLAMA_SERVER_VERSION := latest LLAMA_SERVER_VARIANT := cpu BASE_IMAGE := ubuntu:24.04 diff --git a/README.md b/README.md index 0c801b46..b0356690 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ https://docs.docker.com/ai/model-runner/get-started/ Before building from source, ensure you have the following installed: -- **Go 1.24+** - Required for building both model-runner and model-cli +- **Go 1.25+** - Required for building both model-runner and model-cli - **Git** - For cloning repositories - **Make** - For using the provided Makefiles - **Docker** (optional) - For building and running containerized versions diff --git a/cmd/cli/Dockerfile b/cmd/cli/Dockerfile index d45e7350..6a64d099 100644 --- a/cmd/cli/Dockerfile +++ b/cmd/cli/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 -ARG GO_VERSION=1.24 -ARG ALPINE_VERSION=3.21 +ARG GO_VERSION=1.25 +ARG ALPINE_VERSION=3.23 ARG DOCS_FORMATS="md,yaml" diff --git a/cmd/cli/go.mod b/cmd/cli/go.mod index 32620942..1c2f7395 100644 --- a/cmd/cli/go.mod +++ b/cmd/cli/go.mod @@ -1,6 +1,6 @@ module github.com/docker/model-runner/cmd/cli -go 1.24.3 +go 1.25.6 require ( github.com/charmbracelet/glamour v0.10.0 diff --git a/go.mod b/go.mod index a9c00dba..3833fff3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/docker/model-runner -go 1.24.3 +go 1.25.6 require ( github.com/containerd/containerd/v2 v2.2.1 diff --git a/go.work b/go.work index 32b1a003..1550b7c0 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.24.3 +go 1.25.6 use ( . From e2f358aa95b50b85f05d3ddaac8d6b106a9b7838 Mon Sep 17 00:00:00 2001 From: harshil Date: Mon, 19 Jan 2026 12:35:25 +0000 Subject: [PATCH 17/17] Added layer wise update on push --- cmd/cli/desktop/api.go | 1 + cmd/cli/desktop/progress.go | 3 +- pkg/distribution/huggingface/downloader.go | 4 +- pkg/distribution/huggingface/model.go | 10 ++--- .../internal/progress/reporter.go | 39 +++++++++---------- .../internal/progress/reporter_test.go | 4 +- pkg/distribution/internal/store/store.go | 2 +- pkg/distribution/oci/remote/remote.go | 2 +- pkg/distribution/tarball/target.go | 2 +- 9 files changed, 33 insertions(+), 34 deletions(-) diff --git a/cmd/cli/desktop/api.go b/cmd/cli/desktop/api.go index f2c85a1a..166e6047 100644 --- a/cmd/cli/desktop/api.go +++ b/cmd/cli/desktop/api.go @@ -7,6 +7,7 @@ type ProgressMessage struct { Total uint64 `json:"total"` Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current Layer Layer `json:"layer"` // Current layer information + Mode string `json:"mode"` // "push", "pull" } type Layer struct { diff --git a/cmd/cli/desktop/progress.go b/cmd/cli/desktop/progress.go index 3dc2b714..364c311b 100644 --- a/cmd/cli/desktop/progress.go +++ b/cmd/cli/desktop/progress.go @@ -12,7 +12,6 @@ import ( "github.com/docker/docker/pkg/jsonmessage" "github.com/docker/go-units" "github.com/docker/model-runner/cmd/cli/pkg/standalone" - "github.com/docker/model-runner/pkg/distribution/oci" ) // DisplayProgress displays progress messages from a model pull/push operation @@ -157,7 +156,7 @@ func writeDockerProgress(w io.Writer, msg *ProgressMessage, layerStatus map[stri } // Detect if this is a push operation based on the sentinel layer ID - isPush := layerID == oci.UploadingLayerID + isPush := msg.Mode == "push" // Determine status based on progress var status string diff --git a/pkg/distribution/huggingface/downloader.go b/pkg/distribution/huggingface/downloader.go index 0bb53731..df451d46 100644 --- a/pkg/distribution/huggingface/downloader.go +++ b/pkg/distribution/huggingface/downloader.go @@ -184,7 +184,7 @@ func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile // Write final progress for this file (100% complete) if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID) + _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID, "") } return localPath, nil @@ -208,7 +208,7 @@ func (pr *progressReader) Read(p []byte) (n int, err error) { // Report progress periodically (every 1MB or when complete) if pr.progressWriter != nil && (pr.bytesRead-pr.lastReported >= progress.MinBytesForUpdate || pr.bytesRead == pr.fileSize) { - _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID) + _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID, "") pr.lastReported = pr.bytesRead } } diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go index 37835230..aa458498 100644 --- a/pkg/distribution/huggingface/model.go +++ b/pkg/distribution/huggingface/model.go @@ -21,7 +21,7 @@ import ( func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { // List files in the repository if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "", "") } files, err := client.ListFiles(ctx, repo, revision) @@ -47,9 +47,9 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, if progressWriter != nil { if tag == "" || tag == "latest" || tag == "main" { - _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "", "") } else { - _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "", "") } } } @@ -64,7 +64,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, totalSize := TotalSize(allFiles) msg := fmt.Sprintf("Found %d files (%.2f MB total)", len(allFiles), float64(totalSize)/1024/1024) - _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "") + _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "", "") } // Step 3: Download all files @@ -76,7 +76,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, // Step 4: Build the model artifact if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "", "") } model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir) diff --git a/pkg/distribution/internal/progress/reporter.go b/pkg/distribution/internal/progress/reporter.go index b5e0cf74..15d70b7b 100644 --- a/pkg/distribution/internal/progress/reporter.go +++ b/pkg/distribution/internal/progress/reporter.go @@ -29,6 +29,7 @@ type Message struct { Total uint64 `json:"total"` Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current Layer Layer `json:"layer"` // Current layer information + Mode string `json:"mode"` // "push", "pull" } type Reporter struct { @@ -39,6 +40,7 @@ type Reporter struct { format progressF layer oci.Layer imageSize uint64 + mode string } type progressF func(update oci.Update) string @@ -51,7 +53,7 @@ func PushMsg(update oci.Update) string { return fmt.Sprintf("Uploaded: %.2f MB", float64(update.Complete)/1024/1024) } -func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer) *Reporter { +func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer, mode string) *Reporter { return &Reporter{ out: w, progress: make(chan oci.Update, 1), @@ -59,6 +61,7 @@ func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci format: msgF, layer: layer, imageSize: safeUint64(imageSize), + mode: mode, } } @@ -84,31 +87,26 @@ func (r *Reporter) Updates() chan<- oci.Update { now := time.Now() var layerSize uint64 var layerID string - if r.layer != nil { // In case of Pull - id, err := r.layer.DiffID() - if err != nil { - r.err = err - continue - } - layerID = id.String() - size, err := r.layer.Size() - if err != nil { - r.err = err - continue - } - layerSize = safeUint64(size) - } else { // In case of Push there is no layer yet - // Use imageSize as layer is not known at this point - layerSize = r.imageSize - layerID = oci.UploadingLayerID // Fake ID for push operations to enable progress display + id, err := r.layer.DiffID() + if err != nil { + r.err = err + continue + } + layerID = id.String() + size, err := r.layer.Size() + if err != nil { + r.err = err + continue } + layerSize = safeUint64(size) + incrementalBytes := p.Complete - lastComplete // Only update if enough time has passed or enough bytes downloaded or finished if now.Sub(lastUpdate) >= UpdateInterval || incrementalBytes >= MinBytesForUpdate || safeUint64(p.Complete) == layerSize { - if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID); err != nil { + if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID, r.mode); err != nil { r.err = err } lastUpdate = now @@ -127,7 +125,7 @@ func (r *Reporter) Wait() error { } // WriteProgress writes a progress update message -func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string) error { +func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string, mode string) error { return write(w, Message{ Type: "progress", Message: msg, @@ -138,6 +136,7 @@ func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64 Size: layerSize, Current: current, }, + Mode: mode, }) } diff --git a/pkg/distribution/internal/progress/reporter_test.go b/pkg/distribution/internal/progress/reporter_test.go index fdde9dcd..6f823148 100644 --- a/pkg/distribution/internal/progress/reporter_test.go +++ b/pkg/distribution/internal/progress/reporter_test.go @@ -59,7 +59,7 @@ func TestMessages(t *testing.T) { layer1 := newMockLayer(2016) layer2 := newMockLayer(1) - err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID) + err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID, "") if err != nil { t.Fatalf("Failed to write progress message: %v", err) } @@ -224,7 +224,7 @@ func TestProgressEmissionScenarios(t *testing.T) { var buf bytes.Buffer layer := newMockLayer(tt.layerSize) - reporter := NewProgressReporter(&buf, PullMsg, 0, layer) + reporter := NewProgressReporter(&buf, PullMsg, 0, layer, "") updates := reporter.Updates() // Send updates with delays diff --git a/pkg/distribution/internal/store/store.go b/pkg/distribution/internal/store/store.go index 13558dc5..e3f0d491 100644 --- a/pkg/distribution/internal/store/store.go +++ b/pkg/distribution/internal/store/store.go @@ -325,7 +325,7 @@ func (s *LocalStore) Write(mdl oci.Image, tags []string, w io.Writer, opts ...Wr var pr *progress.Reporter var progressChan chan<- oci.Update if safeWriter != nil { - pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l) + pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l, "pull") progressChan = pr.Updates() } diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index f6449ef2..aecd988f 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -769,7 +769,7 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) var pr *progress.Reporter var progressChan chan<- oci.Update if safeWriter != nil { - pr = progress.NewProgressReporter(safeWriter, progress.PushMsg, size, layer) + pr = progress.NewProgressReporter(safeWriter, progress.PushMsg, size, layer, "push") progressChan = pr.Updates() } diff --git a/pkg/distribution/tarball/target.go b/pkg/distribution/tarball/target.go index 05934b53..b4961c5d 100644 --- a/pkg/distribution/tarball/target.go +++ b/pkg/distribution/tarball/target.go @@ -117,7 +117,7 @@ func (t *Target) addLayer(layer oci.Layer, tw *tar.Writer, progressWriter io.Wri if progressWriter != nil { pr = progress.NewProgressReporter(progressWriter, func(update oci.Update) string { return fmt.Sprintf("Transferred: %.2f MB", float64(update.Complete)/1024/1024) - }, imageSize, layer) + }, imageSize, layer, "") progressChan = pr.Updates() defer func() { close(progressChan)