LLVM
22.0.0git
include
llvm
Analysis
MLModelRunner.h
Go to the documentation of this file.
1
//===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
10
#ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
11
#define LLVM_ANALYSIS_MLMODELRUNNER_H
12
13
#include "
llvm/Analysis/TensorSpec.h
"
14
#include "
llvm/IR/PassManager.h
"
15
16
namespace
llvm
{
17
class
LLVMContext
;
18
19
/// MLModelRunner interface: abstraction of a mechanism for evaluating a
20
/// ML model. More abstractly, evaluating a function that has as tensors as
21
/// arguments, described via TensorSpecs, and returns a tensor. Currently, the
22
/// latter is assumed to be a scalar, in absence of more elaborate scenarios.
23
/// NOTE: feature indices are expected to be consistent all accross
24
/// MLModelRunners (pertaining to the same model), and also Loggers (see
25
/// TFUtils.h)
26
class
MLModelRunner
{
27
public
:
28
// Disallows copy and assign.
29
MLModelRunner
(
const
MLModelRunner
&) =
delete
;
30
MLModelRunner
&
operator=
(
const
MLModelRunner
&) =
delete
;
31
virtual
~MLModelRunner
() =
default
;
32
33
template
<
typename
T>
T
evaluate
() {
34
return
*
reinterpret_cast<
T
*
>
(
evaluateUntyped
());
35
}
36
37
template
<
typename
T,
typename
I>
T
*
getTensor
(
I
FeatureID) {
38
return
reinterpret_cast<
T
*
>
(
39
getTensorUntyped
(
static_cast<
size_t
>
(FeatureID)));
40
}
41
42
template
<
typename
T,
typename
I>
const
T
*
getTensor
(
I
FeatureID)
const
{
43
return
reinterpret_cast<
const
T
*
>
(
44
getTensorUntyped
(
static_cast<
size_t
>
(FeatureID)));
45
}
46
47
void
*
getTensorUntyped
(
size_t
Index) {
return
InputBuffers[Index]; }
48
const
void
*
getTensorUntyped
(
size_t
Index)
const
{
49
return
(
const_cast<
MLModelRunner
*
>
(
this
))->getTensorUntyped(Index);
50
}
51
52
enum class
Kind
:
int
{
Unknown
,
Release
,
Development
,
NoOp
,
Interactive
};
53
Kind
getKind
()
const
{
return
Type
; }
54
virtual
void
switchContext
(
StringRef
Name) {}
55
56
protected
:
57
MLModelRunner
(
LLVMContext
&
Ctx
,
Kind
Type
,
size_t
NumInputs)
58
:
Ctx
(
Ctx
),
Type
(
Type
), InputBuffers(NumInputs) {
59
assert
(
Type
!=
Kind::Unknown
);
60
}
61
virtual
void
*
evaluateUntyped
() = 0;
62
63
void
setUpBufferForTensor
(
size_t
Index,
const
TensorSpec
&
Spec
,
64
void
*Buffer) {
65
if
(!Buffer) {
66
OwnedBuffers.emplace_back(
Spec
.getTotalTensorBufferSize());
67
Buffer = OwnedBuffers.back().data();
68
}
69
InputBuffers[Index] = Buffer;
70
}
71
72
LLVMContext
&
Ctx
;
73
const
Kind
Type
;
74
75
private
:
76
std::vector<void *> InputBuffers;
77
std::vector<std::vector<char *>> OwnedBuffers;
78
};
79
}
// namespace llvm
80
81
#endif
// LLVM_ANALYSIS_MLMODELRUNNER_H
assert
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
PassManager.h
This header defines various interfaces for pass management in LLVM.
I
#define I(x, y, z)
Definition
MD5.cpp:58
T
#define T
Definition
Mips16ISelLowering.cpp:353
TensorSpec.h
llvm::LLVMContext
This is an important class for using LLVM in a threaded context.
Definition
LLVMContext.h:68
llvm::MLModelRunner::getKind
Kind getKind() const
Definition
MLModelRunner.h:53
llvm::MLModelRunner::getTensor
const T * getTensor(I FeatureID) const
Definition
MLModelRunner.h:42
llvm::MLModelRunner::evaluateUntyped
virtual void * evaluateUntyped()=0
llvm::MLModelRunner::switchContext
virtual void switchContext(StringRef Name)
Definition
MLModelRunner.h:54
llvm::MLModelRunner::getTensorUntyped
void * getTensorUntyped(size_t Index)
Definition
MLModelRunner.h:47
llvm::MLModelRunner::getTensor
T * getTensor(I FeatureID)
Definition
MLModelRunner.h:37
llvm::MLModelRunner::evaluate
T evaluate()
Definition
MLModelRunner.h:33
llvm::MLModelRunner::Kind
Kind
Definition
MLModelRunner.h:52
llvm::MLModelRunner::Kind::Interactive
@ Interactive
Definition
MLModelRunner.h:52
llvm::MLModelRunner::Kind::Development
@ Development
Definition
MLModelRunner.h:52
llvm::MLModelRunner::Kind::NoOp
@ NoOp
Definition
MLModelRunner.h:52
llvm::MLModelRunner::Kind::Unknown
@ Unknown
Definition
MLModelRunner.h:52
llvm::MLModelRunner::Kind::Release
@ Release
Definition
MLModelRunner.h:52
llvm::MLModelRunner::setUpBufferForTensor
void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, void *Buffer)
Definition
MLModelRunner.h:63
llvm::MLModelRunner::getTensorUntyped
const void * getTensorUntyped(size_t Index) const
Definition
MLModelRunner.h:48
llvm::MLModelRunner::Type
const Kind Type
Definition
MLModelRunner.h:73
llvm::MLModelRunner::~MLModelRunner
virtual ~MLModelRunner()=default
llvm::MLModelRunner::MLModelRunner
MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NumInputs)
Definition
MLModelRunner.h:57
llvm::MLModelRunner::MLModelRunner
MLModelRunner(const MLModelRunner &)=delete
llvm::MLModelRunner::operator=
MLModelRunner & operator=(const MLModelRunner &)=delete
llvm::MLModelRunner::Ctx
LLVMContext & Ctx
Definition
MLModelRunner.h:72
llvm::StringRef
StringRef - Represent a constant reference to a string, i.e.
Definition
StringRef.h:55
llvm::TensorSpec
Definition
TensorSpec.h:63
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition
AddressRanges.h:18
llvm::Spec
Definition
FunctionSpecialization.h:128
Generated on
for LLVM by
1.14.0