Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Member Functions | Data Fields
caffe2::ScatterWeightedSumOp< T, Context > Class Template Reference

Update slices of the tensor in-place with weighted sum. More...

#include <utility_ops.h>

Inheritance diagram for caffe2::ScatterWeightedSumOp< T, Context >:
caffe2::Operator< Context > caffe2::OperatorBase caffe2::Observable< OperatorBase >

Public Member Functions

 USE_SIMPLE_CTOR_DTOR (ScatterWeightedSumOp)
 
bool RunOnDevice () override
 
- Public Member Functions inherited from caffe2::Operator< Context >
 Operator (const OperatorDef &operator_def, Workspace *ws)
 
const Tensor< Context > & Input (int idx)
 
Tensor< Context > * Output (int idx)
 
void WaitEvent (const Event &ev, int stream_id=-1) final
 
void WaitEvents (const std::vector< const Event * > &events, int stream_id=-1) final
 
bool Run (int stream_id=0) final
 
bool RunAsync (int stream_id=0) final
 
bool IsStreamFree (int stream_id) const override
 
bool HasAsyncPart () const override
 
bool SupportsAsyncScheduling () const override
 
- Public Member Functions inherited from caffe2::OperatorBase
 OperatorBase (const OperatorDef &operator_def, Workspace *ws)
 
bool HasArgument (const string &name) const
 Checks if the operator has an argument of the given name.
 
template<typename T >
GetSingleArgument (const string &name, const T &default_value) const
 
template<typename T >
bool HasSingleArgumentOfType (const string &name) const
 
template<typename T >
vector< T > GetRepeatedArgument (const string &name, const vector< T > &default_value={}) const
 
template<typename T >
const T & Input (int idx)
 
template<typename T >
T * Output (int idx)
 
const BlobInputBlob (int idx)
 
BlobOutputBlob (int idx)
 
template<typename T >
bool InputIsType (int idx)
 
template<typename T >
bool OutputIsType (int idx)
 
int InputSize ()
 
int OutputSize ()
 
const vector< const Blob * > & Inputs () const
 
const vector< Blob * > & Outputs ()
 
vector< TensorShape > InputTensorShapes ()
 
void Wait (const OperatorBase &other, int stream_id=-1)
 
virtual void Finish ()
 
virtual void AddRelatedBlobInfo (EnforceNotMet *err)
 
const OperatorDef & debug_def () const
 
void set_debug_def (const std::shared_ptr< const OperatorDef > &operator_def)
 
bool has_debug_def () const
 
void RecordLastFailedOpNetPosition ()
 
int net_position () const
 
void set_net_position (int idx)
 
const DeviceOption & device_option () const
 
const Eventevent () const
 
Eventevent ()
 
void ResetEvent ()
 
void DisableEvent ()
 
bool IsEventDisabled () const
 
const std::string & type ()
 
void annotate_engine (const std::string &engine)
 
const std::string & engine () const
 
- Public Member Functions inherited from caffe2::Observable< OperatorBase >
const ObserverAttachObserver (std::unique_ptr< Observer > observer)
 
std::unique_ptr< ObserverDetachObserver (const Observer *observer_ptr)
 Returns a unique_ptr to the removed observer. More...
 
virtual size_t NumObservers ()
 
void StartAllObservers ()
 
void StopAllObservers ()
 

Data Fields

 USE_OPERATOR_CONTEXT_FUNCTIONS
 
 USE_DISPATCH_HELPER
 

Additional Inherited Members

- Public Types inherited from caffe2::Observable< OperatorBase >
using Observer = ObserverBase< OperatorBase >
 
- Static Public Attributes inherited from caffe2::OperatorBase
static constexpr int kNoNetPositionSet = -1
 
- Protected Member Functions inherited from caffe2::Operator< Context >
void RecordEvent (const char *err_msg=nullptr) final
 
std::string getErrorMsg ()
 
- Protected Member Functions inherited from caffe2::OperatorBase
 DISABLE_COPY_AND_ASSIGN (OperatorBase)
 
- Protected Attributes inherited from caffe2::Operator< Context >
Context context_
 
- Protected Attributes inherited from caffe2::OperatorBase
std::unique_ptr< Eventevent_
 
- Protected Attributes inherited from caffe2::Observable< OperatorBase >
std::vector< std::unique_ptr< Observer > > observers_list_
 

Detailed Description

template<typename T, class Context>
class caffe2::ScatterWeightedSumOp< T, Context >

Update slices of the tensor in-place with weighted sum.

ScatterWeightedSumOp is similar to WeightedSum and computes the weighted sum of several tensors. The first tensor has to be in-place and only slices of it on the first dimension as indexed by INDICES will be updated.

Input: X_0 - tensor to be updated weight_0 - scalar weight for X_0, applied only to slices affected, INDICES - 1-D list of indices on the first dimension of X_0 that need to be updated X_1 - update slices, has to have shape of len(INDICES) + shape(X_0)[1:] weight_1 - scalar weight for X_1 update X_2, weight_2, ...

Output: X_0 - has to be exactly the same tensor as the input 0

Note: The op pretty much ignores the exact shapes of the input arguments and cares only about sizes. It's done for performance consideration to avoid unnecessary reshapes. Only first dimension of X_0 is important, let's call it N. If M is the total size of X_0 and K is the size of INDICES then X_i is assumed to be of shape K x (M / N) regardless of the real shape.

Note: Each update in INDICES is applied independently which means that if duplicated elements are present in INDICES the corresponding slice of X_0 will be scaled multiple times. Manual collapsing of INDICES is required beforehand if necessary.

Note: Updates are applied sequentially by inputs which might have undesired consequences if the input tensor is accessed concurrently by different op (e.g. when doing Hogwild). Other threads might see intermediate results even on individual slice level, e.g. X_0 scaled by weight_0 but without any updates applied.

For now really works only on CPU because of INDICES access

Definition at line 475 of file utility_ops.h.


The documentation for this class was generated from the following file: