Open3D (C++ API)  0.19.0
Loading...
Searching...
No Matches
TorchHelper.h File Reference
#include <torch/script.h>
#include <sstream>
#include <type_traits>
#include "open3d/ml/ShapeChecking.h"

Go to the source code of this file.

Macros

#define CHECK_CUDA(x)
#define CHECK_CONTIGUOUS(x)
#define CHECK_TYPE(x, type)
#define CHECK_SAME_DEVICE_TYPE(...)
#define CHECK_SAME_DTYPE(...)
#define CHECK_SHAPE(tensor, ...)
#define CHECK_SHAPE_COMBINE_FIRST_DIMS(tensor, ...)
#define CHECK_SHAPE_IGNORE_FIRST_DIMS(tensor, ...)
#define CHECK_SHAPE_COMBINE_LAST_DIMS(tensor, ...)
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor, ...)

Typedefs

typedef std::remove_const< decltype(torch::kInt32)>::type TorchDtype_t

Functions

template<class T>
TorchDtype_t ToTorchDtype ()
template<>
TorchDtype_t ToTorchDtype< uint8_t > ()
template<>
TorchDtype_t ToTorchDtype< int8_t > ()
template<>
TorchDtype_t ToTorchDtype< int16_t > ()
template<>
TorchDtype_t ToTorchDtype< int32_t > ()
template<>
TorchDtype_t ToTorchDtype< int64_t > ()
template<>
TorchDtype_t ToTorchDtype< float > ()
template<>
TorchDtype_t ToTorchDtype< double > ()
template<class T, class TDtype>
bool CompareTorchDtype (const TDtype &t)
bool SameDeviceType (std::initializer_list< torch::Tensor > tensors)
bool SameDtype (std::initializer_list< torch::Tensor > tensors)
std::string TensorInfoStr (std::initializer_list< torch::Tensor > tensors)
torch::Tensor CreateTempTensor (const int64_t size, const torch::Device &device, void **ptr=nullptr)
std::vector< open3d::ml::op_util::DimValueGetShapeVector (torch::Tensor tensor)
template<open3d::ml::op_util::CSOpt Opt = open3d::ml::op_util::CSOpt::NONE, class TDimX, class... TArgs>
std::tuple< bool, std::string > CheckShape (torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)

Macro Definition Documentation

◆ CHECK_CONTIGUOUS

#define CHECK_CONTIGUOUS ( x)
Value:
do { \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") \
} while (0)

◆ CHECK_CUDA

#define CHECK_CUDA ( x)
Value:
do { \
TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") \
} while (0)

◆ CHECK_SAME_DEVICE_TYPE

#define CHECK_SAME_DEVICE_TYPE ( ...)
Value:
do { \
if (!SameDeviceType({__VA_ARGS__})) { \
TORCH_CHECK( \
false, \
#__VA_ARGS__ \
" must all have the same device type but got " + \
TensorInfoStr({__VA_ARGS__})) \
} \
} while (0)
std::string TensorInfoStr(std::initializer_list< torch::Tensor > tensors)
Definition TorchHelper.h:120
bool SameDeviceType(std::initializer_list< torch::Tensor > tensors)
Definition TorchHelper.h:95

◆ CHECK_SAME_DTYPE

#define CHECK_SAME_DTYPE ( ...)
Value:
do { \
if (!SameDtype({__VA_ARGS__})) { \
TORCH_CHECK(false, \
#__VA_ARGS__ \
" must all have the same dtype but got " + \
TensorInfoStr({__VA_ARGS__})) \
} \
} while (0)
bool SameDtype(std::initializer_list< torch::Tensor > tensors)
Definition TorchHelper.h:108

◆ CHECK_SHAPE

#define CHECK_SHAPE ( tensor,
... )
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)
std::tuple< bool, std::string > CheckShape(torch::Tensor tensor, TDimX &&dimex, TArgs &&... args)
Definition TorchHelper.h:158

◆ CHECK_SHAPE_COMBINE_FIRST_DIMS

#define CHECK_SHAPE_COMBINE_FIRST_DIMS ( tensor,
... )
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_COMBINE_LAST_DIMS

#define CHECK_SHAPE_COMBINE_LAST_DIMS ( tensor,
... )
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_IGNORE_FIRST_DIMS

#define CHECK_SHAPE_IGNORE_FIRST_DIMS ( tensor,
... )
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_SHAPE_IGNORE_LAST_DIMS

#define CHECK_SHAPE_IGNORE_LAST_DIMS ( tensor,
... )
Value:
do { \
bool cs_success_; \
std::string cs_errstr_; \
std::tie(cs_success_, cs_errstr_) = \
CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
TORCH_CHECK(cs_success_, \
"invalid shape for '" #tensor "', " + cs_errstr_) \
} while (0)

◆ CHECK_TYPE

#define CHECK_TYPE ( x,
type )
Value:
do { \
TORCH_CHECK(x.dtype() == torch::type, #x " must have type " #type) \
} while (0)

Typedef Documentation

◆ TorchDtype_t

typedef std::remove_const<decltype(torch::kInt32)>::type TorchDtype_t

Function Documentation

◆ CheckShape()

template<open3d::ml::op_util::CSOpt Opt = open3d::ml::op_util::CSOpt::NONE, class TDimX, class... TArgs>
std::tuple< bool, std::string > CheckShape ( torch::Tensor tensor,
TDimX && dimex,
TArgs &&... args )

◆ CompareTorchDtype()

template<class T, class TDtype>
bool CompareTorchDtype ( const TDtype & t)
inline

◆ CreateTempTensor()

torch::Tensor CreateTempTensor ( const int64_t size,
const torch::Device & device,
void ** ptr = nullptr )
inline

◆ GetShapeVector()

std::vector< open3d::ml::op_util::DimValue > GetShapeVector ( torch::Tensor tensor)
inline

◆ SameDeviceType()

bool SameDeviceType ( std::initializer_list< torch::Tensor > tensors)
inline

◆ SameDtype()

bool SameDtype ( std::initializer_list< torch::Tensor > tensors)
inline

◆ TensorInfoStr()

std::string TensorInfoStr ( std::initializer_list< torch::Tensor > tensors)
inline

◆ ToTorchDtype()

template<class T>
TorchDtype_t ToTorchDtype ( )
inline

◆ ToTorchDtype< double >()

template<>
TorchDtype_t ToTorchDtype< double > ( )
inline

◆ ToTorchDtype< float >()

template<>
TorchDtype_t ToTorchDtype< float > ( )
inline

◆ ToTorchDtype< int16_t >()

template<>
TorchDtype_t ToTorchDtype< int16_t > ( )
inline

◆ ToTorchDtype< int32_t >()

template<>
TorchDtype_t ToTorchDtype< int32_t > ( )
inline

◆ ToTorchDtype< int64_t >()

template<>
TorchDtype_t ToTorchDtype< int64_t > ( )
inline

◆ ToTorchDtype< int8_t >()

template<>
TorchDtype_t ToTorchDtype< int8_t > ( )
inline

◆ ToTorchDtype< uint8_t >()

template<>
TorchDtype_t ToTorchDtype< uint8_t > ( )
inline