Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 447b481

Browse filesBrowse files
qxy11pytorchmergebot
authored andcommitted
[AOTI] Save data sizes to constants_info (#154534)
Differential Revision: D75223179 Pull Request resolved: #154534 Approved by: https://github.com/muchulee8
1 parent 9c7ed3e commit 447b481
Copy full SHA for 447b481

File tree

Expand file treeCollapse file tree

3 files changed

+26
-0
lines changed
Filter options
Expand file treeCollapse file tree

3 files changed

+26
-0
lines changed

‎torch/_inductor/codegen/aoti_runtime/interface.cpp

Copy file name to clipboardExpand all lines: torch/_inductor/codegen/aoti_runtime/interface.cpp
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,17 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
207207
{ *dtype = container->constant_dtype(idx); })
208208
}
209209

210+
AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize(
211+
AOTInductorModelContainerHandle container_handle,
212+
size_t idx,
213+
size_t* data_size) {
214+
auto* container =
215+
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
216+
container_handle);
217+
CONVERT_EXCEPTION_TO_ERROR_CODE(
218+
{ *data_size = container->constant_data_size(idx); })
219+
}
220+
210221
AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap(
211222
AOTInductorModelContainerHandle container_handle,
212223
AOTInductorConstantMapHandle constant_map_handle,

‎torch/csrc/inductor/aoti_runtime/interface.h

Copy file name to clipboardExpand all lines: torch/csrc/inductor/aoti_runtime/interface.h
+8Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(
117117
size_t idx,
118118
int32_t* dtype);
119119

120+
// Retrieves a constant's data size.
121+
// idx is the index of the internal's constants.
122+
// Need idx < num_constants from AOTInductorModelContainerGetNumConstants
123+
AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize(
124+
AOTInductorModelContainerHandle container_handle,
125+
size_t idx,
126+
size_t* data_size);
127+
120128
// Extract the constants that is being used in the container.
121129
AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap(
122130
AOTInductorModelContainerHandle container_handle,

‎torch/csrc/inductor/aoti_runtime/model_container.h

Copy file name to clipboardExpand all lines: torch/csrc/inductor/aoti_runtime/model_container.h
+7Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ class AOTInductorModelContainer {
232232
return models_[0]->constant_from_folded(static_cast<int64_t>(idx));
233233
}
234234

235+
size_t constant_data_size(size_t idx) const {
236+
if (this->num_models() == 0) {
237+
throw std::runtime_error("No available models in container!");
238+
}
239+
return models_[0]->constant_data_size(static_cast<int64_t>(idx));
240+
}
241+
235242
// retrieve type of constants_info_[idx]
236243
int32_t constant_type(size_t idx) const {
237244
if (this->num_models() == 0) {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.