Caffe2 - C++ API
A deep learning, cross platform ML framework
Public Types | Public Member Functions | Data Fields | Static Public Attributes
caffe2::AbstractUnsortedSegmentOp< T, SIndex, Context, Reducer, SparseFused, InputAccessor > Class Template Reference

Unsorted segment reduction op with optional fused embedding lookup. More...

#include <segment_reduction_op.h>

Inheritance diagram for caffe2::AbstractUnsortedSegmentOp< T, SIndex, Context, Reducer, SparseFused, InputAccessor >:
caffe2::Operator< Context > caffe2::OperatorBase caffe2::Observable< OperatorBase >

Public Types

enum  { INDICES = Reducer::kInputCount, SEGMENT_IDS = Reducer::kInputCount + (SparseFused ? 1 : 0) }
 
- Public Types inherited from caffe2::Observable< OperatorBase >
using Observer = ObserverBase< OperatorBase >
 

Public Member Functions

 AbstractUnsortedSegmentOp (const OperatorDef &operator_def, Workspace *ws)
 
bool RunOnDevice () override
 
template<typename IndexType >
bool DoRunWithType ()
 
template<typename IndexType , int FixedSize>
bool DoRunWithValue ()
 
- Public Member Functions inherited from caffe2::Operator< Context >
 Operator (const OperatorDef &operator_def, Workspace *ws)
 
const Tensor< Context > & Input (int idx)
 
Tensor< Context > * Output (int idx)
 
void WaitEvent (const Event &ev, int stream_id=-1) final
 
void WaitEvents (const std::vector< const Event * > &events, int stream_id=-1) final
 
bool Run (int stream_id=0) final
 
bool RunAsync (int stream_id=0) final
 
bool IsStreamFree (int stream_id) const override
 
bool HasAsyncPart () const override
 
bool SupportsAsyncScheduling () const override
 
- Public Member Functions inherited from caffe2::OperatorBase
 OperatorBase (const OperatorDef &operator_def, Workspace *ws)
 
bool HasArgument (const string &name) const
 Checks if the operator has an argument of the given name.
 
template<typename T >
GetSingleArgument (const string &name, const T &default_value) const
 
template<typename T >
bool HasSingleArgumentOfType (const string &name) const
 
template<typename T >
vector< T > GetRepeatedArgument (const string &name, const vector< T > &default_value={}) const
 
template<typename T >
const T & Input (int idx)
 
template<typename T >
T * Output (int idx)
 
const BlobInputBlob (int idx)
 
BlobOutputBlob (int idx)
 
template<typename T >
bool InputIsType (int idx)
 
template<typename T >
bool OutputIsType (int idx)
 
int InputSize ()
 
int OutputSize ()
 
const vector< const Blob * > & Inputs () const
 
const vector< Blob * > & Outputs ()
 
vector< TensorShape > InputTensorShapes ()
 
void Wait (const OperatorBase &other, int stream_id=-1)
 
virtual void Finish ()
 
virtual void AddRelatedBlobInfo (EnforceNotMet *err)
 
const OperatorDef & debug_def () const
 
void set_debug_def (const std::shared_ptr< const OperatorDef > &operator_def)
 
bool has_debug_def () const
 
void RecordLastFailedOpNetPosition ()
 
int net_position () const
 
void set_net_position (int idx)
 
const DeviceOption & device_option () const
 
const Eventevent () const
 
Eventevent ()
 
void ResetEvent ()
 
void DisableEvent ()
 
bool IsEventDisabled () const
 
const std::string & type ()
 
void annotate_engine (const std::string &engine)
 
const std::string & engine () const
 
- Public Member Functions inherited from caffe2::Observable< OperatorBase >
const ObserverAttachObserver (std::unique_ptr< Observer > observer)
 
std::unique_ptr< ObserverDetachObserver (const Observer *observer_ptr)
 Returns a unique_ptr to the removed observer. More...
 
virtual size_t NumObservers ()
 
void StartAllObservers ()
 
void StopAllObservers ()
 

Data Fields

 USE_OPERATOR_CONTEXT_FUNCTIONS
 

Static Public Attributes

static constexpr int kSelfInputs = SparseFused ? 2 : 1
 
static constexpr int kNumInputs = Reducer::kInputCount + kSelfInputs
 
- Static Public Attributes inherited from caffe2::OperatorBase
static constexpr int kNoNetPositionSet = -1
 

Additional Inherited Members

- Protected Member Functions inherited from caffe2::Operator< Context >
void RecordEvent (const char *err_msg=nullptr) final
 
std::string getErrorMsg ()
 
- Protected Member Functions inherited from caffe2::OperatorBase
 DISABLE_COPY_AND_ASSIGN (OperatorBase)
 
- Protected Attributes inherited from caffe2::Operator< Context >
Context context_
 
- Protected Attributes inherited from caffe2::OperatorBase
std::unique_ptr< Eventevent_
 
- Protected Attributes inherited from caffe2::Observable< OperatorBase >
std::vector< std::unique_ptr< Observer > > observers_list_
 

Detailed Description

template<typename T, typename SIndex, class Context, class Reducer, bool SparseFused = true, class InputAccessor = BaseInputAccessor<T>>
class caffe2::AbstractUnsortedSegmentOp< T, SIndex, Context, Reducer, SparseFused, InputAccessor >

Unsorted segment reduction op with optional fused embedding lookup.

Base implementation for UnsortedSegmentXXX and UnsparseSortedSegmentXXX depending on SparseFused static argument.

Unlike the sorted version it allows to have "gaps" in segment ids.

Inputs: 0: DATA - input embedding to do lookups in 1..P: AUX_ARG_ - optional additional arguments to be passed to the reducer, should have the same first dimension as SEGMENT_IDS (e.g. scalars in WeightedSum)

if SparseFused == true:

P+1: INDICES - 1-D vector with indices to look up in DATA. Should have the same dimension as SEGMENT_IDS

P+1 if SparseFused == false:

P+1 or P+2: SEGMENT_IDS - unsorted segment ids 1-D vector

Args: num_segments - allows to override the dimension of the output. If not set it would be inferred from segment_ids tensor.

Output: Tensor with first dimension of K, where K is the max segment id + 1. Rest of dimensions are decided by reducer but usually are the same size as extra dimensions of DATA

Definition at line 1010 of file segment_reduction_op.h.


The documentation for this class was generated from the following file: