diff --git a/cmd/cli/desktop/api.go b/cmd/cli/desktop/api.go index f2c85a1a..166e6047 100644 --- a/cmd/cli/desktop/api.go +++ b/cmd/cli/desktop/api.go @@ -7,6 +7,7 @@ type ProgressMessage struct { Total uint64 `json:"total"` Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current Layer Layer `json:"layer"` // Current layer information + Mode string `json:"mode"` // "push", "pull" } type Layer struct { diff --git a/cmd/cli/desktop/progress.go b/cmd/cli/desktop/progress.go index 3dc2b714..364c311b 100644 --- a/cmd/cli/desktop/progress.go +++ b/cmd/cli/desktop/progress.go @@ -12,7 +12,6 @@ import ( "github.com/docker/docker/pkg/jsonmessage" "github.com/docker/go-units" "github.com/docker/model-runner/cmd/cli/pkg/standalone" - "github.com/docker/model-runner/pkg/distribution/oci" ) // DisplayProgress displays progress messages from a model pull/push operation @@ -157,7 +156,7 @@ func writeDockerProgress(w io.Writer, msg *ProgressMessage, layerStatus map[stri } // Detect if this is a push operation based on the sentinel layer ID - isPush := layerID == oci.UploadingLayerID + isPush := msg.Mode == "push" // Determine status based on progress var status string diff --git a/pkg/distribution/huggingface/downloader.go b/pkg/distribution/huggingface/downloader.go index 0bb53731..df451d46 100644 --- a/pkg/distribution/huggingface/downloader.go +++ b/pkg/distribution/huggingface/downloader.go @@ -184,7 +184,7 @@ func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile // Write final progress for this file (100% complete) if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID) + _ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID, "") } return localPath, nil @@ -208,7 +208,7 @@ func (pr *progressReader) Read(p []byte) (n int, err error) { // Report progress periodically (every 1MB or when complete) if pr.progressWriter != nil && (pr.bytesRead-pr.lastReported >= progress.MinBytesForUpdate || pr.bytesRead == pr.fileSize) { - _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID) + _ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID, "") pr.lastReported = pr.bytesRead } } diff --git a/pkg/distribution/huggingface/model.go b/pkg/distribution/huggingface/model.go index 37835230..aa458498 100644 --- a/pkg/distribution/huggingface/model.go +++ b/pkg/distribution/huggingface/model.go @@ -21,7 +21,7 @@ import ( func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) { // List files in the repository if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "", "") } files, err := client.ListFiles(ctx, repo, revision) @@ -47,9 +47,9 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, if progressWriter != nil { if tag == "" || tag == "latest" || tag == "main" { - _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "", "") } else { - _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "", "") } } } @@ -64,7 +64,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, totalSize := TotalSize(allFiles) msg := fmt.Sprintf("Found %d files (%.2f MB total)", len(allFiles), float64(totalSize)/1024/1024) - _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "") + _ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "", "") } // Step 3: Download all files @@ -76,7 +76,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, // Step 4: Build the model artifact if progressWriter != nil { - _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "") + _ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "", "") } model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir) diff --git a/pkg/distribution/internal/progress/reporter.go b/pkg/distribution/internal/progress/reporter.go index b5e0cf74..15d70b7b 100644 --- a/pkg/distribution/internal/progress/reporter.go +++ b/pkg/distribution/internal/progress/reporter.go @@ -29,6 +29,7 @@ type Message struct { Total uint64 `json:"total"` Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current Layer Layer `json:"layer"` // Current layer information + Mode string `json:"mode"` // "push", "pull" } type Reporter struct { @@ -39,6 +40,7 @@ type Reporter struct { format progressF layer oci.Layer imageSize uint64 + mode string } type progressF func(update oci.Update) string @@ -51,7 +53,7 @@ func PushMsg(update oci.Update) string { return fmt.Sprintf("Uploaded: %.2f MB", float64(update.Complete)/1024/1024) } -func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer) *Reporter { +func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer, mode string) *Reporter { return &Reporter{ out: w, progress: make(chan oci.Update, 1), @@ -59,6 +61,7 @@ func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci format: msgF, layer: layer, imageSize: safeUint64(imageSize), + mode: mode, } } @@ -84,31 +87,26 @@ func (r *Reporter) Updates() chan<- oci.Update { now := time.Now() var layerSize uint64 var layerID string - if r.layer != nil { // In case of Pull - id, err := r.layer.DiffID() - if err != nil { - r.err = err - continue - } - layerID = id.String() - size, err := r.layer.Size() - if err != nil { - r.err = err - continue - } - layerSize = safeUint64(size) - } else { // In case of Push there is no layer yet - // Use imageSize as layer is not known at this point - layerSize = r.imageSize - layerID = oci.UploadingLayerID // Fake ID for push operations to enable progress display + id, err := r.layer.DiffID() + if err != nil { + r.err = err + continue + } + layerID = id.String() + size, err := r.layer.Size() + if err != nil { + r.err = err + continue } + layerSize = safeUint64(size) + incrementalBytes := p.Complete - lastComplete // Only update if enough time has passed or enough bytes downloaded or finished if now.Sub(lastUpdate) >= UpdateInterval || incrementalBytes >= MinBytesForUpdate || safeUint64(p.Complete) == layerSize { - if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID); err != nil { + if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID, r.mode); err != nil { r.err = err } lastUpdate = now @@ -127,7 +125,7 @@ func (r *Reporter) Wait() error { } // WriteProgress writes a progress update message -func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string) error { +func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string, mode string) error { return write(w, Message{ Type: "progress", Message: msg, @@ -138,6 +136,7 @@ func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64 Size: layerSize, Current: current, }, + Mode: mode, }) } diff --git a/pkg/distribution/internal/progress/reporter_test.go b/pkg/distribution/internal/progress/reporter_test.go index fdde9dcd..6f823148 100644 --- a/pkg/distribution/internal/progress/reporter_test.go +++ b/pkg/distribution/internal/progress/reporter_test.go @@ -59,7 +59,7 @@ func TestMessages(t *testing.T) { layer1 := newMockLayer(2016) layer2 := newMockLayer(1) - err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID) + err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID, "") if err != nil { t.Fatalf("Failed to write progress message: %v", err) } @@ -224,7 +224,7 @@ func TestProgressEmissionScenarios(t *testing.T) { var buf bytes.Buffer layer := newMockLayer(tt.layerSize) - reporter := NewProgressReporter(&buf, PullMsg, 0, layer) + reporter := NewProgressReporter(&buf, PullMsg, 0, layer, "") updates := reporter.Updates() // Send updates with delays diff --git a/pkg/distribution/internal/store/store.go b/pkg/distribution/internal/store/store.go index 13558dc5..e3f0d491 100644 --- a/pkg/distribution/internal/store/store.go +++ b/pkg/distribution/internal/store/store.go @@ -325,7 +325,7 @@ func (s *LocalStore) Write(mdl oci.Image, tags []string, w io.Writer, opts ...Wr var pr *progress.Reporter var progressChan chan<- oci.Update if safeWriter != nil { - pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l) + pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l, "pull") progressChan = pr.Updates() } diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index 6fb1dcf1..aecd988f 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -693,8 +693,21 @@ func (l *remoteLayer) MediaType() (oci.MediaType, error) { return l.desc.MediaType, nil } +// syncWriter is a thread-safe wrapper around io.Writer for concurrent writes +type syncWriter struct { + w io.Writer + mu sync.Mutex +} + +// Write implements io.Writer interface with mutex protection +func (sw *syncWriter) Write(p []byte) (n int, err error) { + sw.mu.Lock() + defer sw.mu.Unlock() + return sw.w.Write(p) +} + // Write pushes an image to a registry. -func Write(ref reference.Reference, img oci.Image, opts ...Option) error { +func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) error { o := makeOptions(opts...) // Pre-authorize with push scope to ensure we have the right permissions @@ -724,8 +737,14 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { totalSize += size } - var completed int64 + // Create a thread-safe writer wrapper for concurrent progress reporting + var safeWriter io.Writer + if w != nil { + safeWriter = &syncWriter{w: w} + } + for _, layer := range layers { + var completed int64 digest, err := layer.Digest() if err != nil { return fmt.Errorf("getting layer digest: %w", err) @@ -747,6 +766,13 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { Size: size, } + var pr *progress.Reporter + var progressChan chan<- oci.Update + if safeWriter != nil { + pr = progress.NewProgressReporter(safeWriter, progress.PushMsg, size, layer, "push") + progressChan = pr.Updates() + } + rc, err := layer.Compressed() if err != nil { return fmt.Errorf("getting layer content: %w", err) @@ -759,29 +785,33 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { // If already exists, continue if errdefs.IsAlreadyExists(err) || strings.Contains(err.Error(), "already exists") { completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } + closeProgress(progressChan) + closeReporter(pr) } continue } - closeProgress(o.progress) + closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("pushing layer: %w", err) } // Wrap the reader with progress tracking to report incremental upload progress // Uses the shared progress.Reader from internal/progress package var reader io.Reader = rc - if o.progress != nil { - reader = progress.NewReaderWithOffset(rc, o.progress, completed) + if progressChan != nil { + reader = progress.NewReaderWithOffset(rc, progressChan, completed) } if _, err := io.Copy(cw, reader); err != nil { cw.Close() rc.Close() - closeProgress(o.progress) + closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("writing layer: %w", err) } @@ -789,27 +819,30 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw.Close() rc.Close() if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) + closeProgress(progressChan) + closeReporter(pr) return fmt.Errorf("committing layer: %w", err) } // If it already exists, we still want to update progress completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } } } else { // Successfully committed, update progress completed += size - if o.progress != nil { - o.progress <- oci.Update{ + if progressChan != nil { + progressChan <- oci.Update{ Complete: completed, - Total: totalSize, + Total: size, } } } + closeProgress(progressChan) + closeReporter(pr) cw.Close() rc.Close() } @@ -834,20 +867,17 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw, err := pusher.Push(o.ctx, configDesc) if err != nil { if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("pushing config: %w", err) } // If it already exists, we don't have a writer to close, just continue } else { if _, err := cw.Write(rawConfig); err != nil { cw.Close() - closeProgress(o.progress) return fmt.Errorf("writing config: %w", err) } if err := cw.Commit(o.ctx, int64(len(rawConfig)), configDesc.Digest); err != nil { cw.Close() if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("committing config: %w", err) } } @@ -857,19 +887,16 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { // Push manifest rawManifest, err := img.RawManifest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest: %w", err) } manifest, err := img.Manifest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest object: %w", err) } manifestDigest, err := img.Digest() if err != nil { - closeProgress(o.progress) return fmt.Errorf("getting manifest digest: %w", err) } @@ -882,24 +909,18 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { cw, err = pusher.Push(o.ctx, manifestDesc) if err != nil { if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { - closeProgress(o.progress) return fmt.Errorf("pushing manifest: %w", err) } - // If it already exists, we don't have a writer to close, just continue - // If it already exists, we still want to close progress and return success - closeProgress(o.progress) return nil } if _, err := cw.Write(rawManifest); err != nil { cw.Close() - closeProgress(o.progress) return fmt.Errorf("writing manifest: %w", err) } if err := cw.Commit(o.ctx, int64(len(rawManifest)), manifestDesc.Digest); err != nil { cw.Close() - closeProgress(o.progress) if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") { return fmt.Errorf("committing manifest: %w", err) } @@ -908,9 +929,6 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error { } cw.Close() - // Close progress channel to signal completion - closeProgress(o.progress) - return nil } @@ -921,6 +939,15 @@ func closeProgress(ch chan<- oci.Update) { } } +// closeReporter safely closes the progress reporter if not nil +func closeReporter(pr *progress.Reporter) { + if pr != nil { + if waitErr := pr.Wait(); waitErr != nil { + fmt.Printf("reporter finished with non-fatal error: %v\n", waitErr) + } + } +} + // Ensure remoteImage is cleaned up properly func (i *remoteImage) Close() error { // The local content store doesn't expose its root path, so cleanup diff --git a/pkg/distribution/registry/client.go b/pkg/distribution/registry/client.go index 82fca7a8..493f6712 100644 --- a/pkg/distribution/registry/client.go +++ b/pkg/distribution/registry/client.go @@ -9,7 +9,6 @@ import ( "strings" "sync" - "github.com/docker/model-runner/pkg/distribution/internal/progress" "github.com/docker/model-runner/pkg/distribution/oci" "github.com/docker/model-runner/pkg/distribution/oci/authn" "github.com/docker/model-runner/pkg/distribution/oci/reference" @@ -264,15 +263,12 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW } imageSize += size } - pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil) - defer pr.Wait() // Set up authentication options authOpts := []remote.Option{ remote.WithContext(ctx), remote.WithTransport(t.transport), remote.WithUserAgent(t.userAgent), - remote.WithProgress(pr.Updates()), remote.WithPlainHTTP(t.plainHTTP), } @@ -283,7 +279,7 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW authOpts = append(authOpts, remote.WithAuthFromKeychain(t.keychain)) } - if err := remote.Write(t.reference, model, authOpts...); err != nil { + if err := remote.Write(t.reference, model, progressWriter, authOpts...); err != nil { return fmt.Errorf("write to registry %q: %w", t.reference.String(), err) } return nil diff --git a/pkg/distribution/tarball/target.go b/pkg/distribution/tarball/target.go index 05934b53..b4961c5d 100644 --- a/pkg/distribution/tarball/target.go +++ b/pkg/distribution/tarball/target.go @@ -117,7 +117,7 @@ func (t *Target) addLayer(layer oci.Layer, tw *tar.Writer, progressWriter io.Wri if progressWriter != nil { pr = progress.NewProgressReporter(progressWriter, func(update oci.Update) string { return fmt.Sprintf("Transferred: %.2f MB", float64(update.Complete)/1024/1024) - }, imageSize, layer) + }, imageSize, layer, "") progressChan = pr.Updates() defer func() { close(progressChan)