Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause
2 : : * Copyright (c) 2023 Marvell.
3 : : */
4 : :
5 : : #include <rte_mldev.h>
6 : :
7 : : #include <mldev_utils.h>
8 : :
9 : : #include <roc_api.h>
10 : :
11 : : #include "cnxk_ml_io.h"
12 : :
13 : : inline int
14 : 0 : cnxk_ml_io_quantize_single(struct cnxk_ml_io *input, uint8_t *dbuffer, uint8_t *qbuffer)
15 : : {
16 : : enum rte_ml_io_type qtype;
17 : : enum rte_ml_io_type dtype;
18 : : uint32_t nb_elements;
19 : : float qscale;
20 : : int ret = 0;
21 : :
22 : 0 : dtype = input->dtype;
23 : 0 : qtype = input->qtype;
24 : 0 : qscale = input->scale;
25 : 0 : nb_elements = input->nb_elements;
26 : :
27 [ # # ]: 0 : if (dtype == qtype) {
28 [ # # ]: 0 : rte_memcpy(qbuffer, dbuffer, input->sz_d);
29 : 0 : return ret;
30 : : }
31 : :
32 [ # # # # : 0 : switch (qtype) {
# # # # #
# ]
33 : 0 : case RTE_ML_IO_TYPE_INT8:
34 : 0 : ret = rte_ml_io_float32_to_int8(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
35 : 0 : break;
36 : 0 : case RTE_ML_IO_TYPE_UINT8:
37 : 0 : ret = rte_ml_io_float32_to_uint8(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
38 : 0 : break;
39 : 0 : case RTE_ML_IO_TYPE_INT16:
40 : 0 : ret = rte_ml_io_float32_to_int16(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
41 : 0 : break;
42 : 0 : case RTE_ML_IO_TYPE_UINT16:
43 : 0 : ret = rte_ml_io_float32_to_uint16(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
44 : 0 : break;
45 : 0 : case RTE_ML_IO_TYPE_INT32:
46 : 0 : ret = rte_ml_io_float32_to_int32(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
47 : 0 : break;
48 : 0 : case RTE_ML_IO_TYPE_UINT32:
49 : 0 : ret = rte_ml_io_float32_to_uint32(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
50 : 0 : break;
51 : 0 : case RTE_ML_IO_TYPE_INT64:
52 : 0 : ret = rte_ml_io_float32_to_int64(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
53 : 0 : break;
54 : 0 : case RTE_ML_IO_TYPE_UINT64:
55 : 0 : ret = rte_ml_io_float32_to_uint64(dbuffer, qbuffer, nb_elements, 1.0 / qscale, 0);
56 : 0 : break;
57 : 0 : case RTE_ML_IO_TYPE_FP16:
58 : 0 : ret = rte_ml_io_float32_to_float16(dbuffer, qbuffer, nb_elements);
59 : 0 : break;
60 : 0 : default:
61 : 0 : plt_err("Unsupported qtype : %u", qtype);
62 : : ret = -ENOTSUP;
63 : : }
64 : :
65 : : return ret;
66 : : }
67 : :
68 : : inline int
69 : 0 : cnxk_ml_io_dequantize_single(struct cnxk_ml_io *output, uint8_t *qbuffer, uint8_t *dbuffer)
70 : : {
71 : : enum rte_ml_io_type qtype;
72 : : enum rte_ml_io_type dtype;
73 : : uint32_t nb_elements;
74 : : float dscale;
75 : : int ret = 0;
76 : :
77 : 0 : dtype = output->dtype;
78 : 0 : qtype = output->qtype;
79 : 0 : dscale = output->scale;
80 : 0 : nb_elements = output->nb_elements;
81 : :
82 [ # # ]: 0 : if (dtype == qtype) {
83 [ # # ]: 0 : rte_memcpy(dbuffer, qbuffer, output->sz_q);
84 : 0 : return 0;
85 : : }
86 : :
87 [ # # # # : 0 : switch (qtype) {
# # # # #
# ]
88 : 0 : case RTE_ML_IO_TYPE_INT8:
89 : 0 : ret = rte_ml_io_int8_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
90 : 0 : break;
91 : 0 : case RTE_ML_IO_TYPE_UINT8:
92 : 0 : ret = rte_ml_io_uint8_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
93 : 0 : break;
94 : 0 : case RTE_ML_IO_TYPE_INT16:
95 : 0 : ret = rte_ml_io_int16_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
96 : 0 : break;
97 : 0 : case RTE_ML_IO_TYPE_UINT16:
98 : 0 : ret = rte_ml_io_uint16_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
99 : 0 : break;
100 : 0 : case RTE_ML_IO_TYPE_INT32:
101 : 0 : ret = rte_ml_io_int32_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
102 : 0 : break;
103 : 0 : case RTE_ML_IO_TYPE_UINT32:
104 : 0 : ret = rte_ml_io_uint32_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
105 : 0 : break;
106 : 0 : case RTE_ML_IO_TYPE_INT64:
107 : 0 : ret = rte_ml_io_int64_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
108 : 0 : break;
109 : 0 : case RTE_ML_IO_TYPE_UINT64:
110 : 0 : ret = rte_ml_io_uint64_to_float32(qbuffer, dbuffer, nb_elements, dscale, 0);
111 : 0 : break;
112 : 0 : case RTE_ML_IO_TYPE_FP16:
113 : 0 : ret = rte_ml_io_float16_to_float32(qbuffer, dbuffer, nb_elements);
114 : 0 : break;
115 : 0 : default:
116 : 0 : plt_err("Unsupported qtype: %u", qtype);
117 : : ret = -ENOTSUP;
118 : : }
119 : :
120 : : return ret;
121 : : }
|