LCOV - code coverage report
Current view: top level - lib/mldev - mldev_utils_scalar_bfloat16.c (source / functions) Hit Total Coverage
Test: Code coverage Lines: 0 69 0.0 %
Date: 2024-12-01 18:57:19 Functions: 0 4 0.0 %
Legend: Lines: hit not hit | Branches: + taken - not taken # not executed Branches: 0 36 0.0 %

           Branch data     Line data    Source code
       1                 :            : /* SPDX-License-Identifier: BSD-3-Clause
       2                 :            :  * Copyright (c) 2023 Marvell.
       3                 :            :  */
       4                 :            : 
       5                 :            : #include <errno.h>
       6                 :            : #include <math.h>
       7                 :            : #include <stdint.h>
       8                 :            : 
       9                 :            : #include "mldev_utils_scalar.h"
      10                 :            : 
      11                 :            : /* Description:
      12                 :            :  * This file implements scalar versions of Machine Learning utility functions used to convert data
      13                 :            :  * types from bfloat16 to float32 and vice-versa.
      14                 :            :  */
      15                 :            : 
      16                 :            : /* Convert a single precision floating point number (float32) into a
      17                 :            :  * brain float number (bfloat16) using round to nearest rounding mode.
      18                 :            :  */
      19                 :            : static uint16_t
      20                 :          0 : __float32_to_bfloat16_scalar_rtn(float x)
      21                 :            : {
      22                 :            :         union float32 f32; /* float32 input */
      23                 :            :         uint32_t f32_s;    /* float32 sign */
      24                 :            :         uint32_t f32_e;    /* float32 exponent */
      25                 :            :         uint32_t f32_m;    /* float32 mantissa */
      26                 :            :         uint16_t b16_s;    /* float16 sign */
      27                 :            :         uint16_t b16_e;    /* float16 exponent */
      28                 :            :         uint16_t b16_m;    /* float16 mantissa */
      29                 :            :         uint32_t tbits;    /* number of truncated bits */
      30                 :            :         uint16_t u16;      /* float16 output */
      31                 :            : 
      32                 :          0 :         f32.f = x;
      33                 :          0 :         f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
      34                 :          0 :         f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
      35                 :          0 :         f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
      36                 :            : 
      37                 :            :         b16_s = f32_s;
      38                 :            :         b16_e = 0;
      39                 :            :         b16_m = 0;
      40                 :            : 
      41      [ #  #  # ]:          0 :         switch (f32_e) {
      42                 :          0 :         case (0): /* float32: zero or subnormal number */
      43                 :            :                 b16_e = 0;
      44         [ #  # ]:          0 :                 if (f32_m == 0) /* zero */
      45                 :            :                         b16_m = 0;
      46                 :            :                 else /* subnormal float32 number, normal bfloat16 */
      47                 :          0 :                         goto bf16_normal;
      48                 :            :                 break;
      49                 :          0 :         case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
      50                 :            :                 b16_e = BF16_MASK_E >> BF16_LSB_E;
      51         [ #  # ]:          0 :                 if (f32_m == 0) { /* infinity */
      52                 :            :                         b16_m = 0;
      53                 :            :                 } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
      54                 :          0 :                         b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M);
      55                 :          0 :                         b16_m |= BIT(BF16_MSB_M);
      56                 :            :                 }
      57                 :            :                 break;
      58                 :          0 :         default: /* float32: normal number, normal bfloat16 */
      59                 :          0 :                 goto bf16_normal;
      60                 :            :         }
      61                 :            : 
      62                 :          0 :         goto bf16_pack;
      63                 :            : 
      64                 :          0 : bf16_normal:
      65                 :          0 :         b16_e = f32_e;
      66                 :            :         tbits = FP32_MSB_M - BF16_MSB_M;
      67                 :          0 :         b16_m = f32_m >> tbits;
      68                 :            : 
      69                 :            :         /* if non-leading truncated bits are set */
      70         [ #  # ]:          0 :         if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
      71                 :          0 :                 b16_m++;
      72                 :            : 
      73                 :            :                 /* if overflow into exponent */
      74         [ #  # ]:          0 :                 if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
      75                 :          0 :                         b16_e++;
      76         [ #  # ]:          0 :         } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
      77                 :            :                 /* if only leading truncated bit is set */
      78         [ #  # ]:          0 :                 if ((b16_m & 0x1) == 0x1) {
      79                 :          0 :                         b16_m++;
      80                 :            : 
      81                 :            :                         /* if overflow into exponent */
      82         [ #  # ]:          0 :                         if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1)
      83                 :          0 :                                 b16_e++;
      84                 :            :                 }
      85                 :            :         }
      86                 :          0 :         b16_m = b16_m & BF16_MASK_M;
      87                 :            : 
      88                 :          0 : bf16_pack:
      89                 :          0 :         u16 = BF16_PACK(b16_s, b16_e, b16_m);
      90                 :            : 
      91                 :          0 :         return u16;
      92                 :            : }
      93                 :            : 
      94                 :            : int
      95                 :          0 : rte_ml_io_float32_to_bfloat16(const void *input, void *output, uint64_t nb_elements)
      96                 :            : {
      97                 :            :         const float *input_buffer;
      98                 :            :         uint16_t *output_buffer;
      99                 :            :         uint64_t i;
     100                 :            : 
     101   [ #  #  #  # ]:          0 :         if ((nb_elements == 0) || (input == NULL) || (output == NULL))
     102                 :            :                 return -EINVAL;
     103                 :            : 
     104                 :            :         input_buffer = (const float *)input;
     105                 :            :         output_buffer = (uint16_t *)output;
     106                 :            : 
     107         [ #  # ]:          0 :         for (i = 0; i < nb_elements; i++) {
     108                 :          0 :                 *output_buffer = __float32_to_bfloat16_scalar_rtn(*input_buffer);
     109                 :            : 
     110                 :          0 :                 input_buffer = input_buffer + 1;
     111                 :          0 :                 output_buffer = output_buffer + 1;
     112                 :            :         }
     113                 :            : 
     114                 :            :         return 0;
     115                 :            : }
     116                 :            : 
     117                 :            : /* Convert a brain float number (bfloat16) into a
     118                 :            :  * single precision floating point number (float32).
     119                 :            :  */
     120                 :            : static float
     121                 :          0 : __bfloat16_to_float32_scalar_rtx(uint16_t f16)
     122                 :            : {
     123                 :            :         union float32 f32; /* float32 output */
     124                 :            :         uint16_t b16_s;    /* float16 sign */
     125                 :            :         uint16_t b16_e;    /* float16 exponent */
     126                 :            :         uint16_t b16_m;    /* float16 mantissa */
     127                 :            :         uint32_t f32_s;    /* float32 sign */
     128                 :            :         uint32_t f32_e;    /* float32 exponent */
     129                 :            :         uint32_t f32_m;    /* float32 mantissa*/
     130                 :            :         uint8_t shift;     /* number of bits to be shifted */
     131                 :            : 
     132                 :          0 :         b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S;
     133                 :          0 :         b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E;
     134                 :          0 :         b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M;
     135                 :            : 
     136                 :          0 :         f32_s = b16_s;
     137      [ #  #  # ]:          0 :         switch (b16_e) {
     138                 :          0 :         case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */
     139                 :            :                 f32_e = FP32_MASK_E >> FP32_LSB_E;
     140         [ #  # ]:          0 :                 if (b16_m == 0x0) { /* infinity */
     141                 :            :                         f32_m = 0;
     142                 :            :                 } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
     143                 :          0 :                         f32_m = b16_m;
     144                 :            :                         shift = FP32_MSB_M - BF16_MSB_M;
     145                 :          0 :                         f32_m = (f32_m << shift) & FP32_MASK_M;
     146                 :          0 :                         f32_m |= BIT(FP32_MSB_M);
     147                 :            :                 }
     148                 :            :                 break;
     149                 :          0 :         case 0: /* bfloat16: zero or subnormal */
     150                 :          0 :                 f32_m = b16_m;
     151         [ #  # ]:          0 :                 if (b16_m == 0) { /* zero signed */
     152                 :            :                         f32_e = 0;
     153                 :            :                 } else { /* subnormal numbers */
     154                 :          0 :                         goto fp32_normal;
     155                 :            :                 }
     156                 :            :                 break;
     157                 :          0 :         default: /* bfloat16: normal number */
     158                 :          0 :                 goto fp32_normal;
     159                 :            :         }
     160                 :            : 
     161                 :          0 :         goto fp32_pack;
     162                 :            : 
     163                 :          0 : fp32_normal:
     164                 :          0 :         f32_m = b16_m;
     165                 :          0 :         f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E;
     166                 :            : 
     167                 :            :         shift = (FP32_MSB_M - BF16_MSB_M);
     168                 :          0 :         f32_m = (f32_m << shift) & FP32_MASK_M;
     169                 :            : 
     170                 :          0 : fp32_pack:
     171                 :          0 :         f32.u = FP32_PACK(f32_s, f32_e, f32_m);
     172                 :            : 
     173                 :          0 :         return f32.f;
     174                 :            : }
     175                 :            : 
     176                 :            : int
     177                 :          0 : rte_ml_io_bfloat16_to_float32(const void *input, void *output, uint64_t nb_elements)
     178                 :            : {
     179                 :            :         const uint16_t *input_buffer;
     180                 :            :         float *output_buffer;
     181                 :            :         uint64_t i;
     182                 :            : 
     183   [ #  #  #  # ]:          0 :         if ((nb_elements == 0) || (input == NULL) || (output == NULL))
     184                 :            :                 return -EINVAL;
     185                 :            : 
     186                 :            :         input_buffer = (const uint16_t *)input;
     187                 :            :         output_buffer = (float *)output;
     188                 :            : 
     189         [ #  # ]:          0 :         for (i = 0; i < nb_elements; i++) {
     190                 :          0 :                 *output_buffer = __bfloat16_to_float32_scalar_rtx(*input_buffer);
     191                 :            : 
     192                 :          0 :                 input_buffer = input_buffer + 1;
     193                 :          0 :                 output_buffer = output_buffer + 1;
     194                 :            :         }
     195                 :            : 
     196                 :            :         return 0;
     197                 :            : }

Generated by: LCOV version 1.14