VOOZH about

URL: https://pypi.org/project/torchinfo/

⇱ torchinfo · PyPI


Skip to main content

torchinfo 1.8.0

pip install torchinfo

Latest release

Released:

Model summary in PyTorch, based off of the original torchsummary.

Navigation

Verified details

These details have been verified by PyPI
Maintainers
👁 Avatar for tyleryep from gravatar.com
tyleryep

Unverified details

These details have not been verified by PyPI
Project links
Meta
  • License: MIT License (MIT)
  • Author: Tyler Yep @tyleryep
  • Tags torch , pytorch , torchsummary , torch-summary , summary , keras , deep-learning , ml , torchinfo , torch-info , visualize , model , statistics , layer , stats
  • Requires: Python >=3.7

Project description

torchinfo

👁 Python 3.7+
👁 PyPI version
👁 Conda version
👁 Build Status
👁 pre-commit.ci status
👁 GitHub license
👁 codecov
👁 Downloads

(formerly torch-summary)

Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensorflow's model.summary() API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.

This is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.

Supports PyTorch versions 1.4.0+.

Usage

pip install torchinfo

Alternatively, via conda:

conda install -c conda-forge torchinfo

How To Use

from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param # Mult-Adds
================================================================================================================
SingleInputNet [7, 1, 28, 28] [7, 10] -- --
├─Conv2d: 1-1 [7, 1, 28, 28] [7, 10, 24, 24] 260 1,048,320
├─Conv2d: 1-2 [7, 10, 12, 12] [7, 20, 8, 8] 5,020 2,248,960
├─Dropout2d: 1-3 [7, 20, 8, 8] [7, 20, 8, 8] -- --
├─Linear: 1-4 [7, 320] [7, 50] 16,050 112,350
├─Linear: 1-5 [7, 50] [7, 10] 510 3,570
================================================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
================================================================================================================

Note: if you are using a Jupyter Notebook or Google Colab, summary(model, ...) must be the returned value of the cell. If it is not, you should wrap the summary in a print(), e.g. print(summary(model, ...)). See tests/jupyter_test.ipynb for examples.

This version now supports:

  • RNNs, LSTMs, and other recursive layers
  • Branching output used to explore model layers using specified depths
  • Returns ModelStatistics object containing all summary data fields
  • Configurable rows/columns
  • Jupyter Notebook / Google Colab

Other new features:

  • Verbose mode to show weights and bias layers
  • Accepts either input data or simply the input shape!
  • Customizable line widths and batch dimension
  • Comprehensive unit/output testing, linting, and code coverage testing

Community Contributions:

  • Sequentials & ModuleLists (thanks to @roym899)
  • Improved Mult-Add calculations (thanks to @TE-StefanUhlich, @zmzhang2000)
  • Dict/Misc input data (thanks to @e-dorigatti)
  • Pruned layer support (thanks to @MajorCarrot)

Documentation

def summary(
 model: nn.Module,
 input_size: Optional[INPUT_SIZE_TYPE] = None,
 input_data: Optional[INPUT_DATA_TYPE] = None,
 batch_dim: Optional[int] = None,
 cache_forward_pass: Optional[bool] = None,
 col_names: Optional[Iterable[str]] = None,
 col_width: int = 25,
 depth: int = 3,
 device: Optional[torch.device] = None,
 dtypes: Optional[List[torch.dtype]] = None,
 mode: str | None = None,
 row_settings: Optional[Iterable[str]] = None,
 verbose: int = 1,
 **kwargs: Any,
) -> ModelStatistics:
"""
Summarize the given PyTorch model. Summarized information includes:
 1) Layer names,
 2) input/output shapes,
 3) kernel shape,
 4) # of parameters,
 5) # of operations (Mult-Adds),
 6) whether layer is trainable

NOTE: If neither input_data or input_size are provided, no forward pass through the
network is performed, and the provided model information is limited to layer names.

Args:
 model (nn.Module):
 PyTorch model to summarize. The model should be fully in either train()
 or eval() mode. If layers are not all in the same mode, running summary
 may have side effects on batchnorm or dropout statistics. If you
 encounter an issue with this, please open a GitHub issue.

 input_size (Sequence of Sizes):
 Shape of input data as a List/Tuple/torch.Size
 (dtypes must match model input, default is FloatTensors).
 You should include batch size in the tuple.
 Default: None

 input_data (Sequence of Tensors):
 Arguments for the model's forward pass (dtypes inferred).
 If the forward() function takes several parameters, pass in a list of
 args or a dict of kwargs (if your forward() function takes in a dict
 as its only argument, wrap it in a list).
 Default: None

 batch_dim (int):
 Batch_dimension of input data. If batch_dim is None, assume
 input_data / input_size contains the batch dimension, which is used
 in all calculations. Else, expand all tensors to contain the batch_dim.
 Specifying batch_dim can be an runtime optimization, since if batch_dim
 is specified, torchinfo uses a batch size of 1 for the forward pass.
 Default: None

 cache_forward_pass (bool):
 If True, cache the run of the forward() function using the model
 class name as the key. If the forward pass is an expensive operation,
 this can make it easier to modify the formatting of your model
 summary, e.g. changing the depth or enabled column types, especially
 in Jupyter Notebooks.
 WARNING: Modifying the model architecture or input data/input size when
 this feature is enabled does not invalidate the cache or re-run the
 forward pass, and can cause incorrect summaries as a result.
 Default: False

 col_names (Iterable[str]):
 Specify which columns to show in the output. Currently supported: (
 "input_size",
 "output_size",
 "num_params",
 "params_percent",
 "kernel_size",
 "mult_adds",
 "trainable",
 )
 Default: ("output_size", "num_params")
 If input_data / input_size are not provided, only "num_params" is used.

 col_width (int):
 Width of each column.
 Default: 25

 depth (int):
 Depth of nested layers to display (e.g. Sequentials).
 Nested layers below this depth will not be displayed in the summary.
 Default: 3

 device (torch.Device):
 Uses this torch device for model and input_data.
 If not specified, uses the dtype of input_data if given, or the
 parameters of the model. Otherwise, uses the result of
 torch.cuda.is_available().
 Default: None

 dtypes (List[torch.dtype]):
 If you use input_size, torchinfo assumes your input uses FloatTensors.
 If your model use a different data type, specify that dtype.
 For multiple inputs, specify the size of both inputs, and
 also specify the types of each parameter here.
 Default: None

 mode (str)
 Either "train" or "eval", which determines whether we call
 model.train() or model.eval() before calling summary().
 Default: "eval".

 row_settings (Iterable[str]):
 Specify which features to show in a row. Currently supported: (
 "ascii_only",
 "depth",
 "var_names",
 )
 Default: ("depth",)

 verbose (int):
 0 (quiet): No output
 1 (default): Print model summary
 2 (verbose): Show weight and bias layers in full detail
 Default: 1
 If using a Juypter Notebook or Google Colab, the default is 0.

 **kwargs:
 Other arguments used in `model.forward` function. Passing *args is no
 longer supported.

Return:
 ModelStatistics object
 See torchinfo/model_statistics.py for more information.
"""

Examples

Get Model Summary as String

from torchinfo import summary

model_stats = summary(your_model, (1, 3, 28, 28), verbose=0)
summary_str = str(model_stats)
# summary_str contains the string representation of the summary!

Explore Different Configurations

class LSTMNet(nn.Module):
 def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
 super().__init__()
 self.hidden_dim = hidden_dim
 self.embedding = nn.Embedding(vocab_size, embed_dim)
 self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
 self.decoder = nn.Linear(hidden_dim, vocab_size)

 def forward(self, x):
 embed = self.embedding(x)
 out, hidden = self.encoder(embed)
 out = self.decoder(out)
 out = out.view(-1, out.size(2))
 return out, hidden

summary(
 LSTMNet(),
 (1, 100),
 dtypes=[torch.long],
 verbose=2,
 col_width=16,
 col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
 row_settings=["var_names"],
)
========================================================================================================================
Layer (type (var_name)) Kernel Shape Output Shape Param # Mult-Adds
========================================================================================================================
LSTMNet (LSTMNet) -- [100, 20] -- --
├─Embedding (embedding) -- [1, 100, 300] 6,000 6,000
│ └─weight [300, 20] └─6,000
├─LSTM (encoder) -- [1, 100, 512] 3,768,320 376,832,000
│ └─weight_ih_l0 [2048, 300] ├─614,400
│ └─weight_hh_l0 [2048, 512] ├─1,048,576
│ └─bias_ih_l0 [2048] ├─2,048
│ └─bias_hh_l0 [2048] ├─2,048
│ └─weight_ih_l1 [2048, 512] ├─1,048,576
│ └─weight_hh_l1 [2048, 512] ├─1,048,576
│ └─bias_ih_l1 [2048] ├─2,048
│ └─bias_hh_l1 [2048] └─2,048
├─Linear (decoder) -- [1, 100, 20] 10,260 10,260
│ └─weight [512, 20] ├─10,240
│ └─bias [20] └─20
========================================================================================================================
Total params: 3,784,580
Trainable params: 3,784,580
Non-trainable params: 0
Total mult-adds (M): 376.85
========================================================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.67
Params size (MB): 15.14
Estimated Total Size (MB): 15.80
========================================================================================================================

ResNet

import torchvision

model = torchvision.models.resnet152()
summary(model, (1, 3, 224, 224), depth=3)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet [1, 1000] --
├─Conv2d: 1-1 [1, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128
├─ReLU: 1-3 [1, 64, 112, 112] --
├─MaxPool2d: 1-4 [1, 64, 56, 56] --
├─Sequential: 1-5 [1, 256, 56, 56] --
│ └─Bottleneck: 2-1 [1, 256, 56, 56] --
│ │ └─Conv2d: 3-1 [1, 64, 56, 56] 4,096
│ │ └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [1, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [1, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [1, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-8 [1, 256, 56, 56] 512
│ │ └─Sequential: 3-9 [1, 256, 56, 56] 16,896
│ │ └─ReLU: 3-10 [1, 256, 56, 56] --
│ └─Bottleneck: 2-2 [1, 256, 56, 56] --

 ...
 ...
 ...

├─AdaptiveAvgPool2d: 1-9 [1, 2048, 1, 1] --
├─Linear: 1-10 [1, 1000] 2,049,000
==========================================================================================
Total params: 60,192,808
Trainable params: 60,192,808
Non-trainable params: 0
Total mult-adds (G): 11.51
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 360.87
Params size (MB): 240.77
Estimated Total Size (MB): 602.25
==========================================================================================

Multiple Inputs w/ Different Data Types

class MultipleInputNetDifferentDtypes(nn.Module):
 def __init__(self):
 super().__init__()
 self.fc1a = nn.Linear(300, 50)
 self.fc1b = nn.Linear(50, 10)

 self.fc2a = nn.Linear(300, 50)
 self.fc2b = nn.Linear(50, 10)

 def forward(self, x1, x2):
 x1 = F.relu(self.fc1a(x1))
 x1 = self.fc1b(x1)
 x2 = x2.type(torch.float)
 x2 = F.relu(self.fc2a(x2))
 x2 = self.fc2b(x2)
 x = torch.cat((x1, x2), 0)
 return F.log_softmax(x, dim=1)

summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])

Alternatively, you can also pass in the input_data itself, and torchinfo will automatically infer the data types.

input_data = torch.randn(1, 300)
other_input_data = torch.randn(1, 300).long()
model = MultipleInputNetDifferentDtypes()

summary(model, input_data=[input_data, other_input_data, ...])

Sequentials & ModuleLists

class ContainerModule(nn.Module):

 def __init__(self):
 super().__init__()
 self._layers = nn.ModuleList()
 self._layers.append(nn.Linear(5, 5))
 self._layers.append(ContainerChildModule())
 self._layers.append(nn.Linear(5, 5))

 def forward(self, x):
 for layer in self._layers:
 x = layer(x)
 return x


class ContainerChildModule(nn.Module):

 def __init__(self):
 super().__init__()
 self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
 self._between = nn.Linear(5, 5)

 def forward(self, x):
 out = self._sequential(x)
 out = self._between(out)
 for l in self._sequential:
 out = l(out)

 out = self._sequential(x)
 for l in self._sequential:
 out = l(out)
 return out

summary(ContainerModule(), (1, 5))
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ContainerModule [1, 5] --
├─ModuleList: 1-1 -- --
│ └─Linear: 2-1 [1, 5] 30
│ └─ContainerChildModule: 2-2 [1, 5] --
│ │ └─Sequential: 3-1 [1, 5] --
│ │ │ └─Linear: 4-1 [1, 5] 30
│ │ │ └─Linear: 4-2 [1, 5] 30
│ │ └─Linear: 3-2 [1, 5] 30
│ │ └─Sequential: 3-3 -- (recursive)
│ │ │ └─Linear: 4-3 [1, 5] (recursive)
│ │ │ └─Linear: 4-4 [1, 5] (recursive)
│ │ └─Sequential: 3-4 [1, 5] (recursive)
│ │ │ └─Linear: 4-5 [1, 5] (recursive)
│ │ │ └─Linear: 4-6 [1, 5] (recursive)
│ │ │ └─Linear: 4-7 [1, 5] (recursive)
│ │ │ └─Linear: 4-8 [1, 5] (recursive)
│ └─Linear: 2-3 [1, 5] 30
==========================================================================================
Total params: 150
Trainable params: 150
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

Contributing

All issues and pull requests are much appreciated! If you are wondering how to build the project:

  • torchinfo is actively developed using the lastest version of Python.
    • Changes should be backward compatible to Python 3.7, and will follow Python's End-of-Life guidance for old versions.
    • Run pip install -r requirements-dev.txt. We use the latest versions of all dev packages.
    • Run pre-commit install.
    • To use auto-formatting tools, use pre-commit run -a.
    • To run unit tests, run pytest.
    • To update the expected output files, run pytest --overwrite.
    • To skip output file tests, use pytest --no-output

References

  • Thanks to @sksq96, @nmhkahn, and @sangyx for providing the inspiration for this project.
  • For Model Size Estimation @jacobkimmel (details here)

Project details

Verified details

These details have been verified by PyPI
Maintainers
👁 Avatar for tyleryep from gravatar.com
tyleryep

Unverified details

These details have not been verified by PyPI
Project links
Meta
  • License: MIT License (MIT)
  • Author: Tyler Yep @tyleryep
  • Tags torch , pytorch , torchsummary , torch-summary , summary , keras , deep-learning , ml , torchinfo , torch-info , visualize , model , statistics , layer , stats
  • Requires: Python >=3.7

Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchinfo-1.8.0.tar.gz (25.9 kB view details)

Uploaded Source

Built Distribution

Filter files by name, interpreter, ABI, and platform.

If you're not sure about the file name format, learn more about wheel file names.

Copy a direct link to the current filters

torchinfo-1.8.0-py3-none-any.whl (23.4 kB view details)

Uploaded Python 3

File details

Details for the file torchinfo-1.8.0.tar.gz.

File metadata

  • Download URL: torchinfo-1.8.0.tar.gz
  • Upload date:
  • Size: 25.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for torchinfo-1.8.0.tar.gz
Algorithm Hash digest
SHA256 72e94b0e9a3e64dc583a8e5b7940b8938a1ac0f033f795457f27e6f4e7afa2e9
MD5 9e55abc36fa0ce929beefde5e4153cf1
BLAKE2b-256 53d92b811d1c0812e9ef23e6cf2dbe022becbe6c5ab065e33fd80ee05c0cd996

See more details on using hashes here.

File details

Details for the file torchinfo-1.8.0-py3-none-any.whl.

File metadata

  • Download URL: torchinfo-1.8.0-py3-none-any.whl
  • Upload date:
  • Size: 23.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for torchinfo-1.8.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2e911c2918603f945c26ff21a3a838d12709223dc4ccf243407bce8b6e897b46
MD5 62ab1041f930012f5a50d0f95c764b15
BLAKE2b-256 7225973bd6128381951b23cdcd8a9870c6dcfc5606cb864df8eabd82e529f9c1

See more details on using hashes here.

Supported by

👁 Image
AWS Cloud computing and Security Sponsor 👁 Image
Datadog Monitoring 👁 Image
Depot Continuous Integration 👁 Image
Fastly CDN 👁 Image
Google Download Analytics 👁 Image
Pingdom Monitoring 👁 Image
Sentry Error logging 👁 Image
StatusPage Status page