Apache Singa
A General Distributed Deep Learning Library
snapshot.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef SINGA_UTILS_SNAPSHOT_H_
23 #define SINGA_UTILS_SNAPSHOT_H_
24 
25 #include "singa/io/reader.h"
26 #include "singa/io/writer.h"
27 #include "singa/utils/logging.h"
28 #include "singa/proto/core.pb.h"
29 #include "singa/core/tensor.h"
30 
31 #include <string>
32 #include <unordered_set>
33 #include <unordered_map>
34 #include <memory>
35 
36 namespace singa {
44 class Snapshot {
45  public:
46  enum Mode { kRead, kWrite };
53  Snapshot(const std::string& prefix, Mode mode, int max_param_size = 10);
54  ~Snapshot() {}
56  std::vector<std::pair<std::string, Tensor>> Read();
58  std::vector<std::pair<std::string, Shape>> ReadShape();
60  Tensor Read(const std::string& Key);
62  Shape ReadShape(const std::string& key);
66  void Write(const std::string& key, const Tensor& param);
68  int version() const {
69  return version_;
70  }
71 
72  private:
74  int version_ = 0;
75  std::string prefix_;
76  Mode mode_;
77  std::unique_ptr<io::BinFileWriter> bin_writer_ptr_;
78  std::unique_ptr<io::Writer> text_writer_ptr_;
79  std::unique_ptr<io::BinFileReader> bin_reader_ptr_;
81  std::unordered_set<std::string> param_names_;
83  std::unordered_map<std::string, Tensor> param_map_;
84 };
85 } // namespace singa
86 
87 #endif // SINGA_UTILS_SNAPSHOT_H_
std::vector< std::pair< std::string, Tensor > > Read()
Read parameters saved as tensors from checkpoint file.
Snapshot(const std::string &prefix, Mode mode, int max_param_size=10)
<prefix>.model is the binary file for parameter key-value pair.
int version() const
available for singa > 1.0.1
Definition: snapshot.h:68
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
std::vector< std::pair< std::string, Shape > > ReadShape()
Read parameter shapes from description file.
void Write(const std::string &key, const Tensor &param)
Serialize and dump out parameter.
The snapshot management.
Definition: snapshot.h:44
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48