Skip to content

Commit

Permalink
Fix/two die (#1285)
Browse files Browse the repository at this point in the history
* add test case

* fix bug

* fix TestGatherReduceScatter

* fix ci
  • Loading branch information
zhen8838 authored Jan 6, 2025
1 parent 9e77477 commit ae1ef36
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,70 +55,61 @@ extern "C" void thread_main(std::byte *const *inouts, const std::byte *rdata, co
}

#ifdef NNCASE_STANDALONE
#include <memory>
int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
std::byte *inputs[@inputCount];
size_t align = @(Model.Alignment);
@foreach(var (b,i) in Model.PrimFunction.Parameters.ToArray().OfType<Nncase.TIR.Buffer>().Select((b,i)=>(Model.GetInfo(b),i)))
{
@:inputs[@i] = (std::byte *)nncase::ntt::runtime::thread_alloc(sizeof(@Html.Raw(b.ElemType)) * @b.Size, align);
}

static void *local_alloc(size_t bytes, size_t alignment) {
#ifdef WIN32
return _aligned_malloc(bytes, alignment);
#else
size_t mask = alignment - 1;
size_t aligned_bytes = bytes + (-bytes & mask);
return aligned_alloc(alignment, aligned_bytes);
#endif
}
std::byte* rdata = (std::byte *)nncase::ntt::runtime::thread_alloc(@Model.RDataSize, align);

static void local_free(void *ptr) {
#ifdef WIN32
_aligned_free(ptr);
#else
free(ptr);
#ifdef __APPLE__
pthread_key_t cpu_thread_context_key_ = {};
pthread_key_create(&cpu_thread_context_key_, [](void *ptr) { delete (nncase::ntt::runtime::cpu_thread_context_t *)ptr; });
#endif
}

static nncase_runtime_cpu_mt_t nncase_cpu_mt_ = {
.acosf = acosf,
.acoshf = acoshf,
.asinf = asinf,
.asinhf = asinhf,
.copysignf = copysignf,
.cosf = cosf,
.coshf = coshf,
.erff = erff,
.expf = expf,
.fmodf = fmodf,
.logf = logf,
.nearbyintf = nearbyintf,
.powf = powf,
.roundf = roundf,
.sinf = sinf,
.sinhf = sinhf,
.sqrtf = sqrtf,
.tanhf = tanhf,
.sram_address = nullptr,
.failfast = nullptr,

#ifndef WIN32
.memcpy = memcpy,
.memmove = memmove,
.memset = memset,
std::vector<std::thread> blocks;
for (size_t cid = 0; cid < cdim(); cid++) {
for (size_t bid = 0; bid < bdim(); bid++) {
blocks.emplace_back([cid, bid, inputs, rdata
#ifdef __APPLE__
, &cpu_thread_context_key_
#endif
};
] {
nncase::ntt::runtime::cpu_block_entry_params_t block_entry_params{
.tdim = tdim(),
.bdim = bdim(),
.cdim = cdim(),
.bid = bid,
.cid = cid,
.cpu_id_offset = (cid * bdim() + bid) * tdim(),
.inouts = inputs,
.rdata = rdata,
#ifdef __APPLE__
.cpu_thread_context_key = cpu_thread_context_key_,
#endif
};

int main([[maybe_unused]] int argc, [[maybe_unused]] char** argv) {
uint8_t *inputs[@inputCount];
size_t align = @(Model.Alignment);
@foreach(var (b,i) in Model.PrimFunction.Parameters.ToArray().OfType<Nncase.TIR.Buffer>().Select((b,i)=>(Model.GetInfo(b),i)))
{
@:inputs[@i] = (uint8_t *)local_alloc(sizeof(@Html.Raw(b.ElemType)) * @b.Size, align);
block_entry(block_entry_params);
});
}
}

uint8_t* rdata = (uint8_t *)local_alloc(@Model.RDataSize, align);
for (auto &block : blocks) {
block.join();
}


#ifdef __APPLE__
pthread_key_delete(cpu_thread_context_key_);
#endif

kernel_entry(&nncase_cpu_mt_, inputs, rdata);
for (size_t i = 0; i < @inputCount; i++) {
local_free(inputs[i]);
nncase::ntt::runtime::thread_free(inputs[i]);
}
local_free(rdata);
nncase::ntt::runtime::thread_free(rdata);
return 0;
}
#endif
5 changes: 5 additions & 0 deletions modules/Nncase.Modules.CPU/Evaluator/CPU/Boxing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public IRType Visit(ITypeInferenceContext context, Boxing target)
return new InvalidType("Not supported input is Partial output is Split");
}

if (target.IsReshape)
{
return new InvalidType("partial not support reshape");
}

return outv;
}
else
Expand Down
76 changes: 49 additions & 27 deletions modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ public record EqualityClass(bool Tuple, List<IEquality> Children) : IEquality
{
}

public sealed class AutoDistributedMetaData : IRMetadata
{
public bool Skip { get; set; }
}

/// <summary>
/// auto distributed the xpu fusion.
/// </summary>
Expand All @@ -38,17 +43,25 @@ public sealed partial class AutoDistributedPass : FunctionPass
{
private readonly CompileOptions _compileOptions;

private readonly bool _bidirectional;

private readonly string _moduleKind;

public AutoDistributedPass(CompileOptions compileOptions, string moduleKind = "cpu")
public AutoDistributedPass(bool bidirectional, string moduleKind, CompileOptions compileOptions)
{
_compileOptions = compileOptions;
_bidirectional = bidirectional;
_moduleKind = moduleKind;
}

protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassContext context)
{
var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind);
if (input.Metadata is AutoDistributedMetaData { Skip: true })
{
return Task.FromResult(input);
}

var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind, _bidirectional);
return Task.FromResult(rewriter.Rewirte(input));
}
}
Expand All @@ -59,7 +72,9 @@ internal sealed class AutoDistributedRewriter : ExprVisitor<Dictionary<IRType, L

private readonly string _moduleKind;

public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu")
private readonly bool _bidirectional;

public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu", bool bidirectional = false)
{
Placements = targetOptions.Hierarchies.Select(h => new Placement(h, targetOptions.HierarchyNames, targetOptions.HierarchyKind)).ToArray();
CompileOptions = compileOptions;
Expand All @@ -74,6 +89,8 @@ public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions t
}

_moduleKind = moduleKind;

_bidirectional = bidirectional;
}

public IRArray<Placement> Placements { get; }
Expand Down Expand Up @@ -345,30 +362,35 @@ protected override Dictionary<IRType, List<Expr>> VisitLeafCall(Call expr)
VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index], isSupported);
}

#if true
var results = expr.Arguments.ToArray().
Select(Visit).
CartesianProduct().
Select(args => args.ToArray()).
Select(args => isSupported ? BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray() :
BuildNotSupportedCalls(op, args.Select(kv => kv.Value[0]).ToArray())).
SelectMany(i => i).
GroupBy(c => c.CheckedType).
ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList<Expr>());
#else
var results = expr.Arguments.ToArray().
Select(Visit).
CartesianProduct().
Select(args => args.ToArray()).
Select(args => args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch
{
DistributedType d => GetDiverseCandidateSBPs(d, Placements).Select(ndsbp => IR.F.CPU.Boxing(arg, new DistributedType(d.TensorType, ndsbp, d.Placement))).Concat(new[] { arg }).ToArray(),
_ => new[] { arg },
}).ToList().CartesianProduct().Select(arg => BuildEquivalCalls(op, arg.ToArray())).SelectMany(i => i).ToArray()).
SelectMany(i => i).
GroupBy(c => c.CheckedType).
ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList<Expr>());
#endif
Dictionary<IRType, List<Expr>> results;
if (_bidirectional && isSupported)
{
results = expr.Arguments.ToArray().
Select(Visit).
CartesianProduct().
Select(args => args.ToArray()).
Select(args => args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch
{
DistributedType d => GetDiverseCandidateSBPs(d, Placements).Select(ndsbp => IR.F.CPU.Boxing(arg, new DistributedType(d.TensorType, ndsbp, d.Placement))).Concat(new[] { arg }).ToArray(),
_ => new[] { arg },
}).ToList().CartesianProduct().Select(arg => BuildEquivalCalls(op, arg.ToArray())).SelectMany(i => i).ToArray()).
SelectMany(i => i).
GroupBy(c => c.CheckedType).
ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList<Expr>());
}
else
{
results = expr.Arguments.ToArray().
Select(Visit).
CartesianProduct().
Select(args => args.ToArray()).
Select(args => isSupported ? BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray() :
BuildNotSupportedCalls(op, args.Select(kv => kv.Value[0]).ToArray())).
SelectMany(i => i).
GroupBy(c => c.CheckedType).
ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList<Expr>());
}

if (results.Count == 0)
{
return expr.Arguments.ToArray().
Expand Down
2 changes: 1 addition & 1 deletion modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp
}

// need refactor tiling.
passManager.Add<Passes.Distributed.AutoDistributedPass>();
passManager.Add<Passes.Distributed.AutoDistributedPass>(true, "cpu");
passManager.AddWithName<DataflowPass>("FoldBoxing").Configure(p =>
{
p.Add<Passes.Rules.Neutral.FoldConstCall>();
Expand Down
6 changes: 4 additions & 2 deletions src/Native/src/test_cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

using namespace nncase;
using namespace nncase::runtime;
// constexpr size_t loop_count = 10;
constexpr size_t loop_count = 1;
constexpr size_t loop_count = 10;
// constexpr size_t loop_count = 1;

#define TRY(x) \
if (x) \
Expand Down Expand Up @@ -70,6 +70,8 @@ result<void> run_core(const std::string &kmodel_path,
parameters.push_back(_.impl());
}

// warm up
try_var(ret, entry->invoke({parameters.data(), parameters.size()}));
double total_time = 0.0;
for (size_t i = 0; i < loop_count; i++) {
auto start_time = std::chrono::steady_clock::now();
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Tests/Distributed/UnitTestDistributeScheme.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task TestLoadScheme()
func = new(output);
}

var pass = new Passes.Distributed.AutoDistributedPass(CompileOptions);
var pass = new AutoDistributedPass(true, "cpu", CompileOptions);

var result = await pass.RunAsync(func, new());

Expand Down
48 changes: 44 additions & 4 deletions src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,45 @@ public async Task TestGatherReduceScatter(int[] shape, int[] hierarchy, int coun

var partial = IR.F.CPU.Boxing(broadcast, new DistributedType(inputType, newsbp, placement));
var sumed = IR.F.CPU.Boxing(partial, new DistributedType(inputType, ndsbp, placement));
posts.Add(IR.F.CPU.Boxing(sumed, inputType));
var post = IR.F.CPU.Boxing(sumed, inputType);
post.Metadata = new Passes.Distributed.AutoDistributedMetaData() { Skip = true };
posts.Add(post);
}

await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
}

[Fact]
public async Task TestPartialReshapeBoxing()
{
var hierarchy = new[] { 2, 4 };
var targetOptions = (CpuTargetOptions)CompileOptions.TargetOptions;
targetOptions.Hierarchies[0] = hierarchy;
targetOptions.HierarchyNames = string.Join(string.Empty, "cbt".TakeLast(hierarchy.Length));
targetOptions.HierarchySizes = Enumerable.Repeat((int)MathF.Pow(2, 30), hierarchy.Length).ToArray();
var lhsType = new TensorType(DataTypes.Float32, new[] { 1, 4, 8 });
var rhsType = new TensorType(DataTypes.Float32, new[] { 8, 16 });
var lhs = new Var(lhsType);
var rhs = new Var(rhsType);

var feedDict = new Dictionary<Var, IValue>() {
{ lhs, IR.F.Random.Normal(DataTypes.Float32, 1.0f, 1.0f, 1, lhsType.Shape.ToValueArray()).Evaluate() },
{ rhs, IR.F.Random.Normal(DataTypes.Float32, 1.0f, 1.0f, 1, rhsType.Shape.ToValueArray()).Evaluate() },
};

var placement = new Placement(hierarchy, targetOptions.HierarchyNames);
var lhsBoxing = IR.F.CPU.Boxing(lhs, new DistributedType(lhsType, new SBP[] { SBP.S(2), SBP.B }, placement));
var rhsBoxing = IR.F.CPU.Boxing(rhs, new DistributedType(rhsType, new SBP[] { SBP.S(0), SBP.S(1) }, placement));
var matmul = IR.F.Tensors.MatMul(lhsBoxing, rhsBoxing);
var newShape = new[] { 1, 4, 8, 2 };
var reshape = IR.F.CPU.Boxing(matmul, new DistributedType(new TensorType(DataTypes.Float32, newShape), new SBP[] { SBP.B, SBP.S(2) }, placement), true);
var sumed = IR.F.CPU.Boxing(reshape, new DistributedType(new TensorType(DataTypes.Float32, newShape), new SBP[] { SBP.S(1), SBP.S(2) }, placement));
var post = IR.F.CPU.Boxing(sumed, new TensorType(DataTypes.Float32, newShape));
post.Metadata = new Passes.Distributed.AutoDistributedMetaData() { Skip = true };

await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{0}"), feedDict, new[] { post });
}

[Fact]
public async Task TestMatmulBinaryBinary()
{
Expand Down Expand Up @@ -464,8 +497,9 @@ public async Task TestConv2DAndIm2col(int[] inputShape, int[] wShape, int[] padd
await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{count}"), feedDict, posts);
}

[Theory(Skip = "ToBig")]
[InlineData(new object[] { false, 0 })]
[Theory]

// [InlineData(new object[] { false, 0 })]
[InlineData(new object[] { true, 1 })] // enable packing
public async Task TestDecodeLayer(bool packing, int count)
{
Expand All @@ -475,7 +509,12 @@ public async Task TestDecodeLayer(bool packing, int count)
return;
}

((CpuTargetOptions)CompileOptions.TargetOptions).Packing = packing;
var cpuOptions = (CpuTargetOptions)CompileOptions.TargetOptions;
cpuOptions.Packing = packing;
cpuOptions.Hierarchies = new[] { new[] { 2, 64 } };
cpuOptions.HierarchyNames = "bt";
cpuOptions.MemoryCapacities = [524288, 2147483647];
cpuOptions.MemoryBandWidths = [128, 64];
var vhidden_in = new Var("vhidden_in", new TensorType(DataTypes.Float32, new[] { 1, 384, 8192 }));
var vattn_mask = new Var("vattn_mask", new TensorType(DataTypes.Float32, new[] { 1, 1, 384, 384 }));
var vposition_ids = new Var("vposition_ids", new TensorType(DataTypes.Int64, new[] { 1, 384 }));
Expand Down Expand Up @@ -605,6 +644,7 @@ internal async Task Run(string dumpDir, CpuKernelCase kernelCase)
}

var main = new Function(fusion.Body, kernelCase.Vars.ToArray());
main.Metadata = fusion.Body.Metadata;

var module = new IR.IRModule(main);
var inputs = kernelCase.Inputs.ToArray();
Expand Down

0 comments on commit ae1ef36

Please sign in to comment.