Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause 2 : : * Copyright (c) 2023 Marvell. 3 : : */ 4 : : 5 : : #include <errno.h> 6 : : #include <math.h> 7 : : #include <stdint.h> 8 : : 9 : : #include "mldev_utils_scalar.h" 10 : : 11 : : /* Description: 12 : : * This file implements scalar versions of Machine Learning utility functions used to convert data 13 : : * types from bfloat16 to float32 and vice-versa. 14 : : */ 15 : : 16 : : /* Convert a single precision floating point number (float32) into a 17 : : * brain float number (bfloat16) using round to nearest rounding mode. 18 : : */ 19 : : static uint16_t 20 : 0 : __float32_to_bfloat16_scalar_rtn(float x) 21 : : { 22 : : union float32 f32; /* float32 input */ 23 : : uint32_t f32_s; /* float32 sign */ 24 : : uint32_t f32_e; /* float32 exponent */ 25 : : uint32_t f32_m; /* float32 mantissa */ 26 : : uint16_t b16_s; /* float16 sign */ 27 : : uint16_t b16_e; /* float16 exponent */ 28 : : uint16_t b16_m; /* float16 mantissa */ 29 : : uint32_t tbits; /* number of truncated bits */ 30 : : uint16_t u16; /* float16 output */ 31 : : 32 : 0 : f32.f = x; 33 : 0 : f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; 34 : 0 : f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; 35 : 0 : f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; 36 : : 37 : : b16_s = f32_s; 38 : : b16_e = 0; 39 : : b16_m = 0; 40 : : 41 [ # # # ]: 0 : switch (f32_e) { 42 : 0 : case (0): /* float32: zero or subnormal number */ 43 : : b16_e = 0; 44 [ # # ]: 0 : if (f32_m == 0) /* zero */ 45 : : b16_m = 0; 46 : : else /* subnormal float32 number, normal bfloat16 */ 47 : 0 : goto bf16_normal; 48 : : break; 49 : 0 : case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ 50 : : b16_e = BF16_MASK_E >> BF16_LSB_E; 51 [ # # ]: 0 : if (f32_m == 0) { /* infinity */ 52 : : b16_m = 0; 53 : : } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ 54 : 0 : b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); 55 : 0 : b16_m |= BIT(BF16_MSB_M); 56 : : } 57 : : break; 58 : 0 : default: /* float32: normal number, normal bfloat16 */ 59 : 0 : goto bf16_normal; 60 : : } 61 : : 62 : 0 : goto bf16_pack; 63 : : 64 : 0 : bf16_normal: 65 : 0 : b16_e = f32_e; 66 : : tbits = FP32_MSB_M - BF16_MSB_M; 67 : 0 : b16_m = f32_m >> tbits; 68 : : 69 : : /* if non-leading truncated bits are set */ 70 [ # # ]: 0 : if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { 71 : 0 : b16_m++; 72 : : 73 : : /* if overflow into exponent */ 74 [ # # ]: 0 : if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 75 : 0 : b16_e++; 76 [ # # ]: 0 : } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { 77 : : /* if only leading truncated bit is set */ 78 [ # # ]: 0 : if ((b16_m & 0x1) == 0x1) { 79 : 0 : b16_m++; 80 : : 81 : : /* if overflow into exponent */ 82 [ # # ]: 0 : if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) 83 : 0 : b16_e++; 84 : : } 85 : : } 86 : 0 : b16_m = b16_m & BF16_MASK_M; 87 : : 88 : 0 : bf16_pack: 89 : 0 : u16 = BF16_PACK(b16_s, b16_e, b16_m); 90 : : 91 : 0 : return u16; 92 : : } 93 : : 94 : : int 95 : 0 : rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements) 96 : : { 97 : : const float *input_buffer; 98 : : uint16_t *output_buffer; 99 : : uint64_t i; 100 : : 101 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 102 : : return -EINVAL; 103 : : 104 : : input_buffer = (const float *)input; 105 : : output_buffer = (uint16_t *)output; 106 : : 107 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) { 108 : 0 : *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer); 109 : : 110 : 0 : input_buffer = input_buffer + 1; 111 : 0 : output_buffer = output_buffer + 1; 112 : : } 113 : : 114 : : return 0; 115 : : } 116 : : 117 : : /* Convert a brain float number (bfloat16) into a 118 : : * single precision floating point number (float32). 119 : : */ 120 : : static float 121 : 0 : __bfloat16_to_float32_scalar_rtx(uint16_t f16) 122 : : { 123 : : union float32 f32; /* float32 output */ 124 : : uint16_t b16_s; /* float16 sign */ 125 : : uint16_t b16_e; /* float16 exponent */ 126 : : uint16_t b16_m; /* float16 mantissa */ 127 : : uint32_t f32_s; /* float32 sign */ 128 : : uint32_t f32_e; /* float32 exponent */ 129 : : uint32_t f32_m; /* float32 mantissa*/ 130 : : uint8_t shift; /* number of bits to be shifted */ 131 : : 132 : 0 : b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; 133 : 0 : b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; 134 : 0 : b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; 135 : : 136 : 0 : f32_s = b16_s; 137 [ # # # ]: 0 : switch (b16_e) { 138 : 0 : case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ 139 : : f32_e = FP32_MASK_E >> FP32_LSB_E; 140 [ # # ]: 0 : if (b16_m == 0x0) { /* infinity */ 141 : : f32_m = 0; 142 : : } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ 143 : 0 : f32_m = b16_m; 144 : : shift = FP32_MSB_M - BF16_MSB_M; 145 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M; 146 : 0 : f32_m |= BIT(FP32_MSB_M); 147 : : } 148 : : break; 149 : 0 : case 0: /* bfloat16: zero or subnormal */ 150 : 0 : f32_m = b16_m; 151 [ # # ]: 0 : if (b16_m == 0) { /* zero signed */ 152 : : f32_e = 0; 153 : : } else { /* subnormal numbers */ 154 : 0 : goto fp32_normal; 155 : : } 156 : : break; 157 : 0 : default: /* bfloat16: normal number */ 158 : 0 : goto fp32_normal; 159 : : } 160 : : 161 : 0 : goto fp32_pack; 162 : : 163 : 0 : fp32_normal: 164 : 0 : f32_m = b16_m; 165 : 0 : f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; 166 : : 167 : : shift = (FP32_MSB_M - BF16_MSB_M); 168 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M; 169 : : 170 : 0 : fp32_pack: 171 : 0 : f32.u = FP32_PACK(f32_s, f32_e, f32_m); 172 : : 173 : 0 : return f32.f; 174 : : } 175 : : 176 : : int 177 : 0 : rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements) 178 : : { 179 : : const uint16_t *input_buffer; 180 : : float *output_buffer; 181 : : uint64_t i; 182 : : 183 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL)) 184 : : return -EINVAL; 185 : : 186 : : input_buffer = (const uint16_t *)input; 187 : : output_buffer = (float *)output; 188 : : 189 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) { 190 : 0 : *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer); 191 : : 192 : 0 : input_buffer = input_buffer + 1; 193 : 0 : output_buffer = output_buffer + 1; 194 : : } 195 : : 196 : : return 0; 197 : : }