Skip to content

Commit

Permalink
Handle ragged dot in precision config methods.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696699435
  • Loading branch information
pravnar authored and Google-ML-Automation committed Nov 15, 2024
1 parent a7aff7f commit fb0fa1f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions xla/hlo/ir/hlo_instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5239,6 +5239,9 @@ const PrecisionConfig& HloInstruction::precision_config() const {
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->precision_config();
}
if (auto* ragged_dot = DynCast<HloRaggedDotInstruction>(this)) {
return ragged_dot->precision_config();
}

if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->precision_config();
Expand All @@ -5253,6 +5256,9 @@ PrecisionConfig* HloInstruction::mutable_precision_config() {
if (auto* dot = DynCast<HloDotInstruction>(this)) {
return dot->mutable_precision_config();
}
if (auto* ragged_dot = DynCast<HloRaggedDotInstruction>(this)) {
return ragged_dot->mutable_precision_config();
}
if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
return custom_call->mutable_precision_config();
}
Expand Down
18 changes: 18 additions & 0 deletions xla/service/hlo_instruction_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3181,5 +3181,23 @@ TEST_F(HloInstructionTest, UnfuseInstructionWithConstantOperand) {
GmockMatch(m::Add(m::Parameter(0), m::Broadcast(m::Constant()))));
}

TEST_F(HloInstructionTest, RaggedDotHasPrecisionConfig) {
constexpr char kHloString[] = R"(
HloModule module
ENTRY entry_computation {
a = f32[11,5] parameter(0)
b = f32[3,5,7] parameter(1)
c = u32[3] parameter(2)
ROOT dot = f32[11,7] ragged-dot(a, b, c), lhs_contracting_dims={1}, rhs_contracting_dims={1}, lhs_ragged_dims={0}, rhs_group_dims={0}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
auto* ragged_dot = module->entry_computation()->root_instruction();

EXPECT_THAT(ragged_dot->precision_config().operand_precision(),
::testing::ElementsAre(PrecisionConfig::DEFAULT,
PrecisionConfig::DEFAULT));
}

} // namespace
} // namespace xla

0 comments on commit fb0fa1f

Please sign in to comment.