diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 0251e4eb5e..868ae30f8a 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -9748,7 +9748,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { retVal = processReverseBitsIntrinsic(callExpr, srcLoc); break; } - INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false); + case hlsl::IntrinsicOp::IOP_countbits: { + retVal = processCountBitsIntrinsic(callExpr, srcLoc); + break; + } INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true); INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true); INTRINSIC_SPIRV_OP_CASE(and, LogicalAnd, false); @@ -9907,6 +9910,91 @@ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic( return result; } +SpirvInstruction * +SpirvEmitter::processCountBitsIntrinsic(const CallExpr *callExpr, + clang::SourceLocation srcLoc) { + const QualType argType = callExpr->getArg(0)->getType(); + const uint32_t bitwidth = getElementSpirvBitwidth( + astContext, argType, spirvOptions.enable16BitTypes); + + // The intrinsic should always return an uint or vector of uint. + QualType retType = {}; + if (!isVectorType(callExpr->getCallReturnType(astContext), &retType)) + retType = callExpr->getCallReturnType(astContext); + assert(retType == astContext.UnsignedIntTy); + + // SPIRV only supports 32 bit integers for `OpBitCount` until maintenace9. + // We need to unfold and add extra instructions to support this on + // non-32bit integers. + if (bitwidth == 32) { + return processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpBitCount, + /* actPerRowForMatrices= */ false); + } else if (bitwidth == 16) { + return generateCountBits16(callExpr, srcLoc); + } else if (bitwidth == 64) { + return generateCountBits64(callExpr, srcLoc); + } + emitError("countbits currently only supports 16, 32, and 64-bit " + "width components when targeting SPIR-V", + srcLoc); + return nullptr; +} + +SpirvInstruction * +SpirvEmitter::generateCountBits16(const CallExpr *callExpr, + clang::SourceLocation srcLoc) { + const QualType argType = callExpr->getArg(0)->getType(); + // Load the 16-bit value + auto *loadInst = doExpr(callExpr->getArg(0)); + bool isVector = isVectorType(argType); + uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1; + QualType uintType = + isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count) + : astContext.UnsignedIntTy; + + auto *extended = + spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc); + return spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, extended, + srcLoc); +} + +SpirvInstruction * +SpirvEmitter::generateCountBits64(const CallExpr *callExpr, + clang::SourceLocation srcLoc) { + const QualType argType = callExpr->getArg(0)->getType(); + // Load the 16-bit value + auto *loadInst = doExpr(callExpr->getArg(0)); + bool isVector = isVectorType(argType); + uint32_t count = isVector ? hlsl::GetHLSLVecSize(argType) : 1; + QualType uintType = + isVector ? astContext.getExtVectorType(astContext.UnsignedIntTy, count) + : astContext.UnsignedIntTy; + + auto *lhs = + spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, loadInst, srcLoc); + auto *lhs_count = + spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, lhs, srcLoc); + + auto *shiftAmount = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 32)); + if (isVector) { + SmallVector Components; + for (unsigned I = 0; I < count; ++I) + Components.push_back(shiftAmount); + shiftAmount = spvBuilder.getConstantComposite(uintType, Components); + } + + SpirvInstruction *rhs = spvBuilder.createBinaryOp( + spv::Op::OpShiftRightLogical, argType, loadInst, shiftAmount, srcLoc); + auto *rhs32 = + spvBuilder.createUnaryOp(spv::Op::OpUConvert, uintType, rhs, srcLoc); + auto *rhs_count = + spvBuilder.createUnaryOp(spv::Op::OpBitCount, uintType, rhs32, srcLoc); + + return spvBuilder.createBinaryOp(spv::Op::OpIAdd, uintType, rhs_count, + lhs_count, srcLoc); +} + SpirvInstruction * SpirvEmitter::processReverseBitsIntrinsic(const CallExpr *callExpr, clang::SourceLocation srcLoc) { diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index d1f0363f6e..9b890d3af4 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -852,6 +852,13 @@ class SpirvEmitter : public ASTConsumer { SourceLocation loc, SourceRange range); + SpirvInstruction *processCountBitsIntrinsic(const CallExpr *callExpr, + clang::SourceLocation srcLoc); + SpirvInstruction *generateCountBits16(const CallExpr *callExpr, + clang::SourceLocation srcLoc); + SpirvInstruction *generateCountBits64(const CallExpr *callExpr, + clang::SourceLocation srcLoc); + // Processes the `reversebits` intrinsic SpirvInstruction *processReverseBitsIntrinsic(const CallExpr *expr, clang::SourceLocation srcLoc); diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl index 90ec93bcff..4941eb9e73 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.countbits.hlsl @@ -1,17 +1,203 @@ -// RUN: %dxc -T vs_6_0 -E main -fcgl %s -spirv | FileCheck %s +// RUN: %dxc -T vs_6_2 -E main -fcgl -enable-16bit-types %s -spirv | FileCheck %s // According to HLSL reference: // The 'countbits' function can only operate on scalar or vector of uints. void main() { - uint a; - uint4 b; - -// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a -// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %uint [[a]] - uint cb = countbits(a); - -// CHECK: [[b:%[0-9]+]] = OpLoad %v4uint %b -// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[b]] - uint4 cb4 = countbits(b); +// CHECK: [[v4_32:%[0-9]+]] = OpConstantComposite %v4uint %uint_32 %uint_32 %uint_32 %uint_32 + + uint16_t u16; + uint32_t u32; + uint64_t u64; + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]] +// CHECK: OpStore %u16ru16 [[cast]] + uint16_t u16ru16 = countbits(u16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: OpStore %u32ru16 [[res]] + uint32_t u32ru16 = countbits(u16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort %u16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]] +// CHECK: OpStore %u64ru16 [[cast]] + uint64_t u64ru16 = countbits(u16); + +// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]] +// CHECK: OpStore %u16ru32 [[cast]] + uint16_t u16ru32 = countbits(u32); +// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: OpStore %u32ru32 [[res]] + uint32_t u32ru32 = countbits(u32); +// CHECK: [[ext:%[0-9]+]] = OpLoad %uint %u32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]] +// CHECK: OpStore %u64ru32 [[cast]] + uint64_t u64ru32 = countbits(u32); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ushort [[re]] +// CHECK-DAG: OpStore %u16ru64 [[cast]] + uint16_t u16ru64 = countbits(u64); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: OpStore %u32ru64 [[re]] + uint32_t u32ru64 = countbits(u64); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %ulong %u64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %ulong [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ulong [[re]] +// CHECK-DAG: OpStore %u64ru64 [[cast]] + uint64_t u64ru64 = countbits(u64); + + int16_t s16; + int32_t s32; + int64_t s64; + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %short [[cast]] +// CHECK: OpStore %s16rs16 [[bc]] + int16_t s16rs16 = countbits(s16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %int [[res]] +// CHECK: OpStore %s32rs16 [[bc]] + int32_t s32rs16 = countbits(s16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %short %s16 +// CHECK: [[ext:%[0-9]+]] = OpUConvert %uint [[tmp]] +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %long [[cast]] +// CHECK: OpStore %s64rs16 [[bc]] + int64_t s64rs16 = countbits(s16); + +// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ushort [[res]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %short [[cast]] +// CHECK: OpStore %s16rs32 [[bc]] + int16_t s16rs32 = countbits(s32); +// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %int [[res]] +// CHECK: OpStore %s32rs32 [[bc]] + int32_t s32rs32 = countbits(s32); +// CHECK: [[ext:%[0-9]+]] = OpLoad %int %s32 +// CHECK: [[res:%[0-9]+]] = OpBitCount %uint [[ext]] +// CHECK: [[cast:%[0-9]+]] = OpUConvert %ulong [[res]] +// CHECK: [[bc:%[0-9]+]] = OpBitcast %long [[cast]] +// CHECK: OpStore %s64rs32 [[bc]] + int64_t s64rs32 = countbits(s32); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ushort [[re]] +// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %short [[cast]] +// CHECK-DAG: OpStore %s16rs64 [[bc]] + int16_t s16rs64 = countbits(s64); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %int [[re]] +// CHECK-DAG: OpStore %s32rs64 [[bc]] + int32_t s32rs64 = countbits(s64); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %long %s64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %long [[ld]] %uint_32 +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %uint [[hi]] +// CHECK-DAG: [[re:%[0-9]+]] = OpIAdd %uint [[cb]] [[ca]] +// CHECK-DAG: [[cast:%[0-9]+]] = OpUConvert %ulong [[re]] +// CHECK-DAG: [[bc:%[0-9]+]] = OpBitcast %long [[cast]] +// CHECK-DAG: OpStore %s64rs64 [[bc]] + int64_t s64rs64 = countbits(s64); + + uint16_t4 vu16; + uint32_t4 vu32; + uint64_t4 vu64; + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4ushort %vu16 +// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]] +// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]] + uint4 rvu16 = countbits(vu16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4uint %vu32 +// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]] + uint4 rvu32 = countbits(vu32); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %v4ulong %vu64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %v4ulong [[ld]] [[v4_32]] +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]] +// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]] + uint4 rvu64 = countbits(vu64); + + int16_t4 vs16; + int32_t4 vs32; + int64_t4 vs64; + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4short %vs16 +// CHECK-DAG: [[ext:%[0-9]+]] = OpUConvert %v4uint [[tmp]] +// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[ext]] + uint4 rvs16 = countbits(vs16); + +// CHECK: [[tmp:%[0-9]+]] = OpLoad %v4int %vs32 +// CHECK-NEXT: {{%[0-9]+}} = OpBitCount %v4uint [[tmp]] + uint4 rvs32 = countbits(vs32); + +// CHECK: [[ld:%[0-9]+]] = OpLoad %v4long %vs64 +// CHECK-DAG: [[lo:%[0-9]+]] = OpUConvert %v4uint [[ld]] +// CHECK-DAG: [[sh:%[0-9]+]] = OpShiftRightLogical %v4long [[ld]] [[v4_32]] +// CHECK-DAG: [[hi:%[0-9]+]] = OpUConvert %v4uint [[sh]] +// CHECK-DAG: [[ca:%[0-9]+]] = OpBitCount %v4uint [[lo]] +// CHECK-DAG: [[cb:%[0-9]+]] = OpBitCount %v4uint [[hi]] +// CHECK-DAG: {{%[0-9]+}} = OpIAdd %v4uint [[cb]] [[ca]] + uint4 rvs64 = countbits(vs64); }