Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/cli/desktop/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type ProgressMessage struct {
Total uint64 `json:"total"`
Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current
Layer Layer `json:"layer"` // Current layer information
Mode string `json:"mode"` // "push", "pull"
}

type Layer struct {
Expand Down
3 changes: 1 addition & 2 deletions cmd/cli/desktop/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/go-units"
"github.com/docker/model-runner/cmd/cli/pkg/standalone"
"github.com/docker/model-runner/pkg/distribution/oci"
)

// DisplayProgress displays progress messages from a model pull/push operation
Expand Down Expand Up @@ -157,7 +156,7 @@ func writeDockerProgress(w io.Writer, msg *ProgressMessage, layerStatus map[stri
}

// Detect if this is a push operation based on the sentinel layer ID
isPush := layerID == oci.UploadingLayerID
isPush := msg.Mode == "push"

// Determine status based on progress
var status string
Expand Down
4 changes: 2 additions & 2 deletions pkg/distribution/huggingface/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (d *Downloader) downloadFileWithProgress(ctx context.Context, file RepoFile

// Write final progress for this file (100% complete)
if progressWriter != nil {
_ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID)
_ = progress.WriteProgress(progressWriter, "", totalImageSize, fileSize, fileSize, fileID, "")
}

return localPath, nil
Expand All @@ -208,7 +208,7 @@ func (pr *progressReader) Read(p []byte) (n int, err error) {

// Report progress periodically (every 1MB or when complete)
if pr.progressWriter != nil && (pr.bytesRead-pr.lastReported >= progress.MinBytesForUpdate || pr.bytesRead == pr.fileSize) {
_ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID)
_ = progress.WriteProgress(pr.progressWriter, "", pr.totalImageSize, pr.fileSize, pr.bytesRead, pr.fileID, "")
pr.lastReported = pr.bytesRead
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/distribution/huggingface/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
func BuildModel(ctx context.Context, client *Client, repo, revision, tag string, tempDir string, progressWriter io.Writer) (types.ModelArtifact, error) {
// List files in the repository
if progressWriter != nil {
_ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "")
_ = progress.WriteProgress(progressWriter, "Fetching file list...", 0, 0, 0, "", "")
}

files, err := client.ListFiles(ctx, repo, revision)
Expand All @@ -47,9 +47,9 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,

if progressWriter != nil {
if tag == "" || tag == "latest" || tag == "main" {
_ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "")
_ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization (default)", DefaultGGUFQuantization), 0, 0, 0, "", "")
} else {
_ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "")
_ = progress.WriteProgress(progressWriter, fmt.Sprintf("Selected %s quantization", tag), 0, 0, 0, "", "")
}
}
}
Expand All @@ -64,7 +64,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,
totalSize := TotalSize(allFiles)
msg := fmt.Sprintf("Found %d files (%.2f MB total)",
len(allFiles), float64(totalSize)/1024/1024)
_ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "")
_ = progress.WriteProgress(progressWriter, msg, uint64(totalSize), 0, 0, "", "")
}

// Step 3: Download all files
Expand All @@ -76,7 +76,7 @@ func BuildModel(ctx context.Context, client *Client, repo, revision, tag string,

// Step 4: Build the model artifact
if progressWriter != nil {
_ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "")
_ = progress.WriteProgress(progressWriter, "Building model artifact...", 0, 0, 0, "", "")
}

model, err := buildModelFromFiles(result.LocalPaths, weightFiles, configFiles, tempDir)
Expand Down
39 changes: 19 additions & 20 deletions pkg/distribution/internal/progress/reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type Message struct {
Total uint64 `json:"total"`
Pulled uint64 `json:"pulled"` // Deprecated: use Layer.Current
Layer Layer `json:"layer"` // Current layer information
Mode string `json:"mode"` // "push", "pull"
}

type Reporter struct {
Expand All @@ -39,6 +40,7 @@ type Reporter struct {
format progressF
layer oci.Layer
imageSize uint64
mode string
}

type progressF func(update oci.Update) string
Expand All @@ -51,14 +53,15 @@ func PushMsg(update oci.Update) string {
return fmt.Sprintf("Uploaded: %.2f MB", float64(update.Complete)/1024/1024)
}

func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer) *Reporter {
func NewProgressReporter(w io.Writer, msgF progressF, imageSize int64, layer oci.Layer, mode string) *Reporter {
return &Reporter{
out: w,
progress: make(chan oci.Update, 1),
done: make(chan struct{}),
format: msgF,
layer: layer,
imageSize: safeUint64(imageSize),
mode: mode,
}
}

Expand All @@ -84,31 +87,26 @@ func (r *Reporter) Updates() chan<- oci.Update {
now := time.Now()
var layerSize uint64
var layerID string
if r.layer != nil { // In case of Pull
id, err := r.layer.DiffID()
if err != nil {
r.err = err
continue
}
layerID = id.String()
size, err := r.layer.Size()
if err != nil {
r.err = err
continue
}
layerSize = safeUint64(size)
} else { // In case of Push there is no layer yet
// Use imageSize as layer is not known at this point
layerSize = r.imageSize
layerID = oci.UploadingLayerID // Fake ID for push operations to enable progress display
id, err := r.layer.DiffID()
if err != nil {
r.err = err
continue
}
layerID = id.String()
size, err := r.layer.Size()
if err != nil {
r.err = err
continue
}
layerSize = safeUint64(size)

incrementalBytes := p.Complete - lastComplete

// Only update if enough time has passed or enough bytes downloaded or finished
if now.Sub(lastUpdate) >= UpdateInterval ||
incrementalBytes >= MinBytesForUpdate ||
safeUint64(p.Complete) == layerSize {
if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID); err != nil {
if err := WriteProgress(r.out, r.format(p), r.imageSize, layerSize, safeUint64(p.Complete), layerID, r.mode); err != nil {
r.err = err
}
lastUpdate = now
Expand All @@ -127,7 +125,7 @@ func (r *Reporter) Wait() error {
}

// WriteProgress writes a progress update message
func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string) error {
func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64, layerID string, mode string) error {
return write(w, Message{
Type: "progress",
Message: msg,
Expand All @@ -138,6 +136,7 @@ func WriteProgress(w io.Writer, msg string, imageSize, layerSize, current uint64
Size: layerSize,
Current: current,
},
Mode: mode,
})
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/distribution/internal/progress/reporter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMessages(t *testing.T) {
layer1 := newMockLayer(2016)
layer2 := newMockLayer(1)

err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID)
err := WriteProgress(&buf, PullMsg(update), uint64(layer1.size+layer2.size), uint64(layer1.size), uint64(update.Complete), layer1.diffID, "")
if err != nil {
t.Fatalf("Failed to write progress message: %v", err)
}
Expand Down Expand Up @@ -224,7 +224,7 @@ func TestProgressEmissionScenarios(t *testing.T) {
var buf bytes.Buffer
layer := newMockLayer(tt.layerSize)

reporter := NewProgressReporter(&buf, PullMsg, 0, layer)
reporter := NewProgressReporter(&buf, PullMsg, 0, layer, "")
updates := reporter.Updates()

// Send updates with delays
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (s *LocalStore) Write(mdl oci.Image, tags []string, w io.Writer, opts ...Wr
var pr *progress.Reporter
var progressChan chan<- oci.Update
if safeWriter != nil {
pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l)
pr = progress.NewProgressReporter(safeWriter, progress.PullMsg, imageSize, l, "pull")
progressChan = pr.Updates()
}

Expand Down
Loading