Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause 2 : : * Copyright (c) 2022 Marvell. 3 : : */ 4 : : 5 : : #include <errno.h> 6 : : 7 : : #include <rte_common.h> 8 : : #include <rte_malloc.h> 9 : : #include <rte_mldev.h> 10 : : 11 : : #include "ml_common.h" 12 : : #include "test_model_common.h" 13 : : 14 : : int 15 : 0 : ml_model_load(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 16 : : { 17 : : struct rte_ml_model_params model_params; 18 : : int ret; 19 : : 20 : : RTE_SET_USED(test); 21 : : 22 : 0 : if (model->state == MODEL_LOADED) 23 : : return 0; 24 : : 25 : 0 : if (model->state != MODEL_INITIAL) 26 : : return -EINVAL; 27 : : 28 : : /* read model binary */ 29 : 0 : ret = ml_read_file(opt->filelist[fid].model, &model_params.size, 30 : : (char **)&model_params.addr); 31 : 0 : if (ret != 0) 32 : : return ret; 33 : : 34 : : /* load model to device */ 35 : 0 : ret = rte_ml_model_load(opt->dev_id, &model_params, &model->id); 36 : 0 : if (ret != 0) { 37 : 0 : ml_err("Failed to load model : %s\n", opt->filelist[fid].model); 38 : 0 : model->state = MODEL_ERROR; 39 : 0 : free(model_params.addr); 40 : 0 : return ret; 41 : : } 42 : : 43 : : /* release buffer */ 44 : 0 : free(model_params.addr); 45 : : 46 : : /* get model info */ 47 : 0 : ret = rte_ml_model_info_get(opt->dev_id, model->id, &model->info); 48 : 0 : if (ret != 0) { 49 : 0 : ml_err("Failed to get model info : %s\n", opt->filelist[fid].model); 50 : 0 : return ret; 51 : : } 52 : : 53 : 0 : model->state = MODEL_LOADED; 54 : : 55 : 0 : return 0; 56 : : } 57 : : 58 : : int 59 : 0 : ml_model_unload(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 60 : : { 61 : : struct test_common *t = ml_test_priv(test); 62 : : int ret; 63 : : 64 : : RTE_SET_USED(t); 65 : : 66 : 0 : if (model->state == MODEL_INITIAL) 67 : : return 0; 68 : : 69 : 0 : if (model->state != MODEL_LOADED) 70 : : return -EINVAL; 71 : : 72 : : /* unload model */ 73 : 0 : ret = rte_ml_model_unload(opt->dev_id, model->id); 74 : 0 : if (ret != 0) { 75 : 0 : ml_err("Failed to unload model: %s\n", opt->filelist[fid].model); 76 : 0 : model->state = MODEL_ERROR; 77 : 0 : return ret; 78 : : } 79 : : 80 : 0 : model->state = MODEL_INITIAL; 81 : : 82 : 0 : return 0; 83 : : } 84 : : 85 : : int 86 : 0 : ml_model_start(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 87 : : { 88 : : struct test_common *t = ml_test_priv(test); 89 : : int ret; 90 : : 91 : : RTE_SET_USED(t); 92 : : 93 : 0 : if (model->state == MODEL_STARTED) 94 : : return 0; 95 : : 96 : 0 : if (model->state != MODEL_LOADED) 97 : : return -EINVAL; 98 : : 99 : : /* start model */ 100 : 0 : ret = rte_ml_model_start(opt->dev_id, model->id); 101 : 0 : if (ret != 0) { 102 : 0 : ml_err("Failed to start model : %s\n", opt->filelist[fid].model); 103 : 0 : model->state = MODEL_ERROR; 104 : 0 : return ret; 105 : : } 106 : : 107 : 0 : model->state = MODEL_STARTED; 108 : : 109 : 0 : return 0; 110 : : } 111 : : 112 : : int 113 : 0 : ml_model_stop(struct ml_test *test, struct ml_options *opt, struct ml_model *model, uint16_t fid) 114 : : { 115 : : struct test_common *t = ml_test_priv(test); 116 : : int ret; 117 : : 118 : : RTE_SET_USED(t); 119 : : 120 : 0 : if (model->state == MODEL_LOADED) 121 : : return 0; 122 : : 123 : 0 : if (model->state != MODEL_STARTED) 124 : : return -EINVAL; 125 : : 126 : : /* stop model */ 127 : 0 : ret = rte_ml_model_stop(opt->dev_id, model->id); 128 : 0 : if (ret != 0) { 129 : 0 : ml_err("Failed to stop model: %s\n", opt->filelist[fid].model); 130 : 0 : model->state = MODEL_ERROR; 131 : 0 : return ret; 132 : : } 133 : : 134 : 0 : model->state = MODEL_LOADED; 135 : : 136 : 0 : return 0; 137 : : }