Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.19.0
Performance library for Deep Learning
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
116  view = mkldnn_view,
120  sum = mkldnn_sum,
127  lrn = mkldnn_lrn,
130  rnn = mkldnn_rnn,
131  };
132 
134  struct at {
142 
143  at(const primitive &aprimitive, size_t at = 0)
144  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
146  inline operator primitive() const;
147  };
148 
151  // TODO: use the C++ API wrapper structure.
152 };
153 
155  return static_cast<mkldnn_primitive_kind_t>(akind);
156 }
161 struct error: public std::exception {
163  std::string message;
165 
172 
173  error(mkldnn_status_t astatus, std::string amessage,
174  mkldnn_primitive_t aerror_primitive = 0)
175  : status(astatus)
176  , message(amessage)
177  , error_primitive(aerror_primitive, true)
178  {}
179 
187 
189  const std::string &message,
191  {
192  if (status != mkldnn_success) {
193  if (nullptr != error_primitive)
194  throw error(status, message, *error_primitive);
195  else
196  throw error(status, message, nullptr);
197  }
198  }
199 };
200 
201 inline primitive::at::operator primitive() const {
204  mkldnn_primitive_get_output(data.primitive,
205  data.output_index, &output),
206  "could not get an output primitive");
207  return primitive(const_cast<mkldnn_primitive_t>(output), true);
208 }
209 
213  "could not get primitive descriptor by primitive");
214  return pd;
215 }
217 
222 
226 };
227 
229  return static_cast<mkldnn_round_mode_t>(mode);
230 }
231 
234 };
235 
237  return static_cast<mkldnn_padding_kind_t>(kind);
238 }
239 
240 enum prop_kind {
249 };
250 
252  return static_cast<mkldnn_prop_kind_t>(kind);
253 }
254 
255 enum algorithm {
282 };
283 
285  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
286 }
287 
292 };
293 
295  batch_normalization_flag aflag) {
296  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
297 }
298 
305 };
306 
308  return static_cast<mkldnn_rnn_direction_t>(adir);
309 }
310 
311 enum query {
313 
316 
319 
322 
324 
337 
347 };
348 
350  return static_cast<mkldnn_query_t>(aquery);
351 }
352 
354 
360 
361 #ifndef DOXYGEN_SHOULD_SKIP_THIS
362 template <> struct handle_traits<mkldnn_post_ops_t> {
363  static constexpr auto destructor = &mkldnn_post_ops_destroy;
364 };
365 #endif
366 
367 struct post_ops: public handle<mkldnn_post_ops_t> {
369  mkldnn_post_ops_t result;
371  "could not create post operation sequence");
372  reset(result);
373  }
374 
375  int len() const { return mkldnn_post_ops_len(get()); }
376 
377  primitive::kind kind(int index) const {
380  "post_ops index is out of range");
381  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
382  index));
383  }
384 
385  void append_sum(float scale = 1.) {
387  "could not append sum");
388  }
389 
390  void get_params_sum(int index, float &scale) const {
392  "could not get sum params");
393  }
394 
395  void append_eltwise(float scale, algorithm alg, float alpha,
396  float beta) {
398  convert_to_c(alg), alpha, beta),
399  "could not append eltwise");
400  }
401 
402  void get_params_eltwise(int index, float &scale, algorithm &alg,
403  float &alpha, float &beta) const {
404  mkldnn_alg_kind_t c_alg;
406  &scale, &c_alg, &alpha, &beta),
407  "could not get eltwise params");
408  alg = static_cast<algorithm>(c_alg);
409  }
410 };
411 
412 #ifndef DOXYGEN_SHOULD_SKIP_THIS
413 template <> struct handle_traits<mkldnn_primitive_attr_t> {
414  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
415 };
416 #endif
417 
418 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
422  "could not create a primitive attr");
423  reset(result);
424  }
425 
427  mkldnn_round_mode_t result;
429  get(), &result), "could not get int output round mode");
430  return round_mode(result);
431  }
432 
435  get(), mkldnn::convert_to_c(mode)),
436  "could not set int output round mode");
437  }
438 
439  void get_output_scales(int &mask, std::vector<float> &scales) const
440  {
441  int count, c_mask;
442  const float *c_scales;
444  &count, &c_mask, &c_scales),
445  "could not get int output scales");
446  scales.resize(count);
447 
448  mask = c_mask;
449  for (int c = 0; c < count; ++c)
450  scales[c] = c_scales[c];
451  }
452 
453  void set_output_scales(int mask, const std::vector<float> &scales)
454  {
456  (int)scales.size(), mask, &scales[0]),
457  "could not set int output scales");
458  }
459 
460  const post_ops get_post_ops() const {
461  post_ops result;
462  const_mkldnn_post_ops_t c_result;
464  "could not get post operation sequence");
465  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
466  return result;
467  }
468 
469  void set_post_ops(post_ops ops) {
471  "could not set post operation sequence");
472  }
473 
474  void set_rnn_data_qparams(const float scale, const float shift)
475  {
477  scale, shift), "could not set rnn data int scale/shift");
478  }
479 
480  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
481  {
483  (int)scales.size(), mask, &scales[0]),
484  "could not set rnn weights int scales");
485  }
486 };
487 
489 
495 
496 #ifndef DOXYGEN_SHOULD_SKIP_THIS
497 template <> struct handle_traits<mkldnn_engine_t> {
498  static constexpr auto destructor = &mkldnn_engine_destroy;
499 };
500 #endif
501 
503 struct engine: public handle<mkldnn_engine_t> {
504  friend class primitive;
505  // gcc bug??? using handle::handle;
506 
508  enum kind {
513  };
514 
518 
519  static size_t get_count(kind akind) {
520  return mkldnn_engine_get_count(convert_to_c(akind));
521  }
522 
528 
529  engine(kind akind, size_t index) {
530  mkldnn_engine_t aengine;
532  mkldnn_engine_create(&aengine,
533  convert_to_c(akind), index),
534  "could not create an engine");
535  reset(aengine);
536  }
537 
538  explicit engine(const mkldnn_engine_t& aengine)
539  : handle(aengine, true) {}
540 
542  mkldnn_engine_t engine_q;
545  mkldnn::convert_to_c(eengine), 0, &engine_q),
546  "could not get engine from primitive_desc");
547  reset(engine_q, true);
548  }
549 
550  template <class primitive_desc>
551  static engine query(const primitive_desc &pd) {
552  mkldnn_engine_t engine_q;
555  mkldnn::convert_to_c(eengine), 0, &engine_q),
556  "could not get engine from primitive_desc");
557 
558  return engine(engine_q);
559  }
560 
561 private:
562  static mkldnn_engine_kind_t convert_to_c(kind akind) {
563  return static_cast<mkldnn_engine_kind_t>(akind);
564  }
565 };
566 
568 
571 
577 
579 struct memory: public primitive {
580  private:
581  std::shared_ptr<char> _handle;
582 
583  public:
584  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
585 
586  template <typename T> static void validate_dims(std::vector<T> v) {
587  if (v.size() > TENSOR_MAX_DIMS)
589  "invalid dimensions");
590  }
591 
594  enum data_type {
601  };
602 
605  enum format {
752  };
753 
755  struct desc {
756  friend struct memory;
759 
765  desc(dims adims, data_type adata_type,
766  format aformat) {
767  validate_dims(adims);
769  mkldnn_memory_desc_init(&data, (int)adims.size(),
770  adims.size() == 0 ? nullptr : &adims[0],
771  convert_to_c(adata_type), convert_to_c(aformat)),
772  "could not initialize a memory descriptor");
773  }
774 
778  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
779  };
780 
782  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
783  friend struct memory;
784 
785  // TODO: make private
787 
789  primitive_desc(const desc &adesc, const engine &aengine) {
793  &adesc.data, aengine.get()),
794  "could not initialize a memory primitive descriptor");
795  reset(result);
796  }
797 
801  return memory::desc(*memory_d); }
802 
805  size_t get_size() const {
807  }
808 
809  bool operator==(const primitive_desc &other) const {
810  return (0 == mkldnn_memory_primitive_desc_equal(get(),
811  other.get())) ? false : true;
812  }
813 
814  bool operator!=(const primitive_desc &other) const {
815  return !operator==(other);
816  }
817 
818  engine get_engine() { return engine::query(*this); }
819  };
820 
824  memory(const primitive &aprimitive): primitive(aprimitive) {}
828  memory(const primitive_desc &adesc) {
829  mkldnn_primitive_t result;
831  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
832  "could not create a memory primitive");
833  reset(result);
834  auto _malloc = [](size_t size, int alignment) {
835  void *ptr;
836 #ifdef _WIN32
837  ptr = _aligned_malloc(size, alignment);
838  int rc = ((ptr)? 0 : errno);
839 #else
840  int rc = ::posix_memalign(&ptr, alignment, size);
841 #endif /* _WIN32 */
842  return (rc == 0) ? (char*)ptr : nullptr;
843  };
844  auto _free = [](char* p) {
845 #ifdef _WIN32
846  _aligned_free((void*)p);
847 #else
848  ::free((void*)p);
849 #endif /* _WIN32 */
850  };
851  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
852  set_data_handle(_handle.get());
853  }
854 
855  memory(const primitive_desc &adesc, void *ahandle) {
856  mkldnn_primitive_t result;
858  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
859  "could not create a memory primitive");
860  reset(result);
861  set_data_handle(ahandle);
862  }
863 
866  primitive_desc adesc;
869  &cdesc),
870  "could not get primitive descriptor from a memory primitive");
871  /* FIXME: no const_cast should be here */
872  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
873  return adesc;
874  }
875 
878  inline void *get_data_handle() const {
879  void *handle;
881  "could not get native handle");
882  return handle;
883  }
884 
885  inline void set_data_handle(void *handle) const {
887  "could not set native handle");
888  }
889 
890  // Must go away or be private:
892  return static_cast<mkldnn_data_type_t>(adata_type);
893  }
895  return static_cast<mkldnn_memory_format_t>(aformat);
896  }
897 };
898 
900  auto zero = mkldnn_memory_desc_t();
901  zero.primitive_kind = mkldnn_memory;
902  return memory::desc(zero);
903 }
904 
905 inline memory null_memory(engine eng) {
907  return memory({zero, eng}, nullptr);
908 }
909 
911  &aprimitive_desc, int n_inputs, int n_outputs,
912  const std::string &prim_name) {
913  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
914  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
915  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
916  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
917  if (n_outputs_expected > n_outputs ) {
918  std::string message = "could not create " + prim_name +
919  " primitive, not enought output parameters";
920  throw error(mkldnn_invalid_arguments, message, nullptr);
921  }
922  if (n_inputs_expected > n_inputs ) {
923  std::string message = "could not create " + prim_name +
924  " primitive, not enought input parameters";
925  throw error(mkldnn_invalid_arguments, message, nullptr);
926  }
927 }
928 
929 
930 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
931  const_mkldnn_primitive_desc_t aprimitive_pd;
932  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
934  aprimitive_pd);
935 
936  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
937 }
938 
940  return a == memory::convert_to_c(b);
941 }
943  return !(a == b);
944 }
946  return b == a;
947 }
949  return !(a == b);
950 }
951 
953  return a == memory::convert_to_c(b);
954 }
956  return !(a == b);
957 }
959  return b == a;
960 }
962  return !(a == b);
963 }
964 
966 
972 
973 struct reorder : public primitive {
974  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
976  const memory::primitive_desc &output) {
979  &result, input.get(), output.get()),
980  "could not create a reorder primitive descriptor");
981  reset(result);
982  }
983 
985  const memory::primitive_desc &output,
986  const primitive_attr &aattr) {
989  &result, input.get(), output.get(), aattr.get()),
990  "could not create a reorder primitive descriptor");
991  reset(result);
992  }
993 
994  engine get_engine() { return engine::query(*this); }
995  };
996 
997  reorder(const primitive_desc &aprimitive_desc,
998  const primitive::at &input, const memory &output) {
999  mkldnn_primitive_t result;
1000  mkldnn_primitive_at_t inputs[] = { input.data };
1001  const_mkldnn_primitive_t outputs[] = { output.get() };
1003  aprimitive_desc.get(), inputs, outputs),
1004  "could not create a reorder primitive");
1005  reset(result);
1006  }
1007 
1008  reorder(const primitive::at &input, const memory &output) {
1009  auto input_mpd = memory(input).get_primitive_desc();
1010  auto output_mpd = output.get_primitive_desc();
1011 
1012  auto reorder_d = primitive_desc(input_mpd, output_mpd);
1013 
1014  mkldnn_primitive_t result;
1015  mkldnn_primitive_at_t inputs[] = { input.data };
1016  const_mkldnn_primitive_t outputs[] = { output.get() };
1018  reorder_d.get(), inputs, outputs),
1019  "could not create a reorder primitive");
1020  reset(result);
1021  }
1022 };
1023 
1025 
1031 
1032 struct view : public primitive {
1033  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1035  memory::dims offsets) {
1036  mkldnn_primitive_desc_t result;
1037 
1039  &result, input.get(), &dims[0], &offsets[0]),
1040  "could not create a view primitive descriptor");
1041  reset(result);
1042  }
1043 
1045  memory::primitive_desc adesc;
1047  const_mkldnn_primitive_desc_t const_cdesc =
1051  const_cdesc),
1052  "could not clone a dst primitive descriptor");
1053  adesc.reset(cdesc);
1054  return adesc;
1055  }
1056 
1057  engine get_engine() { return engine::query(*this); }
1058  };
1059 
1060  view(const primitive_desc &view_pd, primitive::at input) {
1061  mkldnn_primitive_t result;
1062  mkldnn_primitive_at_t inputs[] = { input.data };
1064  view_pd.get(), inputs, nullptr),
1065  "could not create a view primitive");
1066  reset(result);
1067  }
1068 
1069  view(memory input, memory::dims dims, memory::dims offsets) {
1070  mkldnn_primitive_t result;
1071  primitive_desc view_pd(input.get_primitive_desc(), dims,
1072  offsets);
1073  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1075  view_pd.get(), inputs, nullptr),
1076  "could not create a view primitive");
1077  reset(result);
1078  }
1079 };
1080 
1082 
1088 
1089 struct concat : public primitive {
1090  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1091  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1092  std::vector<memory::primitive_desc> inputs) {
1093  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1094  c_api_inputs.reserve(inputs.size());
1095  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1096  std::transform(inputs.begin(), inputs.end(),
1097  std::back_inserter(c_api_inputs), convert_to_c);
1098  return c_api_inputs;
1099  }
1100 
1101  primitive_desc(const memory::desc &output, int concat_dimension,
1102  std::vector<memory::primitive_desc> inputs) {
1103  mkldnn_primitive_desc_t result;
1104 
1105  auto c_api_inputs = cpp_to_c(inputs);
1106 
1108  &result, &output.data, (int)c_api_inputs.size(),
1109  concat_dimension, &c_api_inputs[0]),
1110  "could not create a concat primitive descriptor");
1111  reset(result);
1112  }
1113 
1114  primitive_desc(int concat_dimension,
1115  std::vector<memory::primitive_desc> inputs) {
1116  mkldnn_primitive_desc_t result;
1117 
1118  auto c_api_inputs = cpp_to_c(inputs);
1119 
1121  &result, nullptr, (int)c_api_inputs.size(),
1122  concat_dimension, &c_api_inputs[0]),
1123  "could not create a concat primitive descriptor");
1124  reset(result);
1125  }
1126 
1128  memory::primitive_desc adesc;
1130  const_mkldnn_primitive_desc_t const_cdesc =
1133  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1134  "could not clone a dst primitive descriptor");
1135  adesc.reset(cdesc);
1136  return adesc;
1137  }
1138 
1139  engine get_engine() { return engine::query(*this); }
1140  };
1141 
1142  concat(const primitive_desc &concat_pd,
1143  std::vector<primitive::at> &inputs, const memory &output) {
1144  mkldnn_primitive_t result;
1145 
1146  std::vector<mkldnn_primitive_at_t> p_inputs;
1147  for (size_t i = 0; i < inputs.size(); i++)
1148  p_inputs.push_back(inputs[i].data);
1149  const_mkldnn_primitive_t outputs[] = { output.get() };
1150 
1152  concat_pd.get(), &p_inputs[0], outputs),
1153  "could not create a concat primitive");
1154  reset(result);
1155  }
1156 };
1157 
1159 
1165 
1166 struct sum : public primitive {
1167  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1168  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1169  std::vector<memory::primitive_desc> inputs) {
1170  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1171  c_api_inputs.reserve(inputs.size());
1172  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1173  std::transform(inputs.begin(), inputs.end(),
1174  std::back_inserter(c_api_inputs), convert_to_c);
1175  return c_api_inputs;
1176  }
1177 
1179  const std::vector<float> &scales,
1180  std::vector<memory::primitive_desc> inputs) {
1181  mkldnn_primitive_desc_t result;
1182 
1183  auto c_api_inputs = cpp_to_c(inputs);
1184 
1186  scales.size() == inputs.size() ? mkldnn_success
1188  "number of scales not equal to number of inputs");
1189 
1191  &result, &output.data, (int)c_api_inputs.size(),
1192  &scales[0], &c_api_inputs[0]),
1193  "could not create a sum primitive descriptor");
1194  reset(result);
1195  }
1196 
1197  primitive_desc(const std::vector<float> &scales,
1198  std::vector<memory::primitive_desc> inputs) {
1199  mkldnn_primitive_desc_t result;
1200 
1201  auto c_api_inputs = cpp_to_c(inputs);
1202 
1204  scales.size() == inputs.size() ? mkldnn_success
1206  "number of scales not equal to number of inputs");
1207 
1209  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1210  &c_api_inputs[0]),
1211  "could not create a sum primitive descriptor");
1212  reset(result);
1213  }
1214 
1216  memory::primitive_desc adesc;
1218  const_mkldnn_primitive_desc_t const_cdesc =
1222  const_cdesc),
1223  "could not clone a dst primitive descriptor");
1224  adesc.reset(cdesc);
1225  return adesc;
1226  }
1227 
1228  engine get_engine() { return engine::query(*this); }
1229  };
1230 
1231  sum(const primitive_desc &sum_pd,
1232  std::vector<primitive::at> &inputs, const memory &output) {
1233  mkldnn_primitive_t result;
1234 
1235  std::vector<mkldnn_primitive_at_t> p_inputs;
1236  for (size_t i = 0; i < inputs.size(); i++)
1237  p_inputs.push_back(inputs[i].data);
1238  const_mkldnn_primitive_t outputs[] = { output.get() };
1239 
1241  sum_pd.get(), &p_inputs[0], outputs),
1242  "could not create a sum primitive");
1243  reset(result);
1244  }
1245 };
1246 
1248 
1250 
1253 
1256 
1258 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1260  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1261  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1263  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1264  hint_fwd_pd);
1265  error::wrap_c_api(status,
1266  "could not create a primitive descriptor iterator");
1267  pd_iterator.reset(iterator);
1268  fetch_impl();
1269  }
1270 
1271  engine get_engine() { return engine::query(*this); }
1272 
1274  const_mkldnn_primitive_attr_t const_cattr;
1276  "could not get attributes");
1278  error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1279  "could not clone attributes");
1280 
1281  primitive_attr attr;
1282  attr.reset(cattr);
1283  return attr;
1284  }
1285 
1287  const char *impl_info_str() const {
1288  const char *res;
1290  mkldnn_query_impl_info_str, 0, &res),
1291  "could not query implementation info string");
1292  return res;
1293  }
1294 
1301  bool next_impl() {
1303  pd_iterator.get());
1304  if (status == mkldnn_iterator_ends) return false;
1305  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1306 
1307  fetch_impl();
1308  return true;
1309  }
1310 
1312  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1313  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1315  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1316  [=](query q) { return what == q; }))
1317  throw error(mkldnn_invalid_arguments, "invalid memory query");
1318 
1319  const_mkldnn_primitive_desc_t const_cdesc
1321  mkldnn::convert_to_c(what), idx);
1322 
1323  // TODO: is there a better way to inform about this?
1324  if (const_cdesc == nullptr)
1325  throw error(mkldnn_not_required, "queried memory is not required");
1326 
1328  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1329  "could not clone a memory primitive descriptor");
1330 
1332  ret.reset(cdesc);
1333  return ret;
1334  }
1335 
1336  // register specialized queries, e.g. src_primitive_desc()
1337 # define REG_QUERY_MPD(name, what, idx) \
1338  memory::primitive_desc name ## _primitive_desc() const \
1339  { return query_mpd(what ## _pd, idx); }
1340 
1341  private:
1342  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1343  void fetch_impl() {
1345  pd_iterator.get());
1347  "could not fetch a primitive descriptor from the iterator");
1348  reset(pd);
1349  }
1350 };
1351 
1353 
1359 
1361  struct desc {
1363  desc(prop_kind aprop_kind, algorithm aalgorithm,
1364  const memory::desc &src_desc,
1365  const memory::desc &weights_desc,
1366  const memory::desc &bias_desc,
1367  const memory::desc &dst_desc,
1368  const memory::dims strides,
1369  const memory::dims padding_l,
1370  const memory::dims padding_r,
1371  const padding_kind apadding_kind) {
1372  memory::validate_dims(strides);
1373  memory::validate_dims(padding_l);
1374  memory::validate_dims(padding_r);
1376  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1377  &src_desc.data, &weights_desc.data, &bias_desc.data,
1378  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1379  mkldnn::convert_to_c(apadding_kind)),
1380  "could not create a convolution forward descriptor");
1381  }
1382  desc(prop_kind aprop_kind, algorithm aalgorithm,
1383  const memory::desc &src_desc,
1384  const memory::desc &weights_desc,
1385  const memory::desc &dst_desc,
1386  const memory::dims strides,
1387  const memory::dims padding_l,
1388  const memory::dims padding_r,
1389  const padding_kind apadding_kind) {
1390  memory::validate_dims(strides);
1391  memory::validate_dims(padding_l);
1392  memory::validate_dims(padding_r);
1394  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1395  &src_desc.data, &weights_desc.data, nullptr,
1396  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1397  mkldnn::convert_to_c(apadding_kind)),
1398  "could not create a convolution forward descriptor");
1399  }
1400  desc(prop_kind aprop_kind, algorithm aalgorithm,
1401  const memory::desc &src_desc,
1402  const memory::desc &weights_desc,
1403  const memory::desc &bias_desc,
1404  const memory::desc &dst_desc,
1405  const memory::dims strides,
1406  const memory::dims dilates,
1407  const memory::dims padding_l,
1408  const memory::dims padding_r,
1409  const padding_kind apadding_kind) {
1410  memory::validate_dims(strides);
1411  memory::validate_dims(dilates);
1412  memory::validate_dims(padding_l);
1413  memory::validate_dims(padding_r);
1416  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1417  &src_desc.data, &weights_desc.data, &bias_desc.data,
1418  &dst_desc.data, &strides[0], &dilates[0],
1419  &padding_l[0], &padding_r[0],
1420  mkldnn::convert_to_c(apadding_kind)),
1421  "could not create a dilated convolution forward descriptor");
1422  }
1423  desc(prop_kind aprop_kind, algorithm aalgorithm,
1424  const memory::desc &src_desc,
1425  const memory::desc &weights_desc,
1426  const memory::desc &dst_desc,
1427  const memory::dims strides,
1428  const memory::dims dilates,
1429  const memory::dims padding_l,
1430  const memory::dims padding_r,
1431  const padding_kind apadding_kind) {
1432  memory::validate_dims(strides);
1433  memory::validate_dims(dilates);
1434  memory::validate_dims(padding_l);
1435  memory::validate_dims(padding_r);
1438  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1439  &src_desc.data, &weights_desc.data, nullptr,
1440  &dst_desc.data, &strides[0], &dilates[0],
1441  &padding_l[0], &padding_r[0],
1442  mkldnn::convert_to_c(apadding_kind)),
1443  "could not create a dilated convolution forward descriptor");
1444  }
1445  };
1446 
1448  primitive_desc(const desc &desc, const engine &e)
1449  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1450 
1451  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1452  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1453 
1454  REG_QUERY_MPD(src, src, 0);
1455  REG_QUERY_MPD(weights, weights, 0);
1456  REG_QUERY_MPD(bias, weights, 1);
1457  REG_QUERY_MPD(dst, dst, 0);
1458  };
1459 
1460  convolution_forward(const primitive_desc &aprimitive_desc,
1461  const primitive::at &src, const primitive::at &weights,
1462  const primitive::at &bias, const memory &dst) {
1463  mkldnn_primitive_t result;
1464  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1465  bias.data };
1466  const_mkldnn_primitive_t outputs[] = { dst.get() };
1468  aprimitive_desc.get(), inputs, outputs),
1469  "could not create a convolution forward bias primitive");
1470  reset(result);
1471  }
1472 
1473  convolution_forward(const primitive_desc &aprimitive_desc,
1474  const primitive::at &src, const primitive::at &weights,
1475  const memory &dst) {
1476  mkldnn_primitive_t result;
1477  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1478  const_mkldnn_primitive_t outputs[] = { dst.get() };
1479  check_num_parameters(aprimitive_desc.get(), 2, 1,
1480  "convolution forward");
1482  aprimitive_desc.get(), inputs, outputs),
1483  "could not create a convolution forward primitive");
1484  reset(result);
1485  }
1486 };
1487 
1489  struct desc {
1491  desc(algorithm aalgorithm,
1492  const memory::desc &diff_src_desc,
1493  const memory::desc &weights_desc,
1494  const memory::desc &diff_dst_desc,
1495  const memory::dims strides,
1496  const memory::dims padding_l,
1497  const memory::dims padding_r,
1498  const padding_kind apadding_kind) {
1499  memory::validate_dims(strides);
1500  memory::validate_dims(padding_l);
1501  memory::validate_dims(padding_r);
1503  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1504  &weights_desc.data, &diff_dst_desc.data,
1505  &strides[0], &padding_l[0], &padding_r[0],
1506  mkldnn::convert_to_c(apadding_kind)),
1507  "could not create a convolution backward data descriptor");
1508  }
1509  desc(algorithm aalgorithm,
1510  const memory::desc &diff_src_desc,
1511  const memory::desc &weights_desc,
1512  const memory::desc &diff_dst_desc,
1513  const memory::dims strides,
1514  const memory::dims dilates,
1515  const memory::dims padding_l,
1516  const memory::dims padding_r,
1517  const padding_kind apadding_kind) {
1518  memory::validate_dims(strides);
1519  memory::validate_dims(dilates);
1520  memory::validate_dims(padding_l);
1521  memory::validate_dims(padding_r);
1524  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1525  &weights_desc.data, &diff_dst_desc.data,
1526  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1527  mkldnn::convert_to_c(apadding_kind)),
1528  "could not create a convolution backward data descriptor");
1529  }
1530  };
1531 
1533  primitive_desc(const desc &desc, const engine &e,
1534  const convolution_forward::primitive_desc &hint_fwd_pd)
1535  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1536 
1537  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1538  const convolution_forward::primitive_desc &hint_fwd_pd)
1539  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1540 
1541  REG_QUERY_MPD(diff_src, diff_src, 0);
1542  REG_QUERY_MPD(weights, weights, 0);
1543  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1544  };
1545 
1547  const primitive::at &diff_dst, const primitive::at &weights,
1548  const memory &diff_src) {
1549  mkldnn_primitive_t result;
1550  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1551  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1552  check_num_parameters(aprimitive_desc.get(), 2, 1,
1553  "convolution backward data");
1555  aprimitive_desc.get(), inputs, outputs),
1556  "could not create a convolution backward data primitive");
1557  reset(result);
1558  }
1559 };
1560 
1562  struct desc {
1564  desc(algorithm aalgorithm,
1565  const memory::desc &src_desc,
1566  const memory::desc &diff_weights_desc,
1567  const memory::desc &diff_bias_desc,
1568  const memory::desc &diff_dst_desc,
1569  const memory::dims strides,
1570  const memory::dims padding_l,
1571  const memory::dims padding_r,
1572  const padding_kind apadding_kind) {
1573  memory::validate_dims(strides);
1574  memory::validate_dims(padding_l);
1575  memory::validate_dims(padding_r);
1577  &data, convert_to_c(aalgorithm), &src_desc.data,
1578  &diff_weights_desc.data, &diff_bias_desc.data,
1579  &diff_dst_desc.data,
1580  &strides[0], &padding_l[0], &padding_r[0],
1581  mkldnn::convert_to_c(apadding_kind)),
1582  "could not create a convolution backward weights descriptor");
1583  }
1584  desc(algorithm aalgorithm,
1585  const memory::desc &src_desc,
1586  const memory::desc &diff_weights_desc,
1587  const memory::desc &diff_dst_desc,
1588  const memory::dims strides,
1589  const memory::dims padding_l,
1590  const memory::dims padding_r,
1591  const padding_kind apadding_kind) {
1592  memory::validate_dims(strides);
1593  memory::validate_dims(padding_l);
1594  memory::validate_dims(padding_r);
1596  &data, convert_to_c(aalgorithm), &src_desc.data,
1597  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1598  &strides[0], &padding_l[0], &padding_r[0],
1599  mkldnn::convert_to_c(apadding_kind)),
1600  "could not create a convolution backward weights descriptor");
1601  }
1602  desc(algorithm aalgorithm,
1603  const memory::desc &src_desc,
1604  const memory::desc &diff_weights_desc,
1605  const memory::desc &diff_bias_desc,
1606  const memory::desc &diff_dst_desc,
1607  const memory::dims strides,
1608  const memory::dims dilates,
1609  const memory::dims padding_l,
1610  const memory::dims padding_r,
1611  const padding_kind apadding_kind) {
1612  memory::validate_dims(strides);
1613  memory::validate_dims(dilates);
1614  memory::validate_dims(padding_l);
1615  memory::validate_dims(padding_r);
1617  &data, convert_to_c(aalgorithm), &src_desc.data,
1618  &diff_weights_desc.data, &diff_bias_desc.data,
1619  &diff_dst_desc.data,
1620  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1621  mkldnn::convert_to_c(apadding_kind)),
1622  "could not create a convolution backward weights descriptor");
1623  }
1624  desc(algorithm aalgorithm,
1625  const memory::desc &src_desc,
1626  const memory::desc &diff_weights_desc,
1627  const memory::desc &diff_dst_desc,
1628  const memory::dims strides,
1629  const memory::dims dilates,
1630  const memory::dims padding_l,
1631  const memory::dims padding_r,
1632  const padding_kind apadding_kind) {
1633  memory::validate_dims(strides);
1634  memory::validate_dims(dilates);
1635  memory::validate_dims(padding_l);
1636  memory::validate_dims(padding_r);
1638  &data, convert_to_c(aalgorithm), &src_desc.data,
1639  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1640  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1641  mkldnn::convert_to_c(apadding_kind)),
1642  "could not create a convolution backward weights descriptor");
1643  }
1644 
1645  };
1646 
1648  primitive_desc(const desc &desc, const engine &e,
1649  const convolution_forward::primitive_desc &hint_fwd_pd)
1650  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1651 
1652  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1653  const convolution_forward::primitive_desc &hint_fwd_pd)
1654  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1655 
1656  REG_QUERY_MPD(src, src, 0);
1657  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1658  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1659  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1660  };
1661 
1663  const primitive::at &src, const primitive::at &diff_dst,
1664  const memory &diff_weights, const memory &diff_bias) {
1665  mkldnn_primitive_t result;
1666  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1667  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1668  diff_bias.get() };
1669  check_num_parameters(aprimitive_desc.get(), 2, 2,
1670  "convolution backward weights");
1672  aprimitive_desc.get(), inputs, outputs),
1673  "could not create a convolution backward weights primitive");
1674  reset(result);
1675  }
1677  const primitive::at &src, const primitive::at &diff_dst,
1678  const memory &diff_weights) {
1679  mkldnn_primitive_t result;
1680  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1681  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1682  check_num_parameters(aprimitive_desc.get(), 2, 1,
1683  "convolution backward weights");
1685  aprimitive_desc.get(), inputs, outputs),
1686  "could not create a convolution backward weights primitive");
1687  reset(result);
1688  }
1689 };
1690 
1692 //
1698 
1700  struct desc {
1702  desc(prop_kind aprop_kind, algorithm aalgorithm,
1703  const memory::desc &src_desc,
1704  const memory::desc &weights_desc,
1705  const memory::desc &bias_desc,
1706  const memory::desc &dst_desc,
1707  const memory::dims strides,
1708  const memory::dims padding_l,
1709  const memory::dims padding_r,
1710  const padding_kind apadding_kind) {
1711  memory::validate_dims(strides);
1712  memory::validate_dims(padding_l);
1713  memory::validate_dims(padding_r);
1715  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1716  &src_desc.data, &weights_desc.data, &bias_desc.data,
1717  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1718  mkldnn::convert_to_c(apadding_kind)),
1719  "could not create a deconvolution forward descriptor");
1720  }
1721  desc(prop_kind aprop_kind, algorithm aalgorithm,
1722  const memory::desc &src_desc,
1723  const memory::desc &weights_desc,
1724  const memory::desc &dst_desc,
1725  const memory::dims strides,
1726  const memory::dims padding_l,
1727  const memory::dims padding_r,
1728  const padding_kind apadding_kind) {
1729  memory::validate_dims(strides);
1730  memory::validate_dims(padding_l);
1731  memory::validate_dims(padding_r);
1733  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1734  &src_desc.data, &weights_desc.data, nullptr,
1735  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1736  mkldnn::convert_to_c(apadding_kind)),
1737  "could not create a deconvolution forward descriptor");
1738  }
1739  desc(prop_kind aprop_kind, algorithm aalgorithm,
1740  const memory::desc &src_desc,
1741  const memory::desc &weights_desc,
1742  const memory::desc &bias_desc,
1743  const memory::desc &dst_desc,
1744  const memory::dims strides,
1745  const memory::dims dilates,
1746  const memory::dims padding_l,
1747  const memory::dims padding_r,
1748  const padding_kind apadding_kind) {
1749  memory::validate_dims(strides);
1750  memory::validate_dims(dilates);
1751  memory::validate_dims(padding_l);
1752  memory::validate_dims(padding_r);
1754  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1755  &src_desc.data, &weights_desc.data, &bias_desc.data,
1756  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1757  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1758  "could not create a dilated deconvolution forward descriptor");
1759  }
1760  desc(prop_kind aprop_kind, algorithm aalgorithm,
1761  const memory::desc &src_desc,
1762  const memory::desc &weights_desc,
1763  const memory::desc &dst_desc,
1764  const memory::dims strides,
1765  const memory::dims dilates,
1766  const memory::dims padding_l,
1767  const memory::dims padding_r,
1768  const padding_kind apadding_kind) {
1769  memory::validate_dims(strides);
1770  memory::validate_dims(dilates);
1771  memory::validate_dims(padding_l);
1772  memory::validate_dims(padding_r);
1774  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1775  &src_desc.data, &weights_desc.data, nullptr,
1776  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1777  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1778  "could not create a dilated deconvolution forward descriptor");
1779  }
1780  };
1781 
1783  primitive_desc(const desc &desc, const engine &e)
1784  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1785 
1786  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1787  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1788 
1789  REG_QUERY_MPD(src, src, 0);
1790  REG_QUERY_MPD(weights, weights, 0);
1791  REG_QUERY_MPD(bias, weights, 1);
1792  REG_QUERY_MPD(dst, dst, 0);
1793  };
1794 
1795  deconvolution_forward(const primitive_desc &aprimitive_desc,
1796  const primitive::at &src, const primitive::at &weights,
1797  const primitive::at &bias, const memory &dst) {
1798  mkldnn_primitive_t result;
1799  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1800  bias.data };
1801  const_mkldnn_primitive_t outputs[] = { dst.get() };
1802  check_num_parameters(aprimitive_desc.get(), 3, 1,
1803  "deconvolution forward");
1805  aprimitive_desc.get(), inputs, outputs),
1806  "could not create a deconvolution forward bias primitive");
1807  reset(result);
1808  }
1809 
1810  deconvolution_forward(const primitive_desc &aprimitive_desc,
1811  const primitive::at &src, const primitive::at &weights,
1812  const memory &dst) {
1813  mkldnn_primitive_t result;
1814  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1815  const_mkldnn_primitive_t outputs[] = { dst.get() };
1816  check_num_parameters(aprimitive_desc.get(), 2, 1,
1817  "deconvolution forward");
1819  aprimitive_desc.get(), inputs, outputs),
1820  "could not create a deconvolution forward primitive");
1821  reset(result);
1822  }
1823 };
1824 
1826  struct desc {
1828  desc(algorithm aalgorithm,
1829  const memory::desc &diff_src_desc,
1830  const memory::desc &weights_desc,
1831  const memory::desc &diff_dst_desc,
1832  const memory::dims strides,
1833  const memory::dims padding_l,
1834  const memory::dims padding_r,
1835  const padding_kind apadding_kind) {
1836  memory::validate_dims(strides);
1837  memory::validate_dims(padding_l);
1838  memory::validate_dims(padding_r);
1840  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1841  &weights_desc.data, &diff_dst_desc.data,
1842  &strides[0], &padding_l[0], &padding_r[0],
1843  mkldnn::convert_to_c(apadding_kind)),
1844  "could not create a deconvolution backward data descriptor");
1845  }
1846  desc(algorithm aalgorithm,
1847  const memory::desc &diff_src_desc,
1848  const memory::desc &weights_desc,
1849  const memory::desc &diff_dst_desc,
1850  const memory::dims strides,
1851  const memory::dims dilates,
1852  const memory::dims padding_l,
1853  const memory::dims padding_r,
1854  const padding_kind apadding_kind) {
1855  memory::validate_dims(strides);
1856  memory::validate_dims(dilates);
1857  memory::validate_dims(padding_l);
1858  memory::validate_dims(padding_r);
1860  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1861  &weights_desc.data, &diff_dst_desc.data,
1862  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1863  mkldnn::convert_to_c(apadding_kind)),
1864  "could not create a dilated deconvolution backward data descriptor");
1865  }
1866  };
1867 
1869  primitive_desc(const desc &desc, const engine &e,
1870  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1871  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1872 
1873  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1874  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1875  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1876 
1877  REG_QUERY_MPD(diff_src, diff_src, 0);
1878  REG_QUERY_MPD(weights, weights, 0);
1879  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1880  };
1881 
1883  const primitive::at &diff_dst, const primitive::at &weights,
1884  const memory &diff_src) {
1885  mkldnn_primitive_t result;
1886  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1887  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1888  check_num_parameters(aprimitive_desc.get(), 2, 1,
1889  "deconvolution backward data");
1891  aprimitive_desc.get(), inputs, outputs),
1892  "could not create a deconvolution backward data primitive");
1893  reset(result);
1894  }
1895 };
1896 
1898  struct desc {
1900  desc(algorithm aalgorithm,
1901  const memory::desc &src_desc,
1902  const memory::desc &diff_weights_desc,
1903  const memory::desc &diff_bias_desc,
1904  const memory::desc &diff_dst_desc,
1905  const memory::dims strides,
1906  const memory::dims padding_l,
1907  const memory::dims padding_r,
1908  const padding_kind apadding_kind) {
1909  memory::validate_dims(strides);
1910  memory::validate_dims(padding_l);
1911  memory::validate_dims(padding_r);
1913  &data, convert_to_c(aalgorithm), &src_desc.data,
1914  &diff_weights_desc.data, &diff_bias_desc.data,
1915  &diff_dst_desc.data,
1916  &strides[0], &padding_l[0], &padding_r[0],
1917  mkldnn::convert_to_c(apadding_kind)),
1918  "could not create a deconvolution backward weights descriptor");
1919  }
1920  desc(algorithm aalgorithm,
1921  const memory::desc &src_desc,
1922  const memory::desc &diff_weights_desc,
1923  const memory::desc &diff_dst_desc,
1924  const memory::dims strides,
1925  const memory::dims padding_l,
1926  const memory::dims padding_r,
1927  const padding_kind apadding_kind) {
1928  memory::validate_dims(strides);
1929  memory::validate_dims(padding_l);
1930  memory::validate_dims(padding_r);
1932  &data, convert_to_c(aalgorithm), &src_desc.data,
1933  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1934  &strides[0], &padding_l[0], &padding_r[0],
1935  mkldnn::convert_to_c(apadding_kind)),
1936  "could not create a deconvolution backward weights descriptor");
1937  }
1938  desc(algorithm aalgorithm,
1939  const memory::desc &src_desc,
1940  const memory::desc &diff_weights_desc,
1941  const memory::desc &diff_bias_desc,
1942  const memory::desc &diff_dst_desc,
1943  const memory::dims strides,
1944  const memory::dims dilates,
1945  const memory::dims padding_l,
1946  const memory::dims padding_r,
1947  const padding_kind apadding_kind) {
1948  memory::validate_dims(strides);
1949  memory::validate_dims(dilates);
1950  memory::validate_dims(padding_l);
1951  memory::validate_dims(padding_r);
1953  &data, convert_to_c(aalgorithm), &src_desc.data,
1954  &diff_weights_desc.data, &diff_bias_desc.data,
1955  &diff_dst_desc.data,
1956  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1957  mkldnn::convert_to_c(apadding_kind)),
1958  "could not create a dilated deconvolution backward weights descriptor");
1959  }
1960  desc(algorithm aalgorithm,
1961  const memory::desc &src_desc,
1962  const memory::desc &diff_weights_desc,
1963  const memory::desc &diff_dst_desc,
1964  const memory::dims strides,
1965  const memory::dims dilates,
1966  const memory::dims padding_l,
1967  const memory::dims padding_r,
1968  const padding_kind apadding_kind) {
1969  memory::validate_dims(strides);
1970  memory::validate_dims(dilates);
1971  memory::validate_dims(padding_l);
1972  memory::validate_dims(padding_r);
1974  &data, convert_to_c(aalgorithm), &src_desc.data,
1975  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1976  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1977  mkldnn::convert_to_c(apadding_kind)),
1978  "could not create a dilated deconvolution backward weights descriptor");
1979  }
1980  };
1981 
1983  primitive_desc(const desc &desc, const engine &e,
1984  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1985  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1986 
1987  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1988  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1989  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1990 
1991  REG_QUERY_MPD(src, src, 0);
1992  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1993  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1994  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1995  };
1996 
1998  const primitive::at &src, const primitive::at &diff_dst,
1999  const memory &diff_weights, const memory &diff_bias) {
2000  mkldnn_primitive_t result;
2001  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2002  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2003  diff_bias.get() };
2004  check_num_parameters(aprimitive_desc.get(), 2, 2,
2005  "deconvolution backward weights");
2007  aprimitive_desc.get(), inputs, outputs),
2008  "could not create a deconvolution backward weights primitive");
2009  reset(result);
2010  }
2012  const primitive::at &src, const primitive::at &diff_dst,
2013  const memory &diff_weights) {
2014  mkldnn_primitive_t result;
2015  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2016  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2017  check_num_parameters(aprimitive_desc.get(), 2, 1,
2018  "deconvolution backward weights");
2020  aprimitive_desc.get(), inputs, outputs),
2021  "could not create a deconvolution backward weights primitive");
2022  reset(result);
2023  }
2024 };
2025 
2027 
2034 
2035 struct lrn_forward : public primitive {
2036  struct desc {
2038  desc(prop_kind aprop_kind, algorithm aalgorithm,
2039  const memory::desc &src_desc,
2040  int local_size, float alpha, float beta, float k)
2041  {
2043  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2044  &src_desc.data, local_size, alpha, beta, k),
2045  "could not create a lrn forward descriptor");
2046  }
2047  desc(prop_kind aprop_kind, algorithm aalgorithm,
2048  const memory::desc &src_desc,
2049  int local_size, float alpha, float beta)
2050  {
2052  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2053  &src_desc.data, local_size, alpha, beta, float(1.0)),
2054  "could not create a lrn forward descriptor");
2055  }
2056  };
2057 
2059  primitive_desc(const desc &desc, const engine &e)
2060  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2061 
2062  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2063  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2064 
2065  REG_QUERY_MPD(src, src, 0);
2066  REG_QUERY_MPD(dst, dst, 0);
2067  REG_QUERY_MPD(workspace, workspace, 0);
2068  };
2069 
2070  lrn_forward(const primitive_desc &aprimitive_desc,
2071  const primitive::at &src, const memory &workspace,
2072  const memory &dst) {
2073  mkldnn_primitive_t result;
2074  mkldnn_primitive_at_t inputs[] = { src.data };
2075  const_mkldnn_primitive_t outputs[] = { dst.get(),
2076  workspace.get() };
2077  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2079  aprimitive_desc.get(), inputs, outputs),
2080  "could not create a lrn forward primitive");
2081  reset(result);
2082  }
2083 
2084  lrn_forward(const primitive_desc &aprimitive_desc,
2085  const primitive::at &src, const memory &dst) {
2086  mkldnn_primitive_t result;
2087  mkldnn_primitive_at_t inputs[] = { src.data };
2088  const_mkldnn_primitive_t outputs[] = { dst.get() };
2089  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2091  aprimitive_desc.get(), inputs, outputs),
2092  "could not create a lrn forward primitive");
2093  reset(result);
2094  }
2095 };
2096 
2097 struct lrn_backward : public primitive {
2098  struct desc {
2100  desc(algorithm aalgorithm,
2101  const memory::desc &data_desc,
2102  const memory::desc &diff_data_desc,
2103  int local_size, float alpha, float beta, float k)
2104  {
2106  convert_to_c(aalgorithm), &diff_data_desc.data,
2107  &data_desc.data, local_size, alpha, beta, k),
2108  "could not create a lrn backward descriptor");
2109  }
2110  desc(algorithm aalgorithm,
2111  const memory::desc &data_desc,
2112  const memory::desc &diff_data_desc,
2113  int local_size, float alpha, float beta)
2114  {
2116  convert_to_c(aalgorithm), &diff_data_desc.data,
2117  &data_desc.data, local_size, alpha, beta, float(1.0)),
2118  "could not create a lrn backward descriptor");
2119  }
2120  };
2121 
2123  primitive_desc(const desc &desc, const engine &e,
2124  const lrn_forward::primitive_desc &hint_fwd_pd)
2125  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2126 
2127  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2128  const lrn_forward::primitive_desc &hint_fwd_pd)
2129  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2130 
2131  REG_QUERY_MPD(diff_src, diff_src, 0);
2132  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2133  REG_QUERY_MPD(workspace, workspace, 0);
2134  };
2135 
2136  lrn_backward(const primitive_desc &aprimitive_desc,
2137  const primitive::at &src, const primitive::at &diff_dst,
2138  const primitive::at &workspace, const memory &diff_src) {
2139  mkldnn_primitive_t result;
2140  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2141  workspace.data };
2142  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2143  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2145  aprimitive_desc.get(), inputs, outputs),
2146  "could not create a lrn backward primitive");
2147  reset(result);
2148  }
2149 
2150  lrn_backward(const primitive_desc &aprimitive_desc,
2151  const primitive::at &src, const primitive::at &diff_dst,
2152  const memory &diff_src) {
2153  mkldnn_primitive_t result;
2154  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2155  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2156  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2158  aprimitive_desc.get(), inputs, outputs),
2159  "could not create a lrn backward primitive");
2160  reset(result);
2161  }
2162 };
2163 
2165 
2171 
2172 struct pooling_forward : public primitive {
2173  struct desc {
2175  desc(prop_kind aprop_kind, algorithm aalgorithm,
2176  const memory::desc &src_desc,
2177  const memory::desc &dst_desc,
2178  const memory::dims strides,
2179  const memory::dims kernel,
2180  const memory::dims padding_l,
2181  const memory::dims padding_r,
2182  const padding_kind apadding_kind) {
2183  memory::validate_dims(strides);
2184  memory::validate_dims(kernel);
2185  memory::validate_dims(padding_l);
2186  memory::validate_dims(padding_r);
2188  mkldnn::convert_to_c(aprop_kind),
2189  convert_to_c(aalgorithm),
2190  &src_desc.data, &dst_desc.data,
2191  &strides[0], &kernel[0],
2192  &padding_l[0], &padding_r[0],
2193  mkldnn::convert_to_c(apadding_kind)),
2194  "could not init a forward pooling descriptor");
2195  }
2196  };
2197 
2199  primitive_desc(const desc &desc, const engine &e)
2200  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2201 
2202  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2203  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2204 
2205  REG_QUERY_MPD(src, src, 0);
2206  REG_QUERY_MPD(dst, dst, 0);
2207  REG_QUERY_MPD(workspace, workspace, 0);
2208  };
2209 
2210  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2211  const memory &dst) {
2212  mkldnn_primitive_t result;
2213  mkldnn_primitive_at_t inputs[] = { src.data };
2214  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2215  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2217  aprimitive_desc.get(), inputs, outputs),
2218  "could not create a pooling forward primitive");
2219  reset(result);
2220  }
2221 
2222  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2223  const memory &dst, const memory &workspace) {
2224  mkldnn_primitive_t result;
2225  mkldnn_primitive_at_t inputs[] = { src.data };
2226  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2227  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2229  aprimitive_desc.get(), inputs, outputs),
2230  "could not create a pooling forward primitive");
2231  reset(result);
2232  }
2233 };
2234 
2235 struct pooling_backward : public primitive {
2236  struct desc {
2238  desc(algorithm aalgorithm,
2239  const memory::desc &diff_src_desc,
2240  const memory::desc &diff_dst_desc,
2241  const memory::dims &strides,
2242  const memory::dims &kernel,
2243  const memory::dims &padding_l,
2244  const memory::dims &padding_r,
2245  const padding_kind apadding_kind) {
2246  memory::validate_dims(strides);
2247  memory::validate_dims(kernel);
2248  memory::validate_dims(padding_l);
2249  memory::validate_dims(padding_r);
2251  convert_to_c(aalgorithm),
2252  &diff_src_desc.data, &diff_dst_desc.data,
2253  &strides[0], &kernel[0],
2254  &padding_l[0], &padding_r[0],
2255  mkldnn::convert_to_c(apadding_kind)),
2256  "could not init a backward pooling descriptor");
2257  }
2258  };
2259 
2261  primitive_desc(const desc &desc, const engine &e,
2262  const pooling_forward::primitive_desc &hint_fwd_pd)
2263  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2264 
2265  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2266  const pooling_forward::primitive_desc &hint_fwd_pd)
2267  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2268 
2269  REG_QUERY_MPD(diff_src, diff_src, 0);
2270  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2271  REG_QUERY_MPD(workspace, workspace, 0);
2272  };
2273 
2274  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2275  const memory &diff_src) {
2276  mkldnn_primitive_t result;
2277  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2278  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2279  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2281  aprimitive_desc.get(), inputs, outputs),
2282  "could not create a pooling backward primitive");
2283  reset(result);
2284  }
2285 
2286  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2287  const primitive::at &workspace, const memory &diff_src) {
2288  mkldnn_primitive_t result;
2289  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2290  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2291  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2293  aprimitive_desc.get(), inputs, outputs),
2294  "could not create a pooling backward primitive");
2295  reset(result);
2296  }
2297 };
2298 
2300 
2307 
2308 struct eltwise_forward : public primitive {
2309  struct desc {
2311  template <typename T>
2312  desc(prop_kind aprop_kind, algorithm alg_kind,
2313  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2315  mkldnn::convert_to_c(aprop_kind),
2316  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2317  static_cast<float>(alpha), static_cast<float>(beta)),
2318  "could not create a eltwise forward descriptor");
2319  }
2320  };
2321 
2323  primitive_desc(const desc &desc, const engine &e)
2324  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2325 
2326  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2327  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2328 
2329  REG_QUERY_MPD(src, src, 0);
2330  REG_QUERY_MPD(dst, dst, 0);
2331  };
2332 
2333  eltwise_forward(const primitive_desc &aprimitive_desc,
2334  const primitive::at &src, const memory &dst) {
2335  mkldnn_primitive_t result;
2336  mkldnn_primitive_at_t inputs[] = { src.data };
2337  const_mkldnn_primitive_t outputs[] = { dst.get() };
2338  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2340  aprimitive_desc.get(), inputs, outputs),
2341  "could not create a eltwise forward primitive");
2342  reset(result);
2343  }
2344 };
2345 
2346 struct eltwise_backward : public primitive {
2347  struct desc {
2349 
2350  template <typename T>
2351  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2352  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2354  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2355  &data_desc.data, static_cast<float>(alpha),
2356  static_cast<float>(beta)),
2357  "could not create a eltwise backward descriptor");
2358  }
2359  };
2360 
2362  primitive_desc(const desc &desc, const engine &e,
2363  const eltwise_forward::primitive_desc &hint_fwd_pd)
2364  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2365 
2366  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2367  const eltwise_forward::primitive_desc &hint_fwd_pd)
2368  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2369 
2370  REG_QUERY_MPD(src, src, 0);
2371  REG_QUERY_MPD(diff_src, diff_src, 0);
2372  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2373  };
2374 
2375  eltwise_backward(const primitive_desc &aprimitive_desc,
2376  const primitive::at &src, const primitive::at &diff_dst,
2377  const memory &diff_src) {
2378  mkldnn_primitive_t result;
2379  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2380  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2381  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2383  aprimitive_desc.get(), inputs, outputs),
2384  "could not create a eltwise backward primitive");
2385  reset(result);
2386  }
2387 };
2388 
2390 
2396 
2397 struct softmax_forward : public primitive {
2398  struct desc {
2400  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2401  int softmax_axis) {
2403  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2404  softmax_axis),
2405  "could not create a softmax forward descriptor");
2406  }
2407  };
2408 
2410  primitive_desc(const desc &desc, const engine &e)
2411  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2412 
2413  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2414  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2415 
2416  REG_QUERY_MPD(src, src, 0);
2417  REG_QUERY_MPD(dst, dst, 0);
2418  };
2419 
2420  softmax_forward(const primitive_desc &aprimitive_desc,
2421  const primitive::at &src, const memory &dst) {
2422  mkldnn_primitive_t result;
2423  mkldnn_primitive_at_t inputs[] = { src.data };
2424  const_mkldnn_primitive_t outputs[] = { dst.get() };
2425  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2427  aprimitive_desc.get(), inputs, outputs),
2428  "could not create a softmax forward primitive");
2429  reset(result);
2430  }
2431 };
2432 
2433 struct softmax_backward : public primitive {
2434  struct desc {
2436  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2437  int softmax_axis) {
2439  &diff_desc.data, &data_desc.data, softmax_axis),
2440  "could not init a backward softmax descriptor");
2441  }
2442  };
2443 
2445  primitive_desc(const desc &desc, const engine &e,
2446  const softmax_forward::primitive_desc &hint_fwd_pd)
2447  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2448 
2449  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2450  const softmax_forward::primitive_desc &hint_fwd_pd)
2451  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2452 
2453  REG_QUERY_MPD(dst, dst, 0);
2454  REG_QUERY_MPD(diff_src, diff_src, 0);
2455  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2456  REG_QUERY_MPD(workspace, workspace, 0);
2457  };
2458 
2459  softmax_backward(const primitive_desc &aprimitive_desc,
2460  const primitive::at &dst, const primitive::at &diff_dst,
2461  const memory &diff_src) {
2462  mkldnn_primitive_t result;
2463  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2464  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2466  aprimitive_desc.get(), inputs, outputs),
2467  "could not create a softmax backward primitive");
2468  reset(result);
2469  }
2470 };
2471 
2473 
2479 
2481  struct desc {
2483  template <typename T>
2484  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2485  unsigned flags) {
2488  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2489  static_cast<float>(epsilon), flags),
2490  "could not create a batch normalization forward descriptor");
2491  }
2492  };
2493 
2495  primitive_desc(const desc &desc, const engine &e)
2496  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2497 
2498  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2499  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2500 
2501  REG_QUERY_MPD(src, src, 0);
2502  REG_QUERY_MPD(weights, weights, 0);
2503  REG_QUERY_MPD(dst, dst, 0);
2504  REG_QUERY_MPD(workspace, workspace, 0);
2505 
2507  { return stat_primitive_desc(mean); }
2509  { return stat_primitive_desc(var); }
2510 
2511  private:
2512  enum { mean = 1, var = 2, };
2513  memory::primitive_desc stat_primitive_desc(int kind) const {
2517  "could not get a batch-normalization descriptor");
2518  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2519  }
2520  };
2521 
2523  const primitive::at &src, const primitive::at &mean,
2524  const primitive::at &variance, const primitive::at &weights,
2525  const memory &dst) {
2526  mkldnn_primitive_t result;
2527  mkldnn_primitive_at_t inputs[] = { src.data,
2528  mean.data, variance.data, weights.data };
2529  const_mkldnn_primitive_t outputs[] = { dst.get() };
2530  check_num_parameters(aprimitive_desc.get(), 4, 1,
2531  "batch normalization forward");
2533  aprimitive_desc.get(), inputs, outputs),
2534  "could not create a batch normalization forward primitive");
2535  reset(result);
2536  }
2537 
2539  const primitive::at &src, const primitive::at &mean,
2540  const primitive::at &variance, const memory &dst) {
2541  mkldnn_primitive_t result;
2542  mkldnn_primitive_at_t inputs[] = { src.data,
2543  mean.data, variance.data };
2544  const_mkldnn_primitive_t outputs[] = { dst.get() };
2545  check_num_parameters(aprimitive_desc.get(), 3, 1,
2546  "batch normalization forward");
2548  aprimitive_desc.get(), inputs, outputs),
2549  "could not create a batch normalization forward primitive");
2550  reset(result);
2551  }
2552 
2561  const primitive::at &src, const primitive::at &weights,
2562  const memory &dst, const memory &mean, const memory &variance) {
2563  mkldnn_primitive_t result;
2564  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2565  const_mkldnn_primitive_t outputs[] = { dst.get(),
2566  mean.get(), variance.get() };
2567  check_num_parameters(aprimitive_desc.get(), 2, 3,
2568  "batch normalization forward");
2570  aprimitive_desc.get(), inputs, outputs),
2571  "could not create a batch normalization forward primitive");
2572  reset(result);
2573  }
2574 
2576  const primitive::at &src, const primitive::at &weights,
2577  const memory &dst, const memory &mean, const memory &variance,
2578  const memory &workspace) {
2579  mkldnn_primitive_t result;
2580  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2581  const_mkldnn_primitive_t outputs[] = { dst.get(),
2582  mean.get(), variance.get(), workspace.get() };
2583  check_num_parameters(aprimitive_desc.get(), 2, 4,
2584  "batch normalization forward");
2586  aprimitive_desc.get(), inputs, outputs),
2587  "could not create a batch normalization forward primitive");
2588  reset(result);
2589  }
2590 
2592  const primitive::at &src, const memory &dst, const memory &mean,
2593  const memory &variance) {
2594  mkldnn_primitive_t result;
2595  mkldnn_primitive_at_t inputs[] = { src.data };
2596  const_mkldnn_primitive_t outputs[] = { dst.get(),
2597  mean.get(), variance.get() };
2598  check_num_parameters(aprimitive_desc.get(), 1, 3,
2599  "batch normalization forward");
2601  aprimitive_desc.get(), inputs, outputs),
2602  "could not create a batch normalization forward primitive");
2603  reset(result);
2604  }
2605 
2617  const primitive::at &src, const memory &dst, const memory &mean,
2618  const memory &variance, const memory &workspace) {
2619  mkldnn_primitive_t result;
2620  mkldnn_primitive_at_t inputs[2] = { src.data };
2621  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2622  mean.get(), variance.get(), workspace.get() };
2623 
2624  if (1) { // check whether this is the `wrong` constructor
2625  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2626  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2627  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2628  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2629  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2630  // shift parameters, get rid of workspace, and add weights...
2631  auto _weights = dst;
2632  inputs[1] = {_weights.get(), 0};
2633 
2634  auto _dst = mean, _mean = variance, _variance = workspace;
2635  outputs[0] = _dst.get();
2636  outputs[1] = _mean.get();
2637  outputs[2] = _variance.get();
2638  outputs[3] = nullptr;
2639  }
2640  }
2642  aprimitive_desc.get(), inputs, outputs),
2643  "could not create a batch normalization forward primitive");
2644  reset(result);
2645  }
2646 
2648  const primitive::at &src, const primitive::at &weights,
2649  const memory &dst) {
2650  mkldnn_primitive_t result;
2651  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2652  const_mkldnn_primitive_t outputs[] = { dst.get() };
2653  check_num_parameters(aprimitive_desc.get(), 2, 1,
2654  "batch normalization forward");
2656  aprimitive_desc.get(), inputs, outputs),
2657  "could not create a batch normalization forward primitive");
2658  reset(result);
2659  }
2660 
2662  const primitive::at &src, const memory &dst) {
2663  mkldnn_primitive_t result;
2664  mkldnn_primitive_at_t inputs[] = { src.data };
2665  const_mkldnn_primitive_t outputs[] = { dst.get() };
2666  check_num_parameters(aprimitive_desc.get(), 1, 1,
2667  "batch normalization forward");
2669  aprimitive_desc.get(), inputs, outputs),
2670  "could not create a batch normalization forward primitive");
2671  reset(result);
2672  }
2673 };
2674 
2676  struct desc {
2678  template <typename T>
2679  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2680  const memory::desc &data_desc, T epsilon, unsigned flags) {
2683  mkldnn::convert_to_c(aprop_kind),
2684  &diff_data_desc.data, &data_desc.data,
2685  static_cast<float>(epsilon), flags),
2686  "could not create a batch normalization backward descriptor");
2687  }
2688  };
2689 
2691  primitive_desc(const desc &desc, const engine &e,
2693  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2694 
2695  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2697  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2698 
2699  REG_QUERY_MPD(src, src, 0);
2700  REG_QUERY_MPD(mean, src, 1);
2701  REG_QUERY_MPD(variance, src, 2);
2702  REG_QUERY_MPD(weights, weights, 0);
2703  REG_QUERY_MPD(dst, dst, 0);
2704  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2705  REG_QUERY_MPD(workspace, workspace, 0);
2706 
2707  REG_QUERY_MPD(diff_src, diff_src, 0);
2708  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2709  };
2710 
2711  // Prop_kind == backward
2713  const primitive::at &src, const primitive::at &mean,
2714  const primitive::at &variance, const primitive::at &diff_dst,
2715  const primitive::at &weights, const memory &diff_src,
2716  const memory &diff_weights) {
2717  mkldnn_primitive_t result;
2718  mkldnn_primitive_at_t inputs[] = { src.data,
2719  mean.data, variance.data, diff_dst.data, weights.data };
2720  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2721  diff_weights.get() };
2722  check_num_parameters(aprimitive_desc.get(), 5, 2,
2723  "batch normalization backward");
2725  aprimitive_desc.get(), inputs, outputs),
2726  "could not create a batch normalization backward primitive");
2727  reset(result);
2728  }
2729 
2730  // Prop_kind == backward (+ws)
2732  const primitive::at &src, const primitive::at &mean,
2733  const primitive::at &variance, const primitive::at &diff_dst,
2734  const primitive::at &weights, const primitive::at &workspace,
2735  const memory &diff_src, const memory &diff_weights) {
2736  mkldnn_primitive_t result;
2737  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2738  diff_dst.data, weights.data, workspace.data };
2739  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2740  diff_weights.get() };
2741  check_num_parameters(aprimitive_desc.get(), 6, 2,
2742  "batch normalization backward");
2744  aprimitive_desc.get(), inputs, outputs),
2745  "could not create a batch normalization backward primitive");
2746  reset(result);
2747  }
2748 
2749  // Prop_kind == backward_data (+ws or +weights)
2754  const primitive::at &src, const primitive::at &mean,
2755  const primitive::at &variance,const primitive::at &diff_dst,
2756  const primitive::at &weights_or_workspace, const memory &diff_src) {
2757  mkldnn_primitive_t result;
2758  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2759  diff_dst.data, weights_or_workspace.data };
2760  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2761  check_num_parameters(aprimitive_desc.get(), 5, 1,
2762  "batch normalization backward");
2764  aprimitive_desc.get(), inputs, outputs),
2765  "could not create a batch normalization backward primitive");
2766  reset(result);
2767  }
2768 
2769  // Prop_kind == backward_data
2771  const primitive::at &src, const primitive::at &mean,
2772  const primitive::at &variance, const primitive::at &diff_dst,
2773  const memory &diff_src) {
2774  mkldnn_primitive_t result;
2775  mkldnn_primitive_at_t inputs[] = { src.data,
2776  mean.data, variance.data, diff_dst.data };
2777  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2778  check_num_parameters(aprimitive_desc.get(), 4, 1,
2779  "batch normalization backward");
2781  aprimitive_desc.get(), inputs, outputs),
2782  "could not create a batch normalization backward primitive");
2783  reset(result);
2784  }
2785 };
2786 
2788 
2794 
2796  struct desc {
2798  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2799  const memory::desc &weights_desc,
2800  const memory::desc &bias_desc,
2801  const memory::desc &dst_desc) {
2804  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2805  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2806  "could not create a inner product forward descriptor");
2807  }
2808 
2809  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2810  const memory::desc &weights_desc,
2811  const memory::desc &dst_desc) {
2814  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2815  &weights_desc.data, nullptr, &dst_desc.data),
2816  "could not create a inner product forward descriptor");
2817  }
2818  };
2819 
2821  primitive_desc(const desc &desc, const engine &e)
2822  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2823 
2824  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2825  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2826 
2827  REG_QUERY_MPD(src, src, 0);
2828  REG_QUERY_MPD(weights, weights, 0);
2829  REG_QUERY_MPD(bias, weights, 1);
2830  REG_QUERY_MPD(dst, dst, 0);
2831  };
2832 
2833  inner_product_forward(const primitive_desc &aprimitive_desc,
2834  const primitive::at &src, const primitive::at weights,
2835  const primitive::at &bias, const memory &dst) {
2836  mkldnn_primitive_t result;
2837  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2838  bias.data };
2839  const_mkldnn_primitive_t outputs[] = { dst.get() };
2840  check_num_parameters(aprimitive_desc.get(), 3, 1,
2841  "inner product forward");
2843  aprimitive_desc.get(), inputs, outputs),
2844  "could not create a inner product forward primitive");
2845  reset(result);
2846  }
2847 
2848  inner_product_forward(const primitive_desc &aprimitive_desc,
2849  const primitive::at &src, const primitive::at weights,
2850  const memory &dst) {
2851  mkldnn_primitive_t result;
2852  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2853  const_mkldnn_primitive_t outputs[] = { dst.get() };
2854  check_num_parameters(aprimitive_desc.get(), 2, 1,
2855  "inner product forward");
2857  aprimitive_desc.get(), inputs, outputs),
2858  "could not create a inner product forward primitive");
2859  reset(result);
2860  }
2861 };
2862 
2864  struct desc {
2866  desc(const memory::desc &diff_src_desc,
2867  const memory::desc &weights_desc,
2868  const memory::desc &diff_dst_desc) {
2871  &diff_src_desc.data, &weights_desc.data,
2872  &diff_dst_desc.data),
2873  "could not create a inner product backward data descriptor");
2874  }
2875  };
2876 
2878  primitive_desc(const desc &desc, const engine &e,
2879  const inner_product_forward::primitive_desc &hint_fwd_pd)
2880  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2881 
2882  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2883  const inner_product_forward::primitive_desc &hint_fwd_pd)
2884  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2885 
2886  REG_QUERY_MPD(diff_src, diff_src, 0);
2887  REG_QUERY_MPD(weights, weights, 0);
2888  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2889  };
2890 
2892  const primitive::at &diff_dst, const primitive::at weights,
2893  const memory &diff_src) {
2894  mkldnn_primitive_t result;
2895  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2896  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2897  check_num_parameters(aprimitive_desc.get(), 2, 1,
2898  "inner product backward data");
2900  aprimitive_desc.get(), inputs, outputs),
2901  "could not create a inner product backward data primitive");
2902  reset(result);
2903  }
2904 };
2905 
2907  struct desc {
2909  desc(const memory::desc &src_desc,
2910  const memory::desc &diff_weights_desc,
2911  const memory::desc &diff_bias_desc,
2912  const memory::desc &diff_dst_desc) {
2915  &data, &src_desc.data, &diff_weights_desc.data,
2916  &diff_bias_desc.data, &diff_dst_desc.data),
2917  "could not create a inner product backward weights descriptor");
2918  }
2919  desc(const memory::desc &src_desc,
2920  const memory::desc &diff_weights_desc,
2921  const memory::desc &diff_dst_desc) {
2924  &data, &src_desc.data, &diff_weights_desc.data,
2925  nullptr, &diff_dst_desc.data),
2926  "could not create a inner product backward weights descriptor");
2927  }
2928  };
2929 
2931  primitive_desc(const desc &desc, const engine &e,
2932  const inner_product_forward::primitive_desc &hint_fwd_pd)
2933  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2934 
2935  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2936  const inner_product_forward::primitive_desc &hint_fwd_pd)
2937  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2938 
2939  REG_QUERY_MPD(src, src, 0);
2940  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2941  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2942  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2943  };
2944 
2946  const primitive::at &src, const primitive::at diff_dst,
2947  const memory &diff_weights) {
2948  mkldnn_primitive_t result;
2949  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2950  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2951  check_num_parameters(aprimitive_desc.get(), 2, 1,
2952  "inner product backward weights");
2954  aprimitive_desc.get(), inputs, outputs),
2955  "could not create a inner product backward weights primitive");
2956  reset(result);
2957  }
2958 
2960  const primitive::at &src, const primitive::at diff_dst,
2961  const memory &diff_weights, const memory &diff_bias) {
2962  mkldnn_primitive_t result;
2963  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2964  const_mkldnn_primitive_t outputs[] =
2965  { diff_weights.get(), diff_bias.get()};
2966  check_num_parameters(aprimitive_desc.get(), 2, 2,
2967  "inner product backward weights");
2969  aprimitive_desc.get(), inputs, outputs),
2970  "could not create a inner product backward weights primitive");
2971  reset(result);
2972  }
2973 };
2974 
2976 
2982 
2983 struct rnn_cell {
2984  struct desc {
2986 
2987  desc(algorithm kind, algorithm activation_f) {
2989  mkldnn::convert_to_c(kind),
2990  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
2991  "could not init an rnn cell descriptor");
2992  }
2994 
2995  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
2996 
2998  { return algorithm(c_rnn_cell_.cell_kind); }
3001 
3002  float get_alpha() const { return c_rnn_cell_.alpha; }
3003  void set_alpha(float alpha) {
3005  c_rnn_cell_.alpha = alpha;
3006  }
3007 
3008  float get_clipping() const { return c_rnn_cell_.clipping; }
3009  void set_clipping(float clipping) {
3011  c_rnn_cell_.clipping = clipping;
3012  }
3013 
3014  int get_gates_count() const {
3016  }
3017  int get_state_count() const {
3019  }
3020  };
3021 };
3022 
3023 struct rnn_forward : public primitive {
3024  struct desc {
3026  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3027  const rnn_direction direction,
3028  const memory::desc &src_layer_desc,
3029  const memory::desc &src_iter_desc,
3030  const memory::desc &weights_layer_desc,
3031  const memory::desc &weights_iter_desc,
3032  const memory::desc &bias_desc,
3033  const memory::desc &dst_layer_desc,
3034  const memory::desc &dst_iter_desc
3035  ) {
3037  mkldnn::convert_to_c(aprop_kind), cell,
3038  mkldnn::convert_to_c(direction),
3039  &src_layer_desc.data, &src_iter_desc.data,
3040  &weights_layer_desc.data, &weights_iter_desc.data,
3041  &bias_desc.data,
3042  &dst_layer_desc.data, &dst_iter_desc.data),
3043  "could not create an RNN forward descriptor");
3044  }
3045 
3046  };
3047 
3049  primitive_desc(const desc &desc, const engine &e)
3050  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3051 
3052  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3053  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3054 
3055  REG_QUERY_MPD(src_layer, src, 0);
3056  REG_QUERY_MPD(src_iter, src, 1);
3057  REG_QUERY_MPD(weights_layer, weights, 0);
3058  REG_QUERY_MPD(weights_iter, weights, 1);
3059  REG_QUERY_MPD(bias, weights, 2);
3060  REG_QUERY_MPD(dst_layer, dst, 0);
3061  REG_QUERY_MPD(dst_iter, dst, 1);
3062  REG_QUERY_MPD(workspace, workspace, 0);
3063  };
3064 
3065  rnn_forward(const primitive_desc &aprimitive_desc,
3066  const primitive::at &src_layer, const primitive::at &src_iter,
3067  const primitive::at &weights_layer,
3068  const primitive::at &weights_iter, const primitive::at &bias,
3069  const memory &dst_layer, const memory &dst_iter,
3070  const memory &workspace) {
3071  mkldnn_primitive_t result;
3072  mkldnn_primitive_at_t inputs[5];
3073  const_mkldnn_primitive_t outputs[3];
3074  int idx=0;
3075  inputs[idx++] = src_layer.data;
3076  if (!is_null_memory(src_iter.data.primitive))
3077  inputs[idx++] = src_iter.data;
3078  inputs[idx++] = weights_layer.data;
3079  inputs[idx++] = weights_iter.data;
3080  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3081 
3082  idx=0;
3083  outputs[idx++] = dst_layer.get();
3084  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3085  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3086 
3088  aprimitive_desc.get(), inputs, outputs),
3089  "could not create an RNN forward primitive");
3090  reset(result);
3091  }
3092 };
3093 
3094 struct rnn_backward : public primitive {
3095  struct desc {
3097  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3098  const rnn_direction direction,
3099  const memory::desc &src_layer_desc,
3100  const memory::desc &src_iter_desc,
3101  const memory::desc &weights_layer_desc,
3102  const memory::desc &weights_iter_desc,
3103  const memory::desc &bias_desc,
3104  const memory::desc &dst_layer_desc,
3105  const memory::desc &dst_iter_desc,
3106  const memory::desc &diff_src_layer_desc,
3107  const memory::desc &diff_src_iter_desc,
3108  const memory::desc &diff_weights_layer_desc,
3109  const memory::desc &diff_weights_iter_desc,
3110  const memory::desc &diff_bias_desc,
3111  const memory::desc &diff_dst_layer_desc,
3112  const memory::desc &diff_dst_iter_desc) {
3114  mkldnn::convert_to_c(aprop_kind), cell,
3115  mkldnn::convert_to_c(direction),
3116  &src_layer_desc.data, &src_iter_desc.data,
3117  &weights_layer_desc.data, &weights_iter_desc.data,
3118  &bias_desc.data,
3119  &dst_layer_desc.data, &dst_iter_desc.data,
3120  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3121  &diff_weights_layer_desc.data,
3122  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3123  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3124  "could not create an RNN backward descriptor");
3125  }
3126 
3127  };
3128 
3130  primitive_desc(const desc &desc, const engine &e,
3131  const rnn_forward::primitive_desc &hint_fwd_pd)
3132  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3133 
3134  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3135  const rnn_forward::primitive_desc &hint_fwd_pd)
3136  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3137 
3138  REG_QUERY_MPD(src_layer, src, 0);
3139  REG_QUERY_MPD(src_iter, src, 1);
3140  REG_QUERY_MPD(weights_layer, weights, 0);
3141  REG_QUERY_MPD(weights_iter, weights, 1);
3142  REG_QUERY_MPD(bias, weights, 2);
3143  REG_QUERY_MPD(dst_layer, dst, 0);
3144  REG_QUERY_MPD(dst_iter, dst, 1);
3145  REG_QUERY_MPD(workspace, workspace, 0);
3146 
3147  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3148  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3149  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3150  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3151  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3152  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3153  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3154  };
3155 
3156  // With last iteration (with and without input src_iter)
3157  rnn_backward(const primitive_desc &aprimitive_desc,
3158  const primitive::at &src_layer,
3159  const primitive::at &src_iter,
3160  const primitive::at &weights_layer,
3161  const primitive::at &weights_iter,
3162  const primitive::at &bias,
3163  const primitive::at &dst_layer,
3164  const primitive::at &dst_iter,
3165  const memory &diff_src_layer,
3166  const memory &diff_src_iter,
3167  const memory &diff_weights_layer,
3168  const memory &diff_weights_iter,
3169  const memory &diff_bias,
3170  const primitive::at &diff_dst_layer,
3171  const primitive::at &diff_dst_iter,
3172  const primitive::at &workspace) {
3173  mkldnn_primitive_t result;
3174  mkldnn_primitive_at_t inputs[10];
3175  const_mkldnn_primitive_t outputs[5];
3176  int idx=0;
3177  inputs[idx++] = src_layer.data;
3178  if (!is_null_memory(src_iter.data.primitive))
3179  inputs[idx++] = src_iter.data;
3180  inputs[idx++] = weights_layer.data;
3181  inputs[idx++] = weights_iter.data;
3182  if (!is_null_memory(bias.data.primitive))
3183  inputs[idx++] = bias.data;
3184  inputs[idx++] = dst_layer.data;
3185  if (!is_null_memory(dst_iter.data.primitive))
3186  inputs[idx++] = dst_iter.data;
3187  inputs[idx++] = diff_dst_layer.data;
3188  if (!is_null_memory(diff_dst_iter.data.primitive))
3189  inputs[idx++] = diff_dst_iter.data;
3190  inputs[idx++] = workspace.data;
3191 
3192  idx = 0;
3193  outputs[idx++] = diff_src_layer.get();
3194  if (!is_null_memory(diff_src_iter.get()))
3195  outputs[idx++] = diff_src_iter.get();
3196  outputs[idx++] = diff_weights_layer.get();
3197  outputs[idx++] = diff_weights_iter.get();
3198  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3200  aprimitive_desc.get(), inputs, outputs),
3201  "could not create an RNN backward primitive");
3202  reset(result);
3203  }
3204 };
3205 
3207 
3213 
3214 struct shuffle_forward : public primitive {
3215  struct desc {
3217  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3218  int axis, int group_size) {
3220  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3221  axis, group_size),
3222  "could not create a shuffle forward descriptor");
3223  }
3224  };
3225 
3227  primitive_desc(const desc &desc, const engine &e)
3228  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3229 
3230  REG_QUERY_MPD(src, src, 0);
3231  REG_QUERY_MPD(dst, dst, 0);
3232  };
3233 
3234  shuffle_forward(const primitive_desc &aprimitive_desc,
3235  const primitive::at &src, const memory &dst) {
3236  mkldnn_primitive_t result;
3237  mkldnn_primitive_at_t inputs[] = { src.data };
3238  const_mkldnn_primitive_t outputs[] = { dst.get() };
3239  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3241  aprimitive_desc.get(), inputs, outputs),
3242  "could not create a shuffle forward primitive");
3243  reset(result);
3244  }
3245 };
3246 
3247 struct shuffle_backward : public primitive {
3248  struct desc {
3250  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3252  &diff_data_desc.data, axis, group_size),
3253  "could not create a shuffle backward descriptor");
3254  }
3255  };
3256 
3258  primitive_desc(const desc &desc, const engine &e,
3259  const shuffle_forward::primitive_desc &hint_fwd_pd)
3260  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3261 
3262  REG_QUERY_MPD(diff_src, diff_src, 0);
3263  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3264  };
3265 
3266  shuffle_backward(const primitive_desc &aprimitive_desc,
3267  const primitive::at &diff_dst, const memory &diff_src) {
3268  mkldnn_primitive_t result;
3269  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3270  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3271  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3273  aprimitive_desc.get(), inputs, outputs),
3274  "could not create a shuffle backward primitive");
3275  reset(result);
3276  }
3277 };
3278 
3280 
3282 
3288 
3289 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3290 template <> struct handle_traits<mkldnn_stream_t> {
3291  static constexpr auto destructor = &mkldnn_stream_destroy;
3292 };
3293 #endif
3294 
3295 struct stream: public handle<mkldnn_stream_t> {
3296  using handle::handle;
3297 
3301 
3302  static mkldnn_stream_kind_t convert_to_c(kind akind) {
3303  return static_cast<mkldnn_stream_kind_t>(akind);
3304  }
3306  stream(kind akind) {
3307  mkldnn_stream_t astream;
3309  convert_to_c(akind)),
3310  "could not create a stream");
3311  reset(astream);
3312  }
3313 
3318  stream &submit(std::vector<primitive> primitives) {
3319  // TODO: find a proper way to convert vector<primitive> to
3320  // vector<mkldnn_primitive_t>
3321  if (primitives.size() == 0) return *this;
3322  std::vector<mkldnn_primitive_t> c_api_primitives;
3323  c_api_primitives.reserve(primitives.size());
3324  auto convert_to_c = [](primitive p) { return p.get(); };
3325  std::transform(primitives.begin(), primitives.end(),
3326  std::back_inserter(c_api_primitives), convert_to_c);
3327 
3328  mkldnn_primitive_t c_api_error_primitive;
3330  mkldnn_stream_submit(get(),
3331  c_api_primitives.size(), &c_api_primitives[0],
3332  &c_api_error_primitive),
3333  "could not submit primitives to a stream",
3334  &c_api_error_primitive);
3335 
3336  return *this;
3337  }
3338 
3345  bool wait(bool block = true) {
3346  mkldnn_primitive_t c_api_error_primitive;
3347  mkldnn_status_t status = mkldnn_stream_wait(get(),
3348  block, &c_api_error_primitive);
3349  if (status != mkldnn_success
3350  && status != mkldnn_try_again)
3351  error::wrap_c_api(status, "could not wait on a stream",
3352  &c_api_error_primitive);
3353  return (status == mkldnn_success);
3354  }
3355 
3357  mkldnn_primitive_t c_api_error_primitive;
3359  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3360  "could not rerun a stream", &c_api_error_primitive);
3361  return *this;
3362  }
3363 };
3364 
3365 #undef REG_QUERY_MPD
3366 
3368 
3370 
3371 } // namespace mkldnn
3372 
3373 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:385
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2410
Definition: mkldnn.hpp:2361
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
LRN within a single channel.
Definition: mkldnn_types.h:542
primitive error_primitive
Definition: mkldnn.hpp:164
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:880
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1509
Definition: mkldnn.hpp:730
Definition: mkldnn.hpp:342
blocked weights format
Definition: mkldnn_types.h:332
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2848
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2202
Definition: mkldnn.hpp:269
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1091
blocked weights format
Definition: mkldnn_types.h:337
op descriptor
Definition: mkldnn_types.h:1222
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1101
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1652
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:374
mkldnn_primitive_t get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
Definition: mkldnn.hpp:3094
blocked weights format
Definition: mkldnn_types.h:316
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
Definition: mkldnn.hpp:702
Definition: mkldnn.hpp:650
blocked weights format
Definition: mkldnn_types.h:412
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
Definition: mkldnn.hpp:257
Definition: mkldnn.hpp:648
A Softmax primitive.
Definition: mkldnn_types.h:486
number of outputs expected
Definition: mkldnn_types.h:1211
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3052
blocked weights format
Definition: mkldnn_types.h:415
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1662
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2522
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3318
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1258
Definition: mkldnn.hpp:598
Definition: mkldnn.hpp:2235
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:47
stream & rerun()
Definition: mkldnn.hpp:3356
Definition: mkldnn.hpp:2198
A descriptor of a convolution operation.
Definition: mkldnn_types.h:733
Definition: mkldnn.hpp:300
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3217
Definition: mkldnn.hpp:2173
The operation failed and should be retried.
Definition: mkldnn_types.h:53
memory null_memory(engine eng)
Definition: mkldnn.hpp:905
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
Definition: mkldnn.hpp:680
blocked weights format
Definition: mkldnn_types.h:279
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:654
Definition: mkldnn.hpp:329
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1602
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:253
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:171
Definition: mkldnn.hpp:265
padding_kind
Definition: mkldnn.hpp:232
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:55
Definition: mkldnn.hpp:695
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:447
Definition: mkldnn.hpp:2036
Definition: mkldnn.hpp:671
An opaque structure to describe an engine.
Definition: mkldnn.hpp:737
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1564
Backward data propagation.
Definition: mkldnn_types.h:453
Definition: mkldnn.hpp:2434
Definition: mkldnn.hpp:703
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:586
Definition: mkldnn.hpp:3257
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(const_mkldnn_primitive_desc_t primitive_desc, const_mkldnn_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Definition: mkldnn.hpp:3247
Definition: mkldnn.hpp:641
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
Definition: mkldnn.hpp:712
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2400
Definition: mkldnn.hpp:274
blocked weights format
Definition: mkldnn_types.h:310
blocked weights format
Definition: mkldnn_types.h:383
Definition: mkldnn.hpp:722
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:145
Definition: mkldnn.hpp:679
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1142
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:799
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1997
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
Definition: mkldnn.hpp:644
float alpha
alpha is a negative slope parameter (used only if (flags &amp; mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:984
Definition: mkldnn.hpp:607
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(mkldnn_primitive_attr_t *attr, const_mkldnn_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:607
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:605
Definition: mkldnn.hpp:290
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:192
algorithm get_activation() const
Definition: mkldnn.hpp:2999
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:830
blocked weights format
Definition: mkldnn_types.h:416
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
Definition: mkldnn.hpp:658
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2420
blocked weights format
Definition: mkldnn_types.h:417
blocked data format
Definition: mkldnn_types.h:262
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:244
Definition: mkldnn.hpp:659
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:938
Definition: mkldnn.hpp:728
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:584
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:229
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3227
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor.
batch normalization descriptor
Definition: mkldnn_types.h:1231
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1721
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:991
Definition: mkldnn.hpp:655
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:480
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1869
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2099
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
Definition: mkldnn.hpp:637
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:538
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:287
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:805
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:541
Definition: mkldnn.hpp:738
engine get_engine()
Definition: mkldnn.hpp:1271
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:765
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1044
blocked data format
Definition: mkldnn_types.h:263
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1287
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind (possi...
Definition: mkldnn.hpp:225
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2797
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1231
An execution engine.
Definition: mkldnn.hpp:503
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:855
blocked weights format
Definition: mkldnn_types.h:408
Definition: mkldnn.hpp:751
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2865
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha, and beta (...
Definition: mkldnn.hpp:615
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:188
Definition: mkldnn.hpp:714
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2237
blocked weights format
Definition: mkldnn_types.h:323
int len() const
Definition: mkldnn.hpp:375
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:464
Definition: mkldnn.hpp:690
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:494
Packed weights format used in RNN.
Definition: mkldnn_types.h:421
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:910
Definition: mkldnn.hpp:749
Round down.
Definition: mkldnn_types.h:90
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:210
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2449
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1739
Definition: mkldnn.hpp:696
Definition: mkldnn.hpp:264
Definition: mkldnn.hpp:700
primitive_attr()
Definition: mkldnn.hpp:419
Definition: mkldnn_types.h:538
Definition: mkldnn.hpp:2346
Definition: mkldnn.hpp:633
An unspecified engine.
Definition: mkldnn.hpp:510
Definition: mkldnn.hpp:713
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams(mkldnn_primitive_attr_t attr, int count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
Definition: mkldnn.hpp:596
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2445
Definition: mkldnn.hpp:677
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2435
float get_clipping() const
Definition: mkldnn.hpp:3008
Definition: mkldnn.hpp:2409
Definition: mkldnn.hpp:247
32-bit signed integer.
Definition: mkldnn_types.h:76
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1312
Definition: mkldnn.hpp:706
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2878
Max pooling.
Definition: mkldnn_types.h:533
Definition: mkldnn.hpp:724
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1423
memory::desc zero_md()
Definition: mkldnn.hpp:899
Definition: mkldnn.hpp:336
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1034
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
blocked weights format
Definition: mkldnn_types.h:300
blocked weights format
Definition: mkldnn_types.h:322
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2175
Definition: mkldnn.hpp:628
execution engine
Definition: mkldnn_types.h:1207
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3306
Definition: mkldnn.hpp:1033
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:335
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2866
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in the spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2172
blocked weights format
Definition: mkldnn_types.h:307
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:894
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2366
Definition: mkldnn.hpp:320
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:969
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:255
input memory primitive desc
Definition: mkldnn_types.h:1237
blocked weights format
Definition: mkldnn_types.h:325
Definition: mkldnn.hpp:744
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3216
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:214
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1167
Definition: mkldnn.hpp:289
blocked weights format
Definition: mkldnn_types.h:336
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3065
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:795
Definition: mkldnn.hpp:699
rnn descriptor
Definition: mkldnn_types.h:1233
An element-wise primitive.
Definition: mkldnn_types.h:484
Definition: mkldnn.hpp:2433
blocked weights format
Definition: mkldnn_types.h:315
destination grad.
Definition: mkldnn_types.h:1244
Definition: mkldnn.hpp:745
engine get_engine()
Definition: mkldnn.hpp:1228
Definition: mkldnn.hpp:2347
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:979
blocked weights format
Definition: mkldnn_types.h:328
A descriptor for an RNN operation.
Definition: mkldnn_types.h:1006
Definition: mkldnn.hpp:620
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1400
Definition: mkldnn.hpp:1089
Definition: mkldnn.hpp:277
Definition: mkldnn.hpp:259
eltwise descriptor
Definition: mkldnn_types.h:1227
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2616
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1448
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(mkldnn_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: mkldnn.hpp:276
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2753
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2084
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2909
batch_normalization_flag
Definition: mkldnn.hpp:288
A memory primitive.
Definition: mkldnn_types.h:466
float clipping
clipping parameter (used only if (flags &amp; mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:987
Definition: mkldnn.hpp:697
blocked weights format
Definition: mkldnn_types.h:297
blocked weights format
Definition: mkldnn_types.h:309
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3097
Eltwise: soft_relu.
Definition: mkldnn_types.h:529
Definition: mkldnn.hpp:675
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:469
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2833
Definition: mkldnn.hpp:341
Definition: mkldnn.hpp:709
Definition: mkldnn.hpp:645
Definition: mkldnn.hpp:261
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
Definition: mkldnn.hpp:599
RNN cell.
Definition: mkldnn_types.h:544
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2199
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1760
const post_ops get_post_ops() const
Definition: mkldnn.hpp:460
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:930
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2882
Definition: mkldnn.hpp:367
blocked weights format
Definition: mkldnn_types.h:344
Definition: mkldnn.hpp:1360
Backward weights propagation.
Definition: mkldnn_types.h:455
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:433
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3025
blocked weights format
Definition: mkldnn_types.h:411
Definition: mkldnn.hpp:669
32-bit/single-precision floating point.
Definition: mkldnn_types.h:74
Definition: mkldnn.hpp:743
blocked weights format
Definition: mkldnn_types.h:275
blocked data format
Definition: mkldnn_types.h:260
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1584
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2210
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:180
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:425
Memory descriptor.
Definition: mkldnn_types.h:692
Definition: mkldnn.hpp:719
Definition: mkldnn.hpp:2796
Definition: mkldnn.hpp:303
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3234
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:561
void set_clipping(float clipping)
Definition: mkldnn.hpp:3009
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1676
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2037
Definition: mkldnn.hpp:2795
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2484
Definition: mkldnn.hpp:280
Definition: mkldnn.hpp:608
pooling descriptor
Definition: mkldnn_types.h:1229
Definition: mkldnn.hpp:2236
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:240
Definition: mkldnn.hpp:631
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2174
Definition: mkldnn.hpp:267
blocked weights format
Definition: mkldnn_types.h:274
blocked data format
Definition: mkldnn_types.h:264
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:189
Definition: mkldnn.hpp:701
blocked weights format
Definition: mkldnn_types.h:393
blocked weights format
Definition: mkldnn_types.h:343
Definition: mkldnn.hpp:630
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2506
Definition: mkldnn.hpp:725
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:981
Definition: mkldnn.hpp:647
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:159
blocked weights format
Definition: mkldnn_types.h:313
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:814
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1546
The operation was successful.
Definition: mkldnn_types.h:49
Definition: mkldnn.hpp:632
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:403
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:367
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2935
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1648
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:2987
blocked weights format
Definition: mkldnn_types.h:381
Definition: mkldnn.hpp:326
Definition: mkldnn.hpp:245
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1259
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:409
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3096
Definition: mkldnn.hpp:653
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:809
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:451
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:177
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2959
softmax descriptor
Definition: mkldnn_types.h:1228
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:86
A deconvolution primitive.
Definition: mkldnn_types.h:482
Definition: mkldnn.hpp:330
Definition: mkldnn.hpp:275
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:789
Use global statistics.
Definition: mkldnn_types.h:574
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1114
Definition: mkldnn.hpp:668
blocked weights format
Definition: mkldnn_types.h:314
Definition: mkldnn.hpp:636
no query
Definition: mkldnn_types.h:1205
int get_gates_count() const
Definition: mkldnn.hpp:3014
Definition: mkldnn.hpp:1700
blocked weights format
Definition: mkldnn_types.h:395
blocked weights format
Definition: mkldnn_types.h:330
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offsets offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:82
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1127
bool operator==(const T other) const
Definition: mkldnn.hpp:61
Definition: mkldnn.hpp:664
blocked weights format
Definition: mkldnn_types.h:407
Definition: mkldnn.hpp:346
Average pooling include padding.
Definition: mkldnn_types.h:535
Unspecified format.
Definition: mkldnn_types.h:148
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2891
Definition: mkldnn.hpp:2058
destination memory primitive desc
Definition: mkldnn_types.h:1243
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:239
GRU cell with linear before reset.
Definition: mkldnn_types.h:557
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:828
Definition: mkldnn.hpp:649
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2136
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:540
Definition: mkldnn.hpp:698
blocked weights format
Definition: mkldnn_types.h:296
GRU cell.
Definition: mkldnn_types.h:548
Eager stream.
Definition: mkldnn_types.h:1258
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:984
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:453
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:143
implementation name
Definition: mkldnn_types.h:1218
CPU engine.
Definition: mkldnn.hpp:512
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1938
Definition: mkldnn.hpp:1361
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3250
Definition: mkldnn.hpp:3248
Definition: mkldnn.hpp:256
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2274
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and a pointer to a constant floating point array of output ...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:186
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:517
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:885
kind
Kinds of engines.
Definition: mkldnn.hpp:508
Definition: mkldnn.hpp:2098
Definition: mkldnn.hpp:2863
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2413
Definition: mkldnn.hpp:693
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:161
round_mode
Definition: mkldnn.hpp:223
Definition: mkldnn.hpp:747
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:939
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1827
Eltwise: ReLU.
Definition: mkldnn_types.h:513
Definition: mkldnn.hpp:2397
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1362
Definition: mkldnn.hpp:233
1D data tensor.
Definition: mkldnn_types.h:154
REG_QUERY_MPD(diff_src, diff_src, 0)
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:136
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2695
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3134
Definition: mkldnn.hpp:705
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3258
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:198
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2348
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:143
Definition: mkldnn.hpp:1032
Eltwise: square.
Definition: mkldnn_types.h:519
Definition: mkldnn.hpp:609
Definition: mkldnn.hpp:1166
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1382
Definition: mkldnn.hpp:281
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwa...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
void set_rnn_data_qparams(const float scale, const float shift)
Definition: mkldnn.hpp:474
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:891
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:168
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2591
Definition: mkldnn.hpp:268
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2100
Definition: mkldnn.hpp:616
Backward bias propagation.
Definition: mkldnn_types.h:457
Definition: mkldnn.hpp:973
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2047
blocked weights format
Definition: mkldnn_types.h:404
Use scale and shift parameters.
Definition: mkldnn_types.h:587
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1702
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:311
Definition: mkldnn.hpp:279
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:365
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
Definition: mkldnn.hpp:661
blocked weights format
Definition: mkldnn_types.h:295
blocked weights format
Definition: mkldnn_types.h:382
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:778
Definition: mkldnn_types.h:1001
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1215
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2310
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1983
Definition: mkldnn.hpp:418
blocked weights format
Definition: mkldnn_types.h:398
blocked weights format
Definition: mkldnn_types.h:339
Definition: mkldnn.hpp:688
int ndims
Number of dimensions.
Definition: mkldnn_types.h:697
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:997
Definition: mkldnn.hpp:2035
Definition: mkldnn.hpp:1090
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:358
5D grouped weights tensor with the physical layout giohw.
Definition: mkldnn_types.h:221
An opaque structure to describe an execution stream.
void set_alpha(float alpha)
Definition: mkldnn.hpp:3003
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2110
Definition: mkldnn.hpp:656
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:174
Definition: mkldnn.hpp:3215
Definition: mkldnn.hpp:621
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:232
Definition: mkldnn.hpp:2122
Definition: mkldnn.hpp:741
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1561
Definition: mkldnn.hpp:667
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1795
A rnn primitive.
Definition: mkldnn_types.h:496
Definition: mkldnn.hpp:685
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
blocked weights format
Definition: mkldnn_types.h:324
blocked weights format
Definition: mkldnn_types.h:270
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis, and group_size.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1899
Definition: mkldnn.hpp:623
Definition: mkldnn.hpp:2984
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2375
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:435
Definition: mkldnn.hpp:683
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:134
CPU engine.
Definition: mkldnn_types.h:1057
Definition: mkldnn.hpp:291
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:878
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2351
Eltwise: square root.
Definition: mkldnn_types.h:523
Definition: mkldnn.hpp:731
Definition: mkldnn.hpp:739
Definition: mkldnn.hpp:692
blocked weights format
Definition: mkldnn_types.h:277
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1254
Definition: mkldnn.hpp:271
Definition: mkldnn.hpp:681
Definition: mkldnn.hpp:610
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:195
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1165
Winograd convolution.
Definition: mkldnn_types.h:505
Definition: mkldnn.hpp:638
Definition: mkldnn.hpp:246
Definition: mkldnn.hpp:343
Eltwise: linear.
Definition: mkldnn_types.h:525
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1828
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1900
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:1008
Eltwise: logistic.
Definition: mkldnn_types.h:531
Definition: mkldnn.hpp:2675
Direct convolution.
Definition: mkldnn_types.h:503
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:62
Definition: mkldnn.hpp:338
Definition: mkldnn.hpp:270
Definition: mkldnn.hpp:721
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2070
source gradient memory primitive desc
Definition: mkldnn_types.h:1240
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:976
Definition: mkldnn.hpp:1489
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2677
Definition: mkldnn_types.h:993
An opaque structure for primitive descriptor attributes.
Definition: mkldnn.hpp:312
blocked data format
Definition: mkldnn_types.h:266
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
blocked weights format
Definition: mkldnn_types.h:329
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2038
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:439
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2647
Definition: mkldnn.hpp:3299
Definition: mkldnn.hpp:748
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:2985
runtime estimation (seconds)
Definition: mkldnn_types.h:1213
blocked weights format
Definition: mkldnn_types.h:397
A (in-place) concat primitive.
Definition: mkldnn_types.h:476
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
Definition: mkldnn.hpp:660
blocked weights format
Definition: mkldnn_types.h:298
Definition: mkldnn.hpp:673
LSTM cell.
Definition: mkldnn_types.h:546
blocked weights format
Definition: mkldnn_types.h:280
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn.hpp:740
Definition: mkldnn_types.h:1002
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2495
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2821
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2824
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:72
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:353
Definition: mkldnn.hpp:1825
16-bit signed integer.
Definition: mkldnn_types.h:78
Definition: mkldnn.hpp:2309
A shuffle primitive.
Definition: mkldnn_types.h:472
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:305
Definition: mkldnn.hpp:624
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3249
primitive_desc()
Definition: mkldnn.hpp:786
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
blocked weights format
Definition: mkldnn_types.h:312
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1178
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2809
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
blocked data format
Definition: mkldnn_types.h:258
Definition: mkldnn.hpp:242
blocked weights format
Definition: mkldnn_types.h:331
Definition: mkldnn.hpp:676
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1490
blocked weights format
Definition: mkldnn_types.h:321
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:474
blocked weights format
Definition: mkldnn_types.h:340
Definition: mkldnn.hpp:629
Fuse with ReLU.
Definition: mkldnn_types.h:596
Definition: mkldnn.hpp:746
Definition: mkldnn.hpp:678
Definition: mkldnn.hpp:260
Definition: mkldnn.hpp:278
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:519
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1204
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:907
Definition: mkldnn.hpp:691
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:551
Definition: mkldnn.hpp:3023
blocked weights format
Definition: mkldnn_types.h:354
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2011
blocked data format
Definition: mkldnn_types.h:265
blocked weights format
Definition: mkldnn_types.h:276
A sum primitive.
Definition: mkldnn_types.h:478
blocked weights format
Definition: mkldnn_types.h:342
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2770
Definition: mkldnn.hpp:302
Definition: mkldnn.hpp:627
blocked weights format
Definition: mkldnn_types.h:392
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2333
blocked weights format
Definition: mkldnn_types.h:282
Definition: mkldnn.hpp:727
unsigned flags
Definition: mkldnn_types.h:934
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:281
blocked weights format
Definition: mkldnn_types.h:345
Definition: mkldnn.hpp:2983
Definition: mkldnn.hpp:595
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2508
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: mkldnn_types.h:507
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2459
blocked weights format
Definition: mkldnn_types.h:271
Definition: mkldnn.hpp:3024
Definition: mkldnn.hpp:258
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2323
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:399
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
float get_alpha() const
Definition: mkldnn.hpp:3002
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:183
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1214
blocked weights format
Definition: mkldnn_types.h:335
An batch normalization primitive.
Definition: mkldnn_types.h:492
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:501
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:529
Definition: mkldnn.hpp:2308
A descriptor of a pooling operation.
Definition: mkldnn_types.h:846
Definition: mkldnn.hpp:639
Definition: mkldnn.hpp:3295
Definition: mkldnn.hpp:272
Definition: mkldnn.hpp:273
engine get_engine()
Definition: mkldnn.hpp:818
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:173
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1987
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:210
deconvolution descriptor
Definition: mkldnn_types.h:1225
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1168
blocked weights format
Definition: mkldnn_types.h:347
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3266
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:975
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:402
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2261
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:758
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1451
engine get_engine()
Definition: mkldnn.hpp:994
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:80
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:431
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2322
friend struct error
Definition: mkldnn.hpp:107
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2919
Definition: mkldnn.hpp:729
source memory primitive desc
Definition: mkldnn_types.h:1239
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:462
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1873
Definition: mkldnn.hpp:711
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1960
Definition: mkldnn.hpp:3226
Winograd deconvolution.
Definition: mkldnn_types.h:511
Definition: mkldnn.hpp:3300
Definition: mkldnn.hpp:248
number of inputs expected
Definition: mkldnn_types.h:1210
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2399
Definition: mkldnn.hpp:345
Definition: mkldnn.hpp:657
Definition: mkldnn.hpp:3048
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2498
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2312
An unspecified engine.
Definition: mkldnn_types.h:1256
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1783
A view primitive.
Definition: mkldnn_types.h:468
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3095
Definition: mkldnn.hpp:262
Definition: mkldnn.hpp:666
Definition: mkldnn.hpp:328
Definition: mkldnn.hpp:750
Definition: mkldnn.hpp:622
Definition: mkldnn.hpp:3129
Definition: mkldnn.hpp:742
blocked weights format
Definition: mkldnn_types.h:311
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:154
Definition: mkldnn.hpp:718
Definition: mkldnn.hpp:734
Definition: mkldnn.hpp:704
blocked data format
Definition: mkldnn_types.h:261
Definition: mkldnn.hpp:340
Definition: mkldnn.hpp:717
Definition: mkldnn.hpp:331
Definition: mkldnn.hpp:323
Definition: mkldnn.hpp:333
Average pooling exclude padding.
Definition: mkldnn_types.h:537
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
Definition: mkldnn_types.h:972
Forward data propagation (inference mode).
Definition: mkldnn_types.h:445
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:225
Definition: mkldnn.hpp:687
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:204
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:594
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2538
Direct deconvolution.
Definition: mkldnn_types.h:509
Eltwise: abs.
Definition: mkldnn_types.h:521
int get_state_count() const
Definition: mkldnn.hpp:3017
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2560
blocked weights format
Definition: mkldnn_types.h:369
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2286
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
blocked weights format
Definition: mkldnn_types.h:299
A memory descriptor.
Definition: mkldnn.hpp:755
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1882
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:218
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2326
blocked weights format
Definition: mkldnn_types.h:389
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:942
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:480
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:515
algorithm get_cell_kind() const
Definition: mkldnn.hpp:2997
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2908
mkldnn_status_t status
Definition: mkldnn.hpp:162
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1810
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:388
Definition: mkldnn.hpp:646
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
Definition: mkldnn.hpp:682
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1060
blocked weights format
Definition: mkldnn_types.h:348
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1920
Definition: mkldnn.hpp:663
blocked weights format
Definition: mkldnn_types.h:346
2D data tensor.
Definition: mkldnn_types.h:156
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2691
Definition: mkldnn.hpp:625
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2798
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3345
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc and dif...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:64
memory descriptor for memory and view
Definition: mkldnn_types.h:1223
Definition: mkldnn.hpp:710
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1069
Definition: mkldnn.hpp:266
Definition: mkldnn.hpp:674
An LRN primitive.
Definition: mkldnn_types.h:490
Definition: mkldnn_types.h:998
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:429
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3157
Lazy stream.
Definition: mkldnn_types.h:1260
Definition: mkldnn.hpp:332
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2436
blocked weights format
Definition: mkldnn_types.h:394
Definition: mkldnn.hpp:304
Definition: mkldnn.hpp:726
blocked weights format
Definition: mkldnn_types.h:273
desc(algorithm kind)
Definition: mkldnn.hpp:2993
Definition: mkldnn.hpp:689
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3130
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:246
blocked weights format
Definition: mkldnn_types.h:338
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:612
Definition: mkldnn.hpp:640
Definition: mkldnn.hpp:2906
Definition: mkldnn.hpp:708
shuffle descriptor
Definition: mkldnn_types.h:1226
Forward data propagation (training mode).
Definition: mkldnn_types.h:441
Definition: mkldnn.hpp:733
Definition: mkldnn.hpp:665
Definition: mkldnn.hpp:344
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2123
Definition: mkldnn.hpp:626
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2945
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1563
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:824
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:162
engine get_engine()
Definition: mkldnn.hpp:1139
Definition: mkldnn.hpp:613
post_ops()
Definition: mkldnn.hpp:368
An opaque structure to describe a primitive.
Definition: mkldnn.hpp:715
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2731
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:152
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1363
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:70
primitive::kind kind(int index) const
Definition: mkldnn.hpp:377
Definition: mkldnn.hpp:1488
Definition: mkldnn.hpp:600
Definition: mkldnn.hpp:662
Definition: mkldnn.hpp:643
Definition: mkldnn.hpp:611
Definition: mkldnn.hpp:325
Definition: mkldnn.hpp:318
convolution descriptor
Definition: mkldnn_types.h:1224
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1537
Definition: mkldnn.hpp:707
Definition: mkldnn.hpp:606
A memory primitive descriptor.
Definition: mkldnn.hpp:782
Definition: mkldnn.hpp:314
Definition: mkldnn.hpp:2444
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:326
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1533
blocked weights format
Definition: mkldnn_types.h:317
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2661
Eltwise: bounded_relu.
Definition: mkldnn_types.h:527
Definition: mkldnn.hpp:2398
Definition: mkldnn_types.h:995
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1473
Definition: mkldnn.hpp:735
Definition: mkldnn.hpp:694
Definition: mkldnn.hpp:617
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:1053
Definition: mkldnn_types.h:968
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:66
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3049
blocked weights format
Definition: mkldnn_types.h:414
blocked weights format
Definition: mkldnn_types.h:366
Memory primitive that describes the data.
Definition: mkldnn.hpp:579
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:419
Definition: mkldnn.hpp:327
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2059
Definition: mkldnn.hpp:2097
Definition: mkldnn.hpp:301
Round nearest.
Definition: mkldnn_types.h:88
blocked weights format
Definition: mkldnn_types.h:413
Definition: mkldnn.hpp:243
Definition: mkldnn.hpp:3298
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2712
Definition: mkldnn.hpp:1699
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:686
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3302
blocked weights format
Definition: mkldnn_types.h:272
blocked weights format
Definition: mkldnn_types.h:410
Definition: mkldnn.hpp:1897
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2480
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2222
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1460
4D weights tensor with physical layout iohw.
Definition: mkldnn_types.h:201
A reorder primitive.
Definition: mkldnn_types.h:470
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1786
rnn_direction
Definition: mkldnn.hpp:299
primitive_attr get_primitive_attr() const
Definition: mkldnn.hpp:1273
Definition: mkldnn.hpp:672
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1197
blocked weights format
Definition: mkldnn_types.h:390
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:380
blocked weights format
Definition: mkldnn_types.h:320
Definition: mkldnn.hpp:635
An unspecified engine.
Definition: mkldnn_types.h:1055
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:778
blocked weights format
Definition: mkldnn_types.h:341
Definition: mkldnn.hpp:597
Definition: mkldnn.hpp:1167
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:865
engine get_engine()
Definition: mkldnn.hpp:1057
Definition: mkldnn.hpp:684
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2265
friend class primitive_at
Definition: mkldnn.hpp:109
blocked weights format
Definition: mkldnn_types.h:391
Definition: mkldnn.hpp:720
blocked weights format
Definition: mkldnn_types.h:368
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:500
Definition: mkldnn.hpp:716
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2931
Definition: mkldnn.hpp:263
inner product descriptor
Definition: mkldnn_types.h:1232
blocked weights format
Definition: mkldnn_types.h:375
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:390
A pooling primitive.
Definition: mkldnn_types.h:488
Definition: mkldnn.hpp:723
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1241
output memory primitive desc
Definition: mkldnn_types.h:1238
Definition: mkldnn.hpp:2260
blocked weights format
Definition: mkldnn_types.h:396
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:207
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2062
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2482
Definition: mkldnn.hpp:974
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:334
Definition: mkldnn.hpp:634
std::string message
Definition: mkldnn.hpp:163
Definition: mkldnn.hpp:3214
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
Definition: mkldnn.hpp:651
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2362
Definition: mkldnn.hpp:315
blocked weights format
Definition: mkldnn_types.h:308
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:449
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:227
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:241
lrn descriptor
Definition: mkldnn_types.h:1230
Definition: mkldnn.hpp:670
workspace memory primitive desc
Definition: mkldnn_types.h:1245
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2150
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1624
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1301
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
Definition: mkldnn.hpp:619
blocked weights format
Definition: mkldnn_types.h:269
blocked weights format
Definition: mkldnn_types.h:278
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1701
Definition: mkldnn.hpp:642
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2679
blocked weights format
Definition: mkldnn_types.h:327
Definition: mkldnn.hpp:224
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:294
Definition: mkldnn.hpp:736
Definition: mkldnn.hpp:614
Definition: mkldnn.hpp:686
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:426
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2127
weights grad.
Definition: mkldnn_types.h:1242
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:165
Definition: mkldnn.hpp:321
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
Definition: mkldnn.hpp:618
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:395
primitive kind
Definition: mkldnn_types.h:1208
blocked data format
Definition: mkldnn_types.h:259
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1846
blocked weights format
Definition: mkldnn_types.h:306
Definition: mkldnn.hpp:317
Definition: mkldnn.hpp:652
An opaque structure to describe a primitive descriptor iterator.
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2238
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2575
Definition: mkldnn.hpp:732
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1491
Definition: mkldnn.hpp:339
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3026
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...