Update slices of the tensor in-place with weighted sum. More...
#include <utility_ops.h>
Public Member Functions | |
USE_SIMPLE_CTOR_DTOR (ScatterWeightedSumOp) | |
bool | RunOnDevice () override |
Public Member Functions inherited from caffe2::Operator< Context > | |
Operator (const OperatorDef &operator_def, Workspace *ws) | |
Operator (const c10::FunctionSchema &fn_schema, std::vector< c10::IValue > inputs, std::vector< at::Tensor > outputs) | |
const Tensor & | Input (int idx, DeviceType type=Context::GetDeviceType()) |
Retrieve a non-owning reference to the input at position 'idx' for this operator. More... | |
Tensor | XOutput (int idx, at::IntArrayRef dims, at::TensorOptions options) |
XOutput is a modernized version of Output which returns a Tensor rather than a Tensor* (the raw pointer in the latter case is useless, as Tensor is a pointer type.) | |
Public Member Functions inherited from caffe2::OperatorBase | |
OperatorBase (const OperatorDef &operator_def, Workspace *ws) | |
OperatorBase (const c10::FunctionSchema &schema, std::vector< c10::IValue > inputs, std::vector< at::Tensor > outputs) | |
bool | isLegacyOperator () const |
Return true if the operator was instantiated with OperatorDef New operators should be instantiated with FunctionSchema. | |
const c10::FunctionSchema & | getFunctionSchema () const |
bool | HasArgument (const string &name) const |
Checks if the operator has an argument of the given name. | |
template<typename T > | |
T | GetSingleArgument (const string &name, const T &default_value) const |
template<typename T > | |
bool | HasSingleArgumentOfType (const string &name) const |
template<typename T > | |
vector< T > | GetVectorFromIValueList (const c10::IValue &value) const |
template<typename T > | |
vector< T > | GetRepeatedArgument (const string &name, const vector< T > &default_value={}) const |
template<typename T > | |
const T & | Input (int idx) |
template<typename T > | |
const T & | Input (int idx, DeviceType type) |
template<typename T > | |
T * | Output (int idx) |
template<typename T > | |
T * | Output (int idx, DeviceType type) |
Tensor | XOutputTensor (int idx, at::IntArrayRef dims, at::TensorOptions options) |
void | SetOutputTensor (int idx, Tensor tensor) |
Tensor | OutputTensorOrUndefined (int idx) |
Tensor * | OutputTensor (int idx, at::IntArrayRef dims, at::TensorOptions options) |
Tensor * | OutputTensorCopyFrom (int idx, at::TensorOptions options, const Tensor &src, bool async=false) |
Tensor * | OutputTensorAlias (int idx, const Tensor &src) |
template<typename T > | |
T * | Output (int idx, T *allocated) |
const Blob & | InputBlob (int idx) |
Blob * | OutputBlob (int idx) |
bool | IsInputOutputAlias (int i, int j) |
template<typename T > | |
bool | InputIsType (int idx) |
bool | InputIsTensorType (int idx, DeviceType device_type) |
template<typename T > | |
bool | OutputIsType (int idx) |
bool | OutputIsTensorType (int idx, DeviceType type) |
int | InputSize () const |
int | OutputSize () const |
const vector< const Blob * > & | Inputs () const |
const vector< Blob * > & | Outputs () |
vector< TensorShape > | InputTensorShapes () const |
virtual void | WaitEvent (const Event &ev, int=-1) |
void | Wait (const OperatorBase &other, int stream_id=-1) |
virtual void | WaitEvents (const std::vector< const Event * > &events, int=-1) |
virtual void | Finish () |
virtual bool | Run (int=0) |
virtual bool | HasAsyncPart () const |
virtual bool | SupportsAsyncScheduling () const |
virtual bool | RunAsync (int stream_id=0) |
virtual void | AddRelatedBlobInfo (EnforceNotMet *err) |
const OperatorDef & | debug_def () const |
void | set_debug_def (const std::shared_ptr< const OperatorDef > &operator_def) |
bool | has_debug_def () const |
void | RecordLastFailedOpNetPosition () |
int | net_position () const |
void | set_net_position (int idx) |
const DeviceOption & | device_option () const |
const Event & | event () const |
Event & | event () |
void | ResetEvent () |
void | DisableEvent () |
bool | IsEventDisabled () const |
virtual void | SyncDeviceBarrierForObservers () |
virtual bool | IsStreamFree (int) const |
const std::string & | type () const |
void | annotate_engine (const std::string &engine) |
const std::string & | engine () const |
void | SetExecutorHelper (ExecutorHelper *helper) |
ExecutorHelper * | GetExecutorHelper () const |
std::vector< at::Tensor > | move_newstyle_outputs ()&& |
template<> | |
NetDef | GetSingleArgument (const std::string &name, const NetDef &default_value) const |
template<> | |
vector< int > | GetVectorFromIValueList (const c10::IValue &value) const |
template<> | |
vector< float > | GetVectorFromIValueList (const c10::IValue &value) const |
template<> | |
vector< string > | GetVectorFromIValueList (const c10::IValue &value) const |
Public Member Functions inherited from caffe2::Observable< OperatorBase > | |
Observable (Observable &&)=default | |
Observable & | operator= (Observable &&)=default |
C10_DISABLE_COPY_AND_ASSIGN (Observable) | |
const Observer * | AttachObserver (std::unique_ptr< Observer > observer) |
std::unique_ptr< Observer > | DetachObserver (const Observer *observer_ptr) |
Returns a unique_ptr to the removed observer. More... | |
virtual size_t | NumObservers () |
void | StartAllObservers () |
void | StopAllObservers () |
Data Fields | |
USE_OPERATOR_CONTEXT_FUNCTIONS | |
USE_DISPATCH_HELPER | |
Additional Inherited Members | |
Public Types inherited from caffe2::Observable< OperatorBase > | |
using | Observer = ObserverBase< OperatorBase > |
Static Public Attributes inherited from caffe2::OperatorBase | |
static const int | kNoNetPositionSet = -1 |
Protected Member Functions inherited from caffe2::OperatorBase | |
virtual void | RecordEvent (const char *=nullptr) |
void | SetEventFinished (const char *err_msg=nullptr) |
void | SetEventFinishedWithException (const char *err_msg=nullptr) |
std::string | getErrorMsg () |
C10_DISABLE_COPY_AND_ASSIGN (OperatorBase) | |
Protected Attributes inherited from caffe2::OperatorBase | |
std::unique_ptr< Event > | event_ |
Protected Attributes inherited from caffe2::Observable< OperatorBase > | |
std::vector< std::unique_ptr< Observer > > | observers_list_ |
Update slices of the tensor in-place with weighted sum.
ScatterWeightedSumOp is similar to WeightedSum and computes the weighted sum of several tensors. The first tensor has to be in-place and only slices of it on the first dimension as indexed by INDICES will be updated.
Input: X_0 - tensor to be updated weight_0 - scalar weight for X_0, applied only to slices affected, INDICES - 1-D list of indices on the first dimension of X_0 that need to be updated X_1 - update slices, has to have shape of len(INDICES) + shape(X_0)[1:] weight_1 - scalar weight for X_1 update X_2, weight_2, ...
Output: X_0 - has to be exactly the same tensor as the input 0
Note: The op pretty much ignores the exact shapes of the input arguments and cares only about sizes. It's done for performance consideration to avoid unnecessary reshapes. Only first dimension of X_0 is important, let's call it N. If M is the total size of X_0 and K is the size of INDICES then X_i is assumed to be of shape K x (M / N) regardless of the real shape.
Note: Each update in INDICES is applied independently which means that if duplicated elements are present in INDICES the corresponding slice of X_0 will be scaled multiple times. Manual collapsing of INDICES is required beforehand if necessary.
Note: Updates are applied sequentially by inputs which might have undesired consequences if the input tensor is accessed concurrently by different op (e.g. when doing Hogwild). Other threads might see intermediate results even on individual slice level, e.g. X_0 scaled by weight_0 but without any updates applied.
For now really works only on CPU because of INDICES access
Definition at line 525 of file utility_ops.h.