Skip to content

Commit 435cefe

Browse files
authoredDec 4, 2024··
Merge pull request #68 from GPUEngineering/hf/bt-data-storage
Proper interoperability in data formatting (C++ and Python)
2 parents 5257d5d + 2f3a5f5 commit 435cefe

File tree

5 files changed

+41
-6
lines changed

5 files changed

+41
-6
lines changed
 

‎CHANGELOG.md

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
<!-- ---------------------
9+
v1.7.1
10+
--------------------- -->
11+
## v1.7.1 - 4-12-2024
12+
13+
### Fixed
14+
15+
- Compatibility between Python and C++ in how the data is stored in bt files
16+
17+
818
<!-- ---------------------
919
v1.7.0
1020
--------------------- -->

‎python/VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.7.0
1+
1.7.1

‎python/gputils_api/gputils_api.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ def read_array_from_gputils_binary_file(path, dt=np.dtype('d')):
1414
nc = int.from_bytes(f.read(8), byteorder='little', signed=False) # read number of columns
1515
nm = int.from_bytes(f.read(8), byteorder='little', signed=False) # read number of matrices
1616
dat = np.fromfile(f, dtype=np.dtype(dt)) # read data
17-
dat = dat.reshape((nr, nc, nm)) # reshape
17+
18+
if nm >= 2: # if we actually have a 3D tensor (not a matrix or a vector)
19+
dat = dat.reshape((nm, nc, nr)).swapaxes(0, 2) # I'll explain this to you when you grow up
20+
else:
21+
dat = dat.reshape((nr, nc, nm)) # reshape
1822
return dat
1923

2024

@@ -27,6 +31,7 @@ def write_array_to_gputils_binary_file(x, path):
2731
:raises ValueError: if `x` has more than 3 dimensions
2832
:raises ValueError: if the file name specified `path` does not have the .bt extension
2933
"""
34+
3035
if not path.endswith(".bt"):
3136
raise ValueError("The file must have the .bt extension")
3237
x_shape = x.shape
@@ -36,8 +41,12 @@ def write_array_to_gputils_binary_file(x, path):
3641
nr = x_shape[0]
3742
nc = x_shape[1] if x_dims >= 2 else 1
3843
nm = x_shape[2] if x_dims == 3 else 1
44+
if x_dims == 3:
45+
x = x.swapaxes(0, 2).reshape(-1) # column-major storage; axis 2 last
46+
else:
47+
x = x.T.reshape(-1) # column-major storage
3948
with open(path, 'wb') as f:
4049
f.write(nr.to_bytes(8, 'little')) # write number of rows
4150
f.write(nc.to_bytes(8, 'little')) # write number of columns
4251
f.write(nm.to_bytes(8, 'little')) # write number of matrices
43-
x.reshape(nr*nc*nm, 1).tofile(f) # write data
52+
x.tofile(f) # write data

‎python/test/test.py

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def setUpClass(cls):
3737

3838
a = np.linspace(-100, 100, 4 * 5).reshape((4, 5)).astype('d')
3939
gpuapi.write_array_to_gputils_binary_file(a, os.path.join(base_dir, 'a_d.bt'))
40+
4041
gpuapi.write_array_to_gputils_binary_file(cls._B, os.path.join(base_dir, 'b_d.bt'))
4142

4243
def __test_read_eye(self, dt):

‎test/testTensor.cu

+18-3
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,24 @@ TEST_F(TensorTest, parseTensorFromFileBinary) {
183183
TEST_F(TensorTest, parseTensorFromBinaryPython) {
184184
std::string fName = "../../python/b_d.bt";
185185
DTensor<double> b = DTensor<double>::parseFromFile(fName);
186-
std::vector<double> vb(12);
187-
b.download(vb);
188-
for (size_t i = 0; i < 12; i++) EXPECT_NEAR(i + 1., vb[i], PRECISION_HIGH);
186+
for (size_t i=0; i<3; i++) {
187+
for (size_t j=0; j<3; j++) {
188+
EXPECT_NEAR(1 + 2*j + 6*i, b(i, j, 0), PRECISION_HIGH);
189+
EXPECT_NEAR(2 + 2*j + 6*i, b(i, j, 1), PRECISION_HIGH);
190+
}
191+
}
192+
}
193+
194+
195+
/* ---------------------------------------
196+
* Parse not existing file
197+
* --------------------------------------- */
198+
199+
TEST_F(TensorTest, parseTensorFromNonexistentFile) {
200+
std::string fName = "../../python/whatever.bt";
201+
EXPECT_THROW(DTensor<double> b = DTensor<double>::parseFromFile(fName, rowMajor), std::invalid_argument);
202+
std::string fName2 = "../../python/whatever.txt";
203+
EXPECT_THROW(DTensor<double> b = DTensor<double>::parseFromFile(fName2, rowMajor), std::invalid_argument);
189204
}
190205

191206

0 commit comments

Comments
 (0)
Please sign in to comment.