|
5 | 5 | "encoding/json"
|
6 | 6 | "errors"
|
7 | 7 | "fmt"
|
| 8 | + "io" |
| 9 | + "net/http" |
| 10 | + "net/url" |
8 | 11 | "os"
|
9 | 12 | "os/signal"
|
10 | 13 | "path/filepath"
|
@@ -341,24 +344,65 @@ func writeOutput(outputPath string, output []byte) error {
|
341 | 344 | }
|
342 | 345 |
|
343 | 346 | func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
|
344 |
| - dataurlObj, err := dataurl.DecodeString(outputString) |
345 |
| - if err != nil { |
346 |
| - return fmt.Errorf("Failed to decode dataurl: %w", err) |
| 347 | + var output []byte |
| 348 | + var contentType string |
| 349 | + |
| 350 | + if httpURL, ok := getHTTPURL(outputString); ok { |
| 351 | + resp, err := http.Get(httpURL.String()) |
| 352 | + if err != nil { |
| 353 | + return fmt.Errorf("Failed to fetch URL: %w", err) |
| 354 | + } |
| 355 | + defer resp.Body.Close() |
| 356 | + |
| 357 | + output, err = io.ReadAll(resp.Body) |
| 358 | + if err != nil { |
| 359 | + return fmt.Errorf("Failed to read response: %w", err) |
| 360 | + } |
| 361 | + contentType = resp.Header.Get("Content-Type") |
| 362 | + contentType = useExtensionIfUnknownContentType(contentType, output, outputString) |
| 363 | + |
| 364 | + } else { |
| 365 | + dataurlObj, err := dataurl.DecodeString(outputString) |
| 366 | + if err != nil { |
| 367 | + return fmt.Errorf("Failed to decode dataurl: %w", err) |
| 368 | + } |
| 369 | + output = dataurlObj.Data |
| 370 | + contentType = dataurlObj.ContentType() |
347 | 371 | }
|
348 |
| - output := dataurlObj.Data |
349 | 372 |
|
350 | 373 | if addExtension {
|
351 |
| - extension := mime.ExtensionByType(dataurlObj.ContentType()) |
352 |
| - if extension != "" { |
353 |
| - outputPath += extension |
| 374 | + if ext := mime.ExtensionByType(contentType); ext != "" { |
| 375 | + outputPath += ext |
354 | 376 | }
|
355 | 377 | }
|
356 | 378 |
|
357 |
| - if err := writeOutput(outputPath, output); err != nil { |
358 |
| - return err |
| 379 | + return writeOutput(outputPath, output) |
| 380 | +} |
| 381 | + |
| 382 | +func getHTTPURL(str string) (*url.URL, bool) { |
| 383 | + u, err := url.Parse(str) |
| 384 | + if err == nil && (u.Scheme == "http" || u.Scheme == "https") { |
| 385 | + return u, true |
359 | 386 | }
|
| 387 | + return nil, false |
| 388 | +} |
360 | 389 |
|
361 |
| - return nil |
| 390 | +func useExtensionIfUnknownContentType(contentType string, content []byte, filename string) string { |
| 391 | + // If contentType is empty or application/octet-string, first attempt to get the |
| 392 | + // content type from the file extension, and if that fails, try to guess it from |
| 393 | + // the content itself. |
| 394 | + |
| 395 | + if contentType == "" || contentType == "application/octet-stream" { |
| 396 | + if ext := filepath.Ext(filename); ext != "" { |
| 397 | + if mimeType := mime.TypeByExtension(ext); mimeType != "" { |
| 398 | + return mimeType |
| 399 | + } |
| 400 | + } |
| 401 | + if detected := http.DetectContentType(content); detected != "" { |
| 402 | + return detected |
| 403 | + } |
| 404 | + } |
| 405 | + return contentType |
362 | 406 | }
|
363 | 407 |
|
364 | 408 | func parseInputFlags(inputs []string) (predict.Inputs, error) {
|
|
0 commit comments