Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause 2 : : * Copyright (c) 2023 Marvell. 3 : : */ 4 : : 5 : : #include <errno.h> 6 : : #include <math.h> 7 : : #include <stdint.h> 8 : : 9 : : #include "mldev_utils_scalar.h" 10 : : 11 : : #include <eal_export.h> 12 : : 13 : : /* Description: 14 : : * This file implements scalar versions of Machine Learning utility functions used to convert data 15 : : * types from bfloat16 to float32 and vice-versa. 16 : : */ 17 : : 18 : : /* Convert a single precision floating point number (float32) into a 19 : : * brain float number (bfloat16) using round to nearest rounding mode. 20 : : */ 21 : : static uint16_t 22 : 0 : __float32_to_bfloat16_scalar_rtn(float x) 23 : : { 24 : : union float32 f32; /* float32 input */ 25 : : uint32_t f32_s; /* float32 sign */ 26 : : uint32_t f32_e; /* float32 exponent */ 27 : : uint32_t f32_m; /* float32 mantissa */ 28 : : uint16_t b16_s; /* float16 sign */ 29 : : uint16_t b16_e; /* float16 exponent */ 30 : : uint16_t b16_m; /* float16 mantissa */ 31 : : uint32_t tbits; /* number of truncated bits */ 32 : : uint16_t u16; /* float16 output */ 33 : : 34 : 0 : f32.f = x; 35 : 0 : f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 36 : 0 : f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 37 : 0 : f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 38 : : 39 : : b16_s = f32_s; 40 : : b16_e = 0; 41 : : b16_m = 0; 42 : : 43 [ # # # ]: 0 : switch (f32_e) { 44 : 0 : case (0): /* float32: zero or subnormal number */ 45 : : b16_e = 0; 46 [ # # ]: 0 : if (f32_m == 0) /* zero */ 47 : : b16_m = 0; 48 : : else /* subnormal float32 number, normal bfloat16 */ 49 : 0 : goto bf16_normal; 50 : : break; 51 : 0 : case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 52 : : b16_e = BF16_MASK_E >> BF16_LSB_E; 53 [ # # ]: 0 : if (f32_m == 0) { /* infinity */ 54 : : b16_m = 0; 55 : : } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 56 : 0 : b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 57 : 0 : b16_m |= BIT(BF16_MSB_M); 58 : : } 59 : : break; 60 : 0 : default: /* float32: normal number, normal bfloat16 */ 61 : 0 : goto bf16_normal; 62 : : } 63 : : 64 : 0 : goto bf16_pack; 65 : : 66 : 0 : bf16_normal: 67 : 0 : b16_e = f32_e; 68 : : tbits = FP32_MSB_M - BF16_MSB_M; 69 : 0 : b16_m = f32_m >> tbits; 70 : : 71 : : /* if non-leading truncated bits are set */ 72 [ # # ]: 0 : if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 73 : 0 : b16_m++; 74 : : 75 : : /* if overflow into exponent */ 76 [ # # ]: 0 : if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 77 : 0 : b16_e++; 78 [ # # ]: 0 : } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 79 : : /* if only leading truncated bit is set */ 80 [ # # ]: 0 : if ((b16_m & 0x1) == 0x1) { 81 : 0 : b16_m++; 82 : : 83 : : /* if overflow into exponent */ 84 [ # # ]: 0 : if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 85 : 0 : b16_e++; 86 : : } 87 : : } 88 : 0 : b16_m = b16_m & BF16_MASK_M; 89 : : 90 : 0 : bf16_pack: 91 : 0 : u16 = BF16_PACK(b16_s, b16_e, b16_m); 92 : : 93 : 0 : return u16; 94 : : } 95 : : 96 : : RTE_EXPORT_EXPERIMENTAL_SYMBOL(rte_ml_io_float32_to_bfloat16, 22.11) 97 : : int 98 : 0 : rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements) 99 : : { 100 : : const float *input_buffer; 101 : : uint16_t *output_buffer; 102 : : uint64_t i; 103 : : 104 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 105 : : return -EINVAL; 106 : : 107 : : input_buffer = (const float *)input; 108 : : output_buffer = (uint16_t *)output; 109 : : 110 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) { 111 : 0 : *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 112 : : 113 : 0 : input_buffer = input_buffer + 1; 114 : 0 : output_buffer = output_buffer + 1; 115 : : } 116 : : 117 : : return 0; 118 : : } 119 : : 120 : : /* Convert a brain float number (bfloat16) into a 121 : : * single precision floating point number (float32). 122 : : */ 123 : : static float 124 : 0 : __bfloat16_to_float32_scalar_rtx(uint16_t f16) 125 : : { 126 : : union float32 f32; /* float32 output */ 127 : : uint16_t b16_s; /* float16 sign */ 128 : : uint16_t b16_e; /* float16 exponent */ 129 : : uint16_t b16_m; /* float16 mantissa */ 130 : : uint32_t f32_s; /* float32 sign */ 131 : : uint32_t f32_e; /* float32 exponent */ 132 : : uint32_t f32_m; /* float32 mantissa*/ 133 : : uint8_t shift; /* number of bits to be shifted */ 134 : : 135 : 0 : b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 136 : 0 : b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 137 : 0 : b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 138 : : 139 : 0 : f32_s = b16_s; 140 [ # # # ]: 0 : switch (b16_e) { 141 : 0 : case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 142 : : f32_e = FP32_MASK_E >> FP32_LSB_E; 143 [ # # ]: 0 : if (b16_m == 0x0) { /* infinity */ 144 : : f32_m = 0; 145 : : } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 146 : 0 : f32_m = b16_m; 147 : : shift = FP32_MSB_M - BF16_MSB_M; 148 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M; 149 : 0 : f32_m |= BIT(FP32_MSB_M); 150 : : } 151 : : break; 152 : 0 : case 0: /* bfloat16: zero or subnormal */ 153 : 0 : f32_m = b16_m; 154 [ # # ]: 0 : if (b16_m == 0) { /* zero signed */ 155 : : f32_e = 0; 156 : : } else { /* subnormal numbers */ 157 : 0 : goto fp32_normal; 158 : : } 159 : : break; 160 : 0 : default: /* bfloat16: normal number */ 161 : 0 : goto fp32_normal; 162 : : } 163 : : 164 : 0 : goto fp32_pack; 165 : : 166 : 0 : fp32_normal: 167 : 0 : f32_m = b16_m; 168 : 0 : f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 169 : : 170 : : shift = (FP32_MSB_M - BF16_MSB_M); 171 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M; 172 : : 173 : 0 : fp32_pack: 174 : 0 : f32.u = FP32_PACK(f32_s, f32_e, f32_m); 175 : : 176 : 0 : return f32.f; 177 : : } 178 : : 179 : : RTE_EXPORT_EXPERIMENTAL_SYMBOL(rte_ml_io_bfloat16_to_float32, 22.11) 180 : : int 181 : 0 : rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements) 182 : : { 183 : : const uint16_t *input_buffer; 184 : : float *output_buffer; 185 : : uint64_t i; 186 : : 187 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 188 : : return -EINVAL; 189 : : 190 : : input_buffer = (const uint16_t *)input; 191 : : output_buffer = (float *)output; 192 : : 193 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) { 194 : 0 : *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 195 : : 196 : 0 : input_buffer = input_buffer + 1; 197 : 0 : output_buffer = output_buffer + 1; 198 : : } 199 : : 200 : : return 0; 201 : : }