Skip to content

Commit

Permalink
adding a header to the upload
Browse files Browse the repository at this point in the history
  • Loading branch information
maxwellflitton committed Dec 6, 2023
1 parent 0f9f5e4 commit 45c773e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/python_apis/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use surrealml_core::storage::surml_file::SurMlFile;
use surrealml_core::storage::header::normalisers::wrapper::NormaliserType;
use std::fs::File;
use std::io::Read;
use hyper::{Body, Request};
use hyper::{Body, Request, Method};
use hyper::header::CONTENT_TYPE;
use hyper::{Client, Uri};

use crate::python_state::{PYTHON_STATE, generate_unique_id};
Expand Down Expand Up @@ -236,8 +237,14 @@ pub fn upload_model(file_path: String, url: String, chunk_size: usize) {
let uri: Uri = url.parse().unwrap();
let generator = StreamAdapter::new(chunk_size, file_path);
let body = Body::wrap_stream(generator);
let req = Request::post(uri).body(body).unwrap();
let tokio_runtime = tokio::runtime::Builder::new_current_thread().build().unwrap();

let req = Request::builder()
.method(Method::POST)
.uri(uri)
.header(CONTENT_TYPE, "application/octet-stream")
.body(body).unwrap();

let tokio_runtime = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap();
tokio_runtime.block_on( async move {
let _response = client.request(req).await.unwrap();
});
Expand Down
1 change: 1 addition & 0 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async fn root(mut stream: BodyStream) -> &'static str {
let chunk = chunk.unwrap();
buffer.extend_from_slice(&chunk);
}
println!("Buffer length: {:?}", buffer.len());
let mut file = SurMlFile::from_bytes(buffer).unwrap();

// check some of the values in the header
Expand Down
6 changes: 5 additions & 1 deletion surrealml/surml_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from surrealml.rust_surrealml import load_cached_raw_model, add_column, add_output, add_normaliser, save_model, \
add_name, load_model, add_description, add_version, to_bytes, add_engine, add_author, add_origin
from surrealml.rust_surrealml import raw_compute, buffered_compute
from surrealml.rust_surrealml import raw_compute, buffered_compute, upload_model

from surrealml.model_cache import SkLearnModelCache
from surrealml.engine_enum import Engine
Expand Down Expand Up @@ -153,6 +153,10 @@ def load(path):
self = SurMlFile()
self.file_id = load_model(path)
return self

@staticmethod
def upload(path: str, url: str, chunk_size: int) -> None:
upload_model(path, url, chunk_size)

def raw_compute(self, input_vector, dims=None):
"""
Expand Down

0 comments on commit 45c773e

Please sign in to comment.