Skip to content

Commit 79813b5

Browse files
authored
NFC: exp DXIL prereq: Separate llvm insts and dxil ops, add accessors (#7977)
This change creates separate lists for llvm instructions and dxil operations, and adds accessors for these.
1 parent 39d2165 commit 79813b5

File tree

2 files changed

+49
-45
lines changed

2 files changed

+49
-45
lines changed

utils/hct/hctdb.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ class db_dxil(object):
342342
"A database of DXIL instruction data"
343343

344344
def __init__(self):
345-
self.instr = [] # DXIL instructions
345+
self._llvm_insts = [] # LLVM instructions
346+
self._dxil_ops = [] # DXIL operations
346347
self.enums = [] # enumeration types
347348
self.val_rules = [] # validation rules
348349
self.metadata = [] # named metadata (db_dxil_metadata)
@@ -373,8 +374,25 @@ def __init__(self):
373374
self.build_semantics()
374375
self.populate_counters()
375376

377+
def get_llvm_insts(self):
378+
"Get all LLVM instructions."
379+
for i in self._llvm_insts:
380+
yield i
381+
382+
def get_dxil_ops(self):
383+
"Get all DXIL operations."
384+
for i in self._dxil_ops:
385+
yield i
386+
387+
def get_all_insts(self):
388+
"Get all instructions, including LLVM and DXIL operations."
389+
for i in self._llvm_insts:
390+
yield i
391+
for i in self._dxil_ops:
392+
yield i
393+
376394
def __str__(self):
377-
return "\n".join(str(i) for i in self.instr)
395+
return "\n".join(str(i) for i in self.get_all_insts())
378396

379397
def next_id(self):
380398
"Returns the next available DXIL op ID and increments the counter"
@@ -397,11 +415,10 @@ def build_opcode_enum(self):
397415
# Keep track of last seen class/category pairs for OpCodeClass
398416
class_dict = {}
399417
class_dict["LlvmInst"] = "LLVM Instructions"
400-
for i in self.instr:
401-
if i.is_dxil_op:
402-
v = OpCodeEnum.add_value(i.dxil_opid, i.dxil_op, i.doc)
403-
v.category = i.category
404-
class_dict[i.dxil_class] = i.category
418+
for i in self.get_dxil_ops():
419+
v = OpCodeEnum.add_value(i.dxil_opid, i.dxil_op, i.doc)
420+
v.category = i.category
421+
class_dict[i.dxil_class] = i.category
405422

406423
# Build OpCodeClass enum
407424
OpCodeClass = self.add_enum_type(
@@ -574,7 +591,7 @@ def populate_categories_and_models(self):
574591
self.name_idx[i].category = "Other"
575592
for i in "LegacyF32ToF16,LegacyF16ToF32".split(","):
576593
self.name_idx[i].category = "Legacy floating-point"
577-
for i in self.instr:
594+
for i in self.get_dxil_ops():
578595
if i.name.startswith("Wave"):
579596
i.category = "Wave"
580597
i.is_wave = True
@@ -5907,20 +5924,16 @@ def UFI(name, **mappings):
59075924

59085925
# TODO - some arguments are required to be immediate constants in DXIL, eg resource kinds; add this information
59095926
# consider - report instructions that are overloaded on a single type, then turn them into non-overloaded version of that type
5910-
self.verify_dense(
5911-
self.get_dxil_insts(), lambda x: x.dxil_opid, lambda x: x.name
5912-
)
5913-
for i in self.instr:
5927+
self.verify_dense(self.get_dxil_ops(), lambda x: x.dxil_opid, lambda x: x.name)
5928+
for i in self.get_all_insts():
59145929
self.verify_dense(i.ops, lambda x: x.pos, lambda x: i.name)
59155930

59165931
# Verify that all operations in each class have the same signature.
59175932
import itertools
59185933

59195934
class_sort_func = lambda x, y: x < y
59205935
class_key_func = lambda x: x.dxil_class
5921-
instr_ordered_by_class = sorted(
5922-
[i for i in self.instr if i.is_dxil_op], key=class_key_func
5923-
)
5936+
instr_ordered_by_class = sorted(self.get_dxil_ops(), key=class_key_func)
59245937
instr_grouped_by_class = itertools.groupby(
59255938
instr_ordered_by_class, key=class_key_func
59265939
)
@@ -8492,9 +8505,9 @@ def build_valrules(self):
84928505
def populate_counters(self):
84938506
self.llvm_op_counters = set()
84948507
self.dxil_op_counters = set()
8495-
for i in self.instr:
8508+
for i in self.get_all_insts():
84968509
counters = getattr(i, "props", {}).get("counters", ())
8497-
if i.dxil_opid:
8510+
if i.is_dxil_op:
84988511
self.dxil_op_counters.update(counters)
84998512
else:
85008513
self.llvm_op_counters.update(counters)
@@ -8518,7 +8531,10 @@ def add_inst(self, i):
85188531
# These should not overlap, but UDiv is a known collision.
85198532
assert i.name not in self.name_idx, f"Duplicate instruction name: {i.name}"
85208533
self.name_idx[i.name] = i
8521-
self.instr.append(i)
8534+
if i.is_dxil_op:
8535+
self._dxil_ops.append(i)
8536+
else:
8537+
self._llvm_insts.append(i)
85228538
return i
85238539

85248540
def add_llvm_instr(
@@ -8581,23 +8597,18 @@ def reserve_dxil_op_range(self, group_name, count, start_reserved_id=0):
85818597

85828598
def get_instr_by_llvm_name(self, llvm_name):
85838599
"Return the instruction with the given LLVM name"
8584-
return next(i for i in self.instr if i.llvm_name == llvm_name)
8585-
8586-
def get_dxil_insts(self):
8587-
for i in self.instr:
8588-
if i.dxil_op != "":
8589-
yield i
8600+
return next(i for i in self.get_llvm_insts() if i.llvm_name == llvm_name)
85908601

85918602
def print_stats(self):
85928603
"Print some basic statistics on the instruction database."
8593-
print("Instruction count: %d" % len(self.instr))
8604+
print("Instruction count: %d" % len(self.get_all_insts()))
85948605
print(
85958606
"Max parameter count in instruction: %d"
8596-
% max(len(i.ops) - 1 for i in self.instr)
8607+
% max(len(i.ops) - 1 for i in self.get_all_insts())
85978608
)
85988609
print(
85998610
"Parameter count: %d"
8600-
% sum(len(i.ops) - 1 for i in self.instr)
8611+
% sum(len(i.ops) - 1 for i in self.get_all_insts())
86018612
)
86028613

86038614

utils/hct/hctdb_instrhelp.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,10 @@ class db_docsref_gen:
151151

152152
def __init__(self, db):
153153
self.db = db
154-
instrs = [i for i in self.db.instr if i.is_dxil_op]
155-
instrs = sorted(
156-
instrs,
154+
self.instrs = sorted(
155+
self.db.get_dxil_ops(),
157156
key=lambda v: ("" if v.category == None else v.category) + "." + v.name,
158157
)
159-
self.instrs = instrs
160158
val_rules = sorted(
161159
db.val_rules,
162160
key=lambda v: ("" if v.category == None else v.category) + "." + v.name,
@@ -328,7 +326,7 @@ def op_set_const_expr(self, o):
328326
)
329327

330328
def print_body(self):
331-
for i in self.db.instr:
329+
for i in self.db.get_all_insts():
332330
if i.is_reserved:
333331
continue
334332
if i.inst_helper_prefix:
@@ -491,8 +489,7 @@ class db_oload_gen:
491489

492490
def __init__(self, db):
493491
self.db = db
494-
instrs = [i for i in self.db.instr if i.is_dxil_op]
495-
self.instrs = sorted(instrs, key=lambda i: i.dxil_opid)
492+
self.instrs = sorted(self.db.get_dxil_ops(), key=lambda i: i.dxil_opid)
496493

497494
# Allow these to be overridden by external scripts.
498495
self.OP = "OP"
@@ -907,7 +904,7 @@ def op_const_expr(self, o):
907904
)
908905

909906
def print_body(self):
910-
llvm_instrs = [i for i in self.db.instr if i.is_allowed and not i.is_dxil_op]
907+
llvm_instrs = [i for i in self.db.get_llvm_insts() if i.is_allowed]
911908
print("static bool IsLLVMInstructionAllowed(llvm::Instruction &I) {")
912909
self.print_comment(
913910
" // ",
@@ -1253,7 +1250,7 @@ def get_instrs_pred(varname, pred, attr_name="dxil_opid"):
12531250
pred_fn = lambda i: getattr(i, pred)
12541251
else:
12551252
pred_fn = pred
1256-
llvm_instrs = [i for i in db.instr if pred_fn(i)]
1253+
llvm_instrs = [i for i in db.get_all_insts() if pred_fn(i)]
12571254
result = format_comment(
12581255
"// ",
12591256
"Instructions: %s"
@@ -1296,7 +1293,7 @@ def get_dxil_op_counters():
12961293
def get_instrs_rst():
12971294
"Create an rst table of allowed LLVM instructions."
12981295
db = get_db_dxil()
1299-
instrs = [i for i in db.instr if i.is_allowed and not i.is_dxil_op]
1296+
instrs = [i for i in db.get_llvm_insts() if i.is_allowed]
13001297
instrs = sorted(instrs, key=lambda v: v.llvm_id)
13011298
rows = []
13021299
rows.append(["Instruction", "Action", "Operand overloads"])
@@ -1377,8 +1374,7 @@ def get_is_pass_option_name():
13771374
def get_opcodes_rst():
13781375
"Create an rst table of opcodes"
13791376
db = get_db_dxil()
1380-
instrs = [i for i in db.instr if i.is_allowed and i.is_dxil_op]
1381-
instrs = sorted(instrs, key=lambda v: v.dxil_opid)
1377+
instrs = sorted(db.get_dxil_ops(), key=lambda v: v.dxil_opid)
13821378
rows = []
13831379
rows.append(["ID", "Name", "Description"])
13841380
for i in instrs:
@@ -1412,8 +1408,7 @@ def get_valrules_rst():
14121408
def get_opsigs():
14131409
# Create a list of DXIL operation signatures, sorted by ID.
14141410
db = get_db_dxil()
1415-
instrs = [i for i in db.instr if i.is_dxil_op]
1416-
instrs = sorted(instrs, key=lambda v: v.dxil_opid)
1411+
instrs = sorted(db.get_dxil_ops(), key=lambda v: v.dxil_opid)
14171412
# db_dxil already asserts that the numbering is dense.
14181413
# Create the code to write out.
14191414
code = "static const char *OpCodeSignatures[] = {\n"
@@ -1455,9 +1450,8 @@ def get_opsigs():
14551450

14561451
def get_min_sm_and_mask_text():
14571452
db = get_db_dxil()
1458-
instrs = [i for i in db.instr if i.is_dxil_op]
14591453
instrs = sorted(
1460-
instrs,
1454+
db.get_dxil_ops(),
14611455
key=lambda v: (
14621456
v.shader_model,
14631457
v.shader_model_translated,
@@ -1546,9 +1540,8 @@ def flush_instrs(grouped_instrs, last_model, last_model_translated, last_stage):
15461540

15471541
def get_valopcode_sm_text():
15481542
db = get_db_dxil()
1549-
instrs = [i for i in db.instr if i.is_dxil_op]
15501543
instrs = sorted(
1551-
instrs, key=lambda v: (v.shader_model, v.shader_stages, v.dxil_opid)
1544+
db.get_dxil_ops(), key=lambda v: (v.shader_model, v.shader_stages, v.dxil_opid)
15521545
)
15531546
last_model = None
15541547
last_stage = None

0 commit comments

Comments
 (0)