Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9748,7 +9748,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
retVal = processReverseBitsIntrinsic(callExpr, srcLoc);
break;
}
INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
case hlsl::IntrinsicOp::IOP_countbits: {
retVal = processCountBitsIntrinsic(callExpr, srcLoc);
break;
}
INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
INTRINSIC_SPIRV_OP_CASE(and, LogicalAnd, false);
Expand Down Expand Up @@ -9907,6 +9910,85 @@ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
return result;
}

SpirvInstruction *
SpirvEmitter::processCountBitsIntrinsic(const CallExpr *callExpr,
clang::SourceLocation srcLoc) {
const QualType argType = callExpr->getArg(0)->getType();
const uint32_t bitwidth = getElementSpirvBitwidth(
astContext, argType, spirvOptions.enable16BitTypes);

// SPIRV only supports 32 bit integers for `OpBitCount` until maintenace9.
// We need to unfold and add extra instructions to support this on
// non-32bit integers.
if (bitwidth == 32) {
return processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpBitCount,
/* actPerRowForMatrices= */ false);
} else if (bitwidth == 16) {
return generateCountBits16(callExpr, srcLoc);
} else if (bitwidth == 64) {
return generateCountBits64(callExpr, srcLoc);
}
emitError("countbits currently only supports 16, 32, and 64-bit "
"width components when targeting SPIR-V",
srcLoc);
return nullptr;
}

SpirvInstruction *
SpirvEmitter::generateCountBits16(const CallExpr *callExpr,
clang::SourceLocation srcLoc) {
const QualType argType = callExpr->getArg(0)->getType();
// Load the 16-bit value
auto *loadInst = doExpr(callExpr->getArg(0));
bool isVector = isVectorType(argType);
uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1;
QualType uintType =
isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count)
: astContext.UnsignedIntTy;

auto *extended =
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc);
return spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, extended,
srcLoc);
}

SpirvInstruction *
SpirvEmitter::generateCountBits64(const CallExpr *callExpr,
clang::SourceLocation srcLoc) {
const QualType argType = callExpr->getArg(0)->getType();
// Load the 16-bit value
auto *loadInst = doExpr(callExpr->getArg(0));
bool isVector = isVectorType(argType);
uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1;
QualType uintType =
isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count)
: astContext.UnsignedIntTy;

auto *lhs =
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc);
auto *lhs_count =
spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, lhs, srcLoc);

auto *shiftAmount =
spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32));
if (isVector) {
SmallVector<SpirvConstant *, 4> Components;
for (unsigned I = 0; I < count; ++I)
Components.push_back(shiftAmount);
shiftAmount = spvBuilder.getConstantComposite(uintType, Components);
}

SpirvInstruction *rhs = spvBuilder.createBinaryOp(
spv::Op::OpShiftLeftLogical, argType, loadInst, shiftAmount, srcLoc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you have mentioned, this should be right.

auto *rhs32 =
spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, rhs, srcLoc);
auto *rhs_count =
spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, rhs32, srcLoc);

return spvBuilder.createBinaryOp(spv::Op::OpIAdd, uintType, rhs_count,
lhs_count, srcLoc);
Comment on lines +9988 to +9989
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the return type is is always uint or will it be 64-bits with a 64-bit input?

}

SpirvInstruction *
SpirvEmitter::processReverseBitsIntrinsic(const CallExpr *callExpr,
clang::SourceLocation srcLoc) {
Expand Down
7 changes: 7 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,13 @@ class SpirvEmitter : public ASTConsumer {
SourceLocation loc,
SourceRange range);

SpirvInstruction *processCountBitsIntrinsic(const CallExpr *callExpr,
clang::SourceLocation srcLoc);
SpirvInstruction *generateCountBits16(const CallExpr *callExpr,
clang::SourceLocation srcLoc);
SpirvInstruction *generateCountBits64(const CallExpr *callExpr,
clang::SourceLocation srcLoc);

// Processes the `reversebits` intrinsic
SpirvInstruction *processReverseBitsIntrinsic(const CallExpr *expr,
clang::SourceLocation srcLoc);
Expand Down
103 changes: 92 additions & 11 deletions tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl
Original file line number Diff line number Diff line change
@@ -1,17 +1,98 @@
// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s
// RUN: %dxc -T vs_6_2 -E main -fcgl -enable-16bit-types %s -spirv | FileCheck %s

// According to HLSL reference:
// The 'countbits' function can only operate on scalar or vector of uints.

void main() {
uint a;
uint4 b;

// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %uint [[a]]
uint cb = countbits(a);

// CHECK: [[b:%[0-9]+]] = OpLoad %v4uint %b
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[b]]
uint4 cb4 = countbits(b);
// CHECK: [[v4_32:%[0-9]+]] = OpConstantComposite %v4uint %uint_32 %uint_32 %uint_32 %uint_32

uint16_t u16;
uint32_t u32;
uint64_t u64;


// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16
// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]]
// CHECK: {{%[0-9]+}} = OpBitCount %uint [[ext]]
uint ru16 = countbits(u16);

// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32
// CHECK: {{%[0-9]+}} = OpBitCount %uint [[ext]]
uint ru32 = countbits(u32);

// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %ulong [[ld]] %uint_32
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %uint [[cb]] [[ca]]
uint ru64 = countbits(u64);

int16_t s16;
int32_t s32;
int64_t s64;

// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16
// CHECK: [[cnv:%[0-9]+]] = OpUConvert %uint [[tmp]]
// CHECK: {{%[0-9]+}} = OpBitCount %uint [[cnv]]
uint rs16 = countbits(s16);

// CHECK: [[tmp:%[0-9]+]] = OpLoad %int %s32
// CHECK: {{%[0-9]+}} = OpBitCount %uint [[tmp]]
uint rs32 = countbits(s32);

// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]]
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %long [[ld]] %uint_32
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]]
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]]
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]]
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %uint [[cb]] [[ca]]
uint rs64 = countbits(s64);


uint16_t4 vu16;
uint32_t4 vu32;
uint64_t4 vu64;

// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4ushort %vu16
// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]]
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]]
uint4 rvu16 = countbits(vu16);

// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4uint %vu32
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]]
uint4 rvu32 = countbits(vu32);

// CHECK: [[ld:%[0-9]+]] = OpLoad %v4ulong %vu64
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]]
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %v4ulong [[ld]] [[v4_32]]
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]]
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]]
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]]
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]]
uint4 rvu64 = countbits(vu64);

int16_t4 vs16;
int32_t4 vs32;
int64_t4 vs64;

// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4short %vs16
// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]]
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]]
uint4 rvs16 = countbits(vs16);

// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4int %vs32
// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]]
uint4 rvs32 = countbits(vs32);

// CHECK: [[ld:%[0-9]+]] = OpLoad %v4long %vs64
// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]]
// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftLeftLogical %v4long [[ld]] [[v4_32]]
// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]]
// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]]
// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]]
// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]]
uint4 rvs64 = countbits(vs64);
}