-
Notifications
You must be signed in to change notification settings - Fork 818
[SPIR-V] countbit on 16+64 bit types #7997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Keenuts
wants to merge
4
commits into
microsoft:main
Choose a base branch
from
Keenuts:fix-7122
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+182
−12
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9748,7 +9748,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { | |
| retVal = processReverseBitsIntrinsic(callExpr, srcLoc); | ||
| break; | ||
| } | ||
| INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false); | ||
| case hlsl::IntrinsicOp::IOP_countbits: { | ||
| retVal = processCountBitsIntrinsic(callExpr, srcLoc); | ||
| break; | ||
| } | ||
| INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true); | ||
| INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true); | ||
| INTRINSIC_SPIRV_OP_CASE(and, LogicalAnd, false); | ||
|
|
@@ -9907,6 +9910,85 @@ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic( | |
| return result; | ||
| } | ||
|
|
||
| SpirvInstruction * | ||
| SpirvEmitter::processCountBitsIntrinsic(const CallExpr *callExpr, | ||
| clang::SourceLocation srcLoc) { | ||
| const QualType argType = callExpr->getArg(0)->getType(); | ||
| const uint32_t bitwidth = getElementSpirvBitwidth( | ||
| astContext, argType, spirvOptions.enable16BitTypes); | ||
|
|
||
| // SPIRV only supports 32 bit integers for `OpBitCount` until maintenace9. | ||
| // We need to unfold and add extra instructions to support this on | ||
| // non-32bit integers. | ||
| if (bitwidth == 32) { | ||
| return processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpBitCount, | ||
| /* actPerRowForMatrices= */ false); | ||
| } else if (bitwidth == 16) { | ||
| return generateCountBits16(callExpr, srcLoc); | ||
| } else if (bitwidth == 64) { | ||
| return generateCountBits64(callExpr, srcLoc); | ||
| } | ||
| emitError("countbits currently only supports 16, 32, and 64-bit " | ||
| "width components when targeting SPIR-V", | ||
| srcLoc); | ||
| return nullptr; | ||
| } | ||
|
|
||
| SpirvInstruction * | ||
| SpirvEmitter::generateCountBits16(const CallExpr *callExpr, | ||
| clang::SourceLocation srcLoc) { | ||
| const QualType argType = callExpr->getArg(0)->getType(); | ||
| // Load the 16-bit value | ||
| auto *loadInst = doExpr(callExpr->getArg(0)); | ||
| bool isVector = isVectorType(argType); | ||
| uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1; | ||
| QualType uintType = | ||
| isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count) | ||
| : astContext.UnsignedIntTy; | ||
|
|
||
| auto *extended = | ||
| spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc); | ||
| return spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, extended, | ||
| srcLoc); | ||
| } | ||
|
|
||
| SpirvInstruction * | ||
| SpirvEmitter::generateCountBits64(const CallExpr *callExpr, | ||
| clang::SourceLocation srcLoc) { | ||
| const QualType argType = callExpr->getArg(0)->getType(); | ||
| // Load the 16-bit value | ||
| auto *loadInst = doExpr(callExpr->getArg(0)); | ||
| bool isVector = isVectorType(argType); | ||
| uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1; | ||
| QualType uintType = | ||
| isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count) | ||
| : astContext.UnsignedIntTy; | ||
|
|
||
| auto *lhs = | ||
| spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc); | ||
| auto *lhs_count = | ||
| spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, lhs, srcLoc); | ||
|
|
||
| auto *shiftAmount = | ||
| spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32)); | ||
| if (isVector) { | ||
| SmallVector<SpirvConstant *, 4> Components; | ||
| for (unsigned I = 0; I < count; ++I) | ||
| Components.push_back(shiftAmount); | ||
| shiftAmount = spvBuilder.getConstantComposite(uintType, Components); | ||
| } | ||
|
|
||
| SpirvInstruction *rhs = spvBuilder.createBinaryOp( | ||
| spv::Op::OpShiftLeftLogical, argType, loadInst, shiftAmount, srcLoc); | ||
| auto *rhs32 = | ||
| spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, rhs, srcLoc); | ||
| auto *rhs_count = | ||
| spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, rhs32, srcLoc); | ||
|
|
||
| return spvBuilder.createBinaryOp(spv::Op::OpIAdd, uintType, rhs_count, | ||
| lhs_count, srcLoc); | ||
|
Comment on lines
+9988
to
+9989
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the return type is is always uint or will it be 64-bits with a 64-bit input? |
||
| } | ||
|
|
||
| SpirvInstruction * | ||
| SpirvEmitter::processReverseBitsIntrinsic(const CallExpr *callExpr, | ||
| clang::SourceLocation srcLoc) { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 92 additions & 11 deletions
103
tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,98 @@ | ||
| // RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s | ||
| // RUN: %dxc -T vs_6_2 -E main -fcgl -enable-16bit-types %s -spirv | FileCheck %s | ||
|
|
||
| // According to HLSL reference: | ||
| // The 'countbits' function can only operate on scalar or vector of uints. | ||
|
|
||
| void main() { | ||
| uint a; | ||
| uint4 b; | ||
|
|
||
| // CHECK: [[a:%[0-9]+]] = OpLoad %uint %a | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %uint [[a]] | ||
| uint cb = countbits(a); | ||
|
|
||
| // CHECK: [[b:%[0-9]+]] = OpLoad %v4uint %b | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[b]] | ||
| uint4 cb4 = countbits(b); | ||
| // CHECK: [[v4_32:%[0-9]+]] = OpConstantComposite %v4uint %uint_32 %uint_32 %uint_32 %uint_32 | ||
|
|
||
| uint16_t u16; | ||
| uint32_t u32; | ||
| uint64_t u64; | ||
|
|
||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16 | ||
| // CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] | ||
| // CHECK: {{%[0-9]+}} = OpBitCount %uint [[ext]] | ||
| uint ru16 = countbits(u16); | ||
|
|
||
| // CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32 | ||
| // CHECK: {{%[0-9]+}} = OpBitCount %uint [[ext]] | ||
| uint ru32 = countbits(u32); | ||
|
|
||
| // CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64 | ||
| // CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] | ||
| // CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %ulong [[ld]] %uint_32 | ||
| // CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] | ||
| // CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] | ||
| // CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] | ||
| // CHECK-DAG: {{%[0-9]+}} = OpIAdd %uint [[cb]] [[ca]] | ||
| uint ru64 = countbits(u64); | ||
|
|
||
| int16_t s16; | ||
| int32_t s32; | ||
| int64_t s64; | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16 | ||
| // CHECK: [[cnv:%[0-9]+]] = OpUConvert %uint [[tmp]] | ||
| // CHECK: {{%[0-9]+}} = OpBitCount %uint [[cnv]] | ||
| uint rs16 = countbits(s16); | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %int %s32 | ||
| // CHECK: {{%[0-9]+}} = OpBitCount %uint [[tmp]] | ||
| uint rs32 = countbits(s32); | ||
|
|
||
| // CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64 | ||
| // CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] | ||
| // CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %long [[ld]] %uint_32 | ||
| // CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] | ||
| // CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] | ||
| // CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] | ||
| // CHECK-DAG: {{%[0-9]+}} = OpIAdd %uint [[cb]] [[ca]] | ||
| uint rs64 = countbits(s64); | ||
|
|
||
|
|
||
| uint16_t4 vu16; | ||
| uint32_t4 vu32; | ||
| uint64_t4 vu64; | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %v4ushort %vu16 | ||
| // CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]] | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]] | ||
| uint4 rvu16 = countbits(vu16); | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %v4uint %vu32 | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]] | ||
| uint4 rvu32 = countbits(vu32); | ||
|
|
||
| // CHECK: [[ld:%[0-9]+]] = OpLoad %v4ulong %vu64 | ||
| // CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]] | ||
| // CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %v4ulong [[ld]] [[v4_32]] | ||
| // CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]] | ||
| // CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]] | ||
| // CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]] | ||
| // CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]] | ||
| uint4 rvu64 = countbits(vu64); | ||
|
|
||
| int16_t4 vs16; | ||
| int32_t4 vs32; | ||
| int64_t4 vs64; | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %v4short %vs16 | ||
| // CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]] | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]] | ||
| uint4 rvs16 = countbits(vs16); | ||
|
|
||
| // CHECK: [[tmp:%[0-9]+]] = OpLoad %v4int %vs32 | ||
| // CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]] | ||
| uint4 rvs32 = countbits(vs32); | ||
|
|
||
| // CHECK: [[ld:%[0-9]+]] = OpLoad %v4long %vs64 | ||
| // CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]] | ||
| // CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %v4long [[ld]] [[v4_32]] | ||
| // CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]] | ||
| // CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]] | ||
| // CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]] | ||
| // CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]] | ||
| uint4 rvs64 = countbits(vs64); | ||
| } |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you have mentioned, this should be right.