19 #ifndef OST_QA_MULTI_CLASSIFIER_HH
20 #define OST_QA_MULTI_CLASSIFIER_HH
29 #include <boost/shared_ptr.hpp>
34 namespace ost {
namespace qa {
43 template <
typename DS>
46 template <
typename T1,
55 template <
typename T1,
81 struct If<true, T, F> {
87 struct If<false, T, F> {
93 template <
typename T1,
96 template <
typename T1,
101 template <
typename T>
110 template <
typename C,
typename T,
typename F>
113 template <
typename C,
117 typedef typename If<IsEqual<NullType, C>::Value, T, F>
::Type Type;
126 : number_of_classes_(number_of_classes) {
129 : number_of_classes_(0) {}
132 return number_of_classes_;
144 lower_bound_(lower_bound) {
148 assert(this->GetNumberOfClasses()>idx);
156 template <
typename DS>
159 ds & number_of_classes_;
173 lower_bound_(lower_bound),
174 upper_bound_(upper_bound) {
177 Real factor=(value-lower_bound_)/(upper_bound_-lower_bound_);
180 assert(this->GetNumberOfClasses()>idx);
185 lower_bound_(0), upper_bound_(1) {
187 template <
typename DS>
190 ds & number_of_classes_;
200 template <
typename T>
215 #if OST_DOUBLE_PRECISION
230 template <
typename I>
234 template <
typename C,
typename T,
typename I>
237 template <
typename C,
typename I>
245 template <
typename C,
typename T,
typename I>
248 index[i]=classifier.GetIndexOf(value);
251 template <
typename T>
252 struct NumberOfClasses;
262 template <
typename T>
265 return t.GetNumberOfClasses();
270 template <
typename V,
typename T1,
271 typename T2=impl::NullType,
272 typename T3=impl::NullType,
273 typename T4=impl::NullType,
274 typename T5=impl::NullType,
275 typename T6=impl::NullType,
276 typename T7=impl::NullType>
283 typedef Classifier<T1>
C1;
284 typedef Classifier<T2>
C2;
285 typedef Classifier<T3>
C3;
286 typedef Classifier<T4>
C4;
287 typedef Classifier<T5>
C5;
288 typedef Classifier<T6>
C6;
289 typedef Classifier<T7>
C7;
309 : classifier1_(c1), classifier2_(c2), classifier3_(c3),
310 classifier4_(c4), classifier5_(c5), classifier6_(c6),
312 this->ExtractNumberOfClasses();
314 uint32_t total=this->CalculateNumberOfBuckets();
315 buckets_.resize(total, initial_value);
320 memset(number_of_classes_, 0,
sizeof(number_of_classes_));
323 template <
typename DS>
334 this->ExtractNumberOfClasses();
340 : classifier1_(rhs.classifier1_), classifier2_(rhs.classifier2_),
341 classifier3_(rhs.classifier3_), classifier4_(rhs.classifier4_),
342 classifier5_(rhs.classifier5_), classifier6_(rhs.classifier6_),
343 classifier7_(rhs.classifier7_) {
344 this->ExtractNumberOfClasses();
345 uint32_t total=this->CalculateNumberOfBuckets();
346 buckets_.resize(total);
347 memcpy(&buckets_.front(), &rhs.buckets_.front(),
sizeof(V)*total);
351 return static_cast<uint32_t>(buckets_.size());
355 T1 x1=T1(), T2 x2=T2(),
356 T3 x3=T3(), T4 x4=T4(),
357 T5 x5=T5(), T6 x6=T6(),
359 IndexType index=this->FindBucket(x1, x2, x3, x4, x5, x6, x7);
360 uint32_t linear_index=this->LinearizeBucketIndex(index);
361 buckets_[linear_index]+=value;
365 T3 x3=T3(), T4 x4=T4(),
366 T5 x5=T5(), T6 x6=T6(), T7 x7=T7())
const {
367 IndexType index=this->FindBucket(x1, x2, x3, x4, x5, x6, x7);
368 uint32_t linear_index=this->LinearizeBucketIndex(index);
369 return buckets_[linear_index];
374 return buckets_[this->LinearizeBucketIndex(index)];
379 buckets_[this->LinearizeBucketIndex(index)]=value;
384 IndexType FindBucket(T1 x1=T1(), T2 x2=T2(), T3 x3=T3(),
385 T4 x4=T4(), T5 x5=T5(), T6 x6=T6(),
391 IndexType> find_index_1(classifier1_, 0, x1, index);
393 IndexType> find_index_2(classifier2_, 1, x2, index);
395 IndexType> find_index_3(classifier3_, 2, x3, index);
397 IndexType> find_index_4(classifier4_, 3, x4, index);
399 IndexType> find_index_5(classifier5_, 4, x5, index);
401 IndexType> find_index_6(classifier6_, 5, x6, index);
403 IndexType> find_index_7(classifier7_, 6, x7, index);
409 buckets_[this->LinearizeBucketIndex(index)]+=value;
412 void ExtractNumberOfClasses()
423 uint32_t LinearizeBucketIndex(
const IndexType& index)
const
428 linear_index+=factor*index[i];
429 factor*=number_of_classes_[i];
434 uint32_t CalculateNumberOfBuckets()
const
438 total*=number_of_classes_[i];
450 std::vector<ValueType> buckets_;