|
11 | 11 | #include <memory>
|
12 | 12 | #include <optional>
|
13 | 13 | #include <cassert>
|
| 14 | +#include <fstream> |
14 | 15 |
|
15 | 16 | #ifndef TENSOR_CUH
|
16 | 17 | #define TENSOR_CUH
|
@@ -250,6 +251,19 @@ public:
|
250 | 251 | */
|
251 | 252 | static DTensor<T> createRandomTensor(size_t numRows, size_t numCols, size_t numMats, T low, T hi);
|
252 | 253 |
|
| 254 | + /** |
| 255 | + * Parse data from text file and create an instance of DTensor |
| 256 | + * |
| 257 | + * This static function reads data from a text file, creates a DTensor and uploads the data to the device. |
| 258 | + * |
| 259 | + * @param path_to_file path to file as string |
| 260 | + * @param mode storage mode (default: StorageMode::defaultMajor) |
| 261 | + * @return instance of DTensor |
| 262 | + * |
| 263 | + * @throws std::invalid_argument if the file is not found |
| 264 | + */ |
| 265 | + static DTensor<T> parseFromTextFile(std::string path_to_file, StorageMode mode = StorageMode::defaultMajor); |
| 266 | + |
253 | 267 | /**
|
254 | 268 | * Constructs a DTensor object.
|
255 | 269 | */
|
@@ -487,6 +501,13 @@ public:
|
487 | 501 | */
|
488 | 502 | void reshape(size_t newNumRows, size_t newNumCols, size_t newNumMats = 1);
|
489 | 503 |
|
| 504 | + /** |
| 505 | + * Saves the current instance of DTensor to a (text) file |
| 506 | + * |
| 507 | + * @param pathToFile |
| 508 | + */ |
| 509 | + void saveToFile(std::string pathToFile); |
| 510 | + |
490 | 511 | /* ------------- OPERATORS ------------- */
|
491 | 512 |
|
492 | 513 | DTensor &operator=(const DTensor &other);
|
@@ -564,6 +585,80 @@ DTensor<T> DTensor<T>::createRandomTensor(size_t numRows, size_t numCols, size_t
|
564 | 585 | throw std::invalid_argument("[createRandomTensor] unsupported type T");
|
565 | 586 | }
|
566 | 587 |
|
| 588 | + |
| 589 | +template<typename T> |
| 590 | +struct data_t { |
| 591 | + size_t numRows; |
| 592 | + size_t numCols; |
| 593 | + size_t numMats; |
| 594 | + std::vector<T> data; |
| 595 | +}; |
| 596 | + |
| 597 | +template<typename T> |
| 598 | +data_t<T> vectorFromFile(std::string path_to_file) { |
| 599 | + data_t<T> dataStruct; |
| 600 | + std::ifstream file; |
| 601 | + file.open(path_to_file, std::ios::in); |
| 602 | + if (!file.is_open()) { throw std::invalid_argument("the file you provided does not exist"); }; |
| 603 | + |
| 604 | + std::string line; |
| 605 | + getline(file, line); dataStruct.numRows = atoi(line.c_str()); |
| 606 | + getline(file, line); dataStruct.numCols = atoi(line.c_str()); |
| 607 | + getline(file, line); dataStruct.numMats = atoi(line.c_str()); |
| 608 | + |
| 609 | + size_t numElements = dataStruct.numRows * dataStruct.numCols * dataStruct.numMats; |
| 610 | + std::vector<T> vecDataFromFile(numElements); |
| 611 | + |
| 612 | + size_t i = 0; |
| 613 | + while (getline(file, line)) { |
| 614 | + if constexpr (std::is_same_v<T, int>) { |
| 615 | + vecDataFromFile[i] = atoi(line.c_str()); |
| 616 | + } else if constexpr (std::is_same_v<T, double>) { |
| 617 | + vecDataFromFile[i] = std::stod(line.c_str()); |
| 618 | + } else if constexpr (std::is_same_v<T, float>) { |
| 619 | + vecDataFromFile[i] = std::stof(line.c_str()); |
| 620 | + } else if constexpr (std::is_same_v<T, long double>) { |
| 621 | + vecDataFromFile[i] = std::stold(line.c_str()); |
| 622 | + } else if constexpr (std::is_same_v<T, long>) { |
| 623 | + vecDataFromFile[i] = std::stol(line.c_str()); |
| 624 | + } else if constexpr (std::is_same_v<T, long long>) { |
| 625 | + vecDataFromFile[i] = std::stoll(line.c_str()); |
| 626 | + } else if constexpr (std::is_same_v<T, unsigned long>) { |
| 627 | + vecDataFromFile[i] = std::stoul(line.c_str()); |
| 628 | + } else if constexpr (std::is_same_v<T, unsigned long long>) { |
| 629 | + vecDataFromFile[i] = std::stoull(line.c_str()); |
| 630 | + } else if constexpr (std::is_same_v<T, size_t>) { |
| 631 | + sscanf(line.c_str(), "%zu", &vecDataFromFile[i]); |
| 632 | + } |
| 633 | + // todo |
| 634 | + |
| 635 | + if (++i == numElements) break; |
| 636 | + } |
| 637 | + dataStruct.data = vecDataFromFile; |
| 638 | + file.close(); |
| 639 | + return dataStruct; |
| 640 | +} |
| 641 | + |
| 642 | +template<typename T> |
| 643 | +DTensor<T> DTensor<T>::parseFromTextFile(std::string path_to_file, |
| 644 | + StorageMode mode) { |
| 645 | + auto parsedData = vectorFromFile<T>(path_to_file); |
| 646 | + DTensor<T> tensorFromData(parsedData.data, parsedData.numRows, parsedData.numCols, parsedData.numMats); |
| 647 | + return tensorFromData; |
| 648 | +} |
| 649 | + |
| 650 | +template<typename T> |
| 651 | +void DTensor<T>::saveToFile(std::string pathToFile) { |
| 652 | + std::ofstream file(pathToFile); |
| 653 | + file << numRows() << std::endl << numCols() << std::endl << numMats() << std::endl; |
| 654 | + std::vector<T> myData(numEl()); download(myData); |
| 655 | + if constexpr (std::is_floating_point<T>::value) { |
| 656 | + int prec = std::numeric_limits<T>::max_digits10 - 1; |
| 657 | + file << std::setprecision(prec); |
| 658 | + } |
| 659 | + for(const T& el : myData) file << el << std::endl; |
| 660 | +} |
| 661 | + |
567 | 662 | template<typename T>
|
568 | 663 | void DTensor<T>::reshape(size_t newNumRows, size_t newNumCols, size_t newNumMats) {
|
569 | 664 | if (m_numRows == newNumRows && m_numCols == newNumCols && m_numMats == newNumMats) return;
|
|
0 commit comments