From 45c773ee4a0de354800d0090b0053eaf24d60030 Mon Sep 17 00:00:00 2001 From: maxwellflitton Date: Wed, 6 Dec 2023 15:49:16 +0000 Subject: [PATCH] adding a header to the upload --- src/python_apis/storage.rs | 13 ++++++++++--- src/transport.rs | 1 + surrealml/surml_file.py | 6 +++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/python_apis/storage.rs b/src/python_apis/storage.rs index 4fffb6b..bd5e288 100644 --- a/src/python_apis/storage.rs +++ b/src/python_apis/storage.rs @@ -12,7 +12,8 @@ use surrealml_core::storage::surml_file::SurMlFile; use surrealml_core::storage::header::normalisers::wrapper::NormaliserType; use std::fs::File; use std::io::Read; -use hyper::{Body, Request}; +use hyper::{Body, Request, Method}; +use hyper::header::CONTENT_TYPE; use hyper::{Client, Uri}; use crate::python_state::{PYTHON_STATE, generate_unique_id}; @@ -236,8 +237,14 @@ pub fn upload_model(file_path: String, url: String, chunk_size: usize) { let uri: Uri = url.parse().unwrap(); let generator = StreamAdapter::new(chunk_size, file_path); let body = Body::wrap_stream(generator); - let req = Request::post(uri).body(body).unwrap(); - let tokio_runtime = tokio::runtime::Builder::new_current_thread().build().unwrap(); + + let req = Request::builder() + .method(Method::POST) + .uri(uri) + .header(CONTENT_TYPE, "application/octet-stream") + .body(body).unwrap(); + + let tokio_runtime = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap(); tokio_runtime.block_on( async move { let _response = client.request(req).await.unwrap(); }); diff --git a/src/transport.rs b/src/transport.rs index 8602309..9c0577e 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -22,6 +22,7 @@ async fn root(mut stream: BodyStream) -> &'static str { let chunk = chunk.unwrap(); buffer.extend_from_slice(&chunk); } + println!("Buffer length: {:?}", buffer.len()); let mut file = SurMlFile::from_bytes(buffer).unwrap(); // check some of the values in the header diff --git a/surrealml/surml_file.py b/surrealml/surml_file.py index 98dd989..356ff48 100644 --- a/surrealml/surml_file.py +++ b/surrealml/surml_file.py @@ -7,7 +7,7 @@ import torch from surrealml.rust_surrealml import load_cached_raw_model, add_column, add_output, add_normaliser, save_model, \ add_name, load_model, add_description, add_version, to_bytes, add_engine, add_author, add_origin -from surrealml.rust_surrealml import raw_compute, buffered_compute +from surrealml.rust_surrealml import raw_compute, buffered_compute, upload_model from surrealml.model_cache import SkLearnModelCache from surrealml.engine_enum import Engine @@ -153,6 +153,10 @@ def load(path): self = SurMlFile() self.file_id = load_model(path) return self + + @staticmethod + def upload(path: str, url: str, chunk_size: int) -> None: + upload_model(path, url, chunk_size) def raw_compute(self, input_vector, dims=None): """