Skip to content

Commit 272e9de

Browse files
Handle URL Path in cog predict
On replicate.com, a model with return type `Path` can return a URL, and it is handled as though the model returned a local file path. Locally, `cog predict` will fail with ``` ⅹ Failed to write output: Failed to decode dataurl: missing data prefix ``` This change makes `cog predict` download and save the output URL. The `useExtensionIfUnknownContentType` is mostly needed because .webp is incorrectly returned by replicate.delivery as application/octet-stream.
1 parent 9681128 commit 272e9de

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

Diff for: pkg/cli/predict.go

+54-10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"io"
9+
"net/http"
10+
"net/url"
811
"os"
912
"os/signal"
1013
"path/filepath"
@@ -341,24 +344,65 @@ func writeOutput(outputPath string, output []byte) error {
341344
}
342345

343346
func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
344-
dataurlObj, err := dataurl.DecodeString(outputString)
345-
if err != nil {
346-
return fmt.Errorf("Failed to decode dataurl: %w", err)
347+
var output []byte
348+
var contentType string
349+
350+
if httpURL, ok := getHTTPURL(outputString); ok {
351+
resp, err := http.Get(httpURL.String())
352+
if err != nil {
353+
return fmt.Errorf("Failed to fetch URL: %w", err)
354+
}
355+
defer resp.Body.Close()
356+
357+
output, err = io.ReadAll(resp.Body)
358+
if err != nil {
359+
return fmt.Errorf("Failed to read response: %w", err)
360+
}
361+
contentType = resp.Header.Get("Content-Type")
362+
contentType = useExtensionIfUnknownContentType(contentType, output, outputString)
363+
364+
} else {
365+
dataurlObj, err := dataurl.DecodeString(outputString)
366+
if err != nil {
367+
return fmt.Errorf("Failed to decode dataurl: %w", err)
368+
}
369+
output = dataurlObj.Data
370+
contentType = dataurlObj.ContentType()
347371
}
348-
output := dataurlObj.Data
349372

350373
if addExtension {
351-
extension := mime.ExtensionByType(dataurlObj.ContentType())
352-
if extension != "" {
353-
outputPath += extension
374+
if ext := mime.ExtensionByType(contentType); ext != "" {
375+
outputPath += ext
354376
}
355377
}
356378

357-
if err := writeOutput(outputPath, output); err != nil {
358-
return err
379+
return writeOutput(outputPath, output)
380+
}
381+
382+
func getHTTPURL(str string) (*url.URL, bool) {
383+
u, err := url.Parse(str)
384+
if err == nil && (u.Scheme == "http" || u.Scheme == "https") {
385+
return u, true
359386
}
387+
return nil, false
388+
}
360389

361-
return nil
390+
func useExtensionIfUnknownContentType(contentType string, content []byte, filename string) string {
391+
// If contentType is empty or application/octet-string, first attempt to get the
392+
// content type from the file extension, and if that fails, try to guess it from
393+
// the content itself.
394+
395+
if contentType == "" || contentType == "application/octet-stream" {
396+
if ext := filepath.Ext(filename); ext != "" {
397+
if mimeType := mime.TypeByExtension(ext); mimeType != "" {
398+
return mimeType
399+
}
400+
}
401+
if detected := http.DetectContentType(content); detected != "" {
402+
return detected
403+
}
404+
}
405+
return contentType
362406
}
363407

364408
func parseInputFlags(inputs []string) (predict.Inputs, error) {

0 commit comments

Comments
 (0)