喵了个c 2020-06-23 09:32 采纳率: 0%
浏览 136

HDU 4578多标记线段树有一点想不通

当mul(乘法标记)或add(加法标记)遇到set标记时,应该可以直接使
set *= mul or set+=add,但这样 就一直wa,有没有大佬知道为什么?

这个是能过的代码

#include <cstdio>
#include <cstring>
#include <climits>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
#define pEnter putchar('\n')
#define sc(n) scanf("%d",&(n))
#define pInt(n) printf("%d\n", (n))
#define scc(n, m) scanf("%d%d",&(n), &(m))
#define sccc(n, m, k) scanf("%d%d%d",&(n), &(m), &(k))
#define mem(arr, val) memset(arr, val, sizeof(arr))
#define pVarInt(v) printf("%s = %d\n",#v,(v))
#define rep(var, s, t) for(int var=(s); var<=(t); var++)
#define drep(var, s, t) for(int var=(s); var>=(t); var--)
void PrintArrInt(int arr[], int s, int e){for(int i=s;i<=e;i++){if(i==s)printf("%d",arr[i]);else printf(" %d",arr[i]);}putchar('\n');}

const int maxn = 1e5 + 5;
const int mod = 10007;
int N, M, addv[maxn*4], mulv[maxn*4], setv[maxn*4], sum[maxn*4][3];

void calc_set(int val, int cur, int L, int R){
    sum[cur][0]=val*(R-L+1)%mod;
    sum[cur][1]=val*val%mod*(R-L+1)%mod;
    sum[cur][2]=val*val%mod*val%mod*(R-L+1)%mod;
}

void calc_mul(int val, int cur, int L, int R){
    sum[cur][0]=val*sum[cur][0]%mod;
    sum[cur][1]=val*val%mod*sum[cur][1]%mod;
    sum[cur][2]=val*val%mod*val%mod*sum[cur][2]%mod;
}

void calc_add(int val, int cur, int L, int R){
    sum[cur][2]=(sum[cur][2]+3*val*sum[cur][1]%mod+3*val*val%mod*sum[cur][0]%mod+val*val%mod*val%mod*(R-L+1)%mod)%mod;
    sum[cur][1]=(sum[cur][1]+2*val*sum[cur][0]%mod+val*val%mod*(R-L+1)%mod)%mod;
    sum[cur][0]=(sum[cur][0]+val*(R-L+1)%mod)%mod;
}   

void pushdown(int o, int L, int R) {
    int ls=o*2, rs=o*2+1, mid=(L+R)>>1;
    if(setv[o]) {
        calc_set(setv[o], ls, L, mid);
        calc_set(setv[o], rs, mid+1, R);
        setv[ls]=setv[rs]=setv[o];
        addv[ls]=addv[rs]=0;
        mulv[ls]=mulv[rs]=1;
        setv[o]=0;
    }
    if(mulv[o]!=1) {
        calc_mul(mulv[o], ls, L, mid);
        calc_mul(mulv[o], rs, mid+1, R);
        mulv[ls]=mulv[o]*mulv[ls]%mod;
        addv[ls]=mulv[o]*addv[ls]%mod;   
        mulv[rs]=mulv[o]*mulv[rs]%mod;
        addv[rs]=mulv[o]*addv[rs]%mod; 
        mulv[o]=1;
    }
    if(addv[o]) {
        calc_add(addv[o], ls, L, mid);
        calc_add(addv[o], rs, mid+1, R);
        addv[ls]=(addv[o]+addv[ls])%mod;
        addv[rs]=(addv[o]+addv[rs])%mod;
        addv[o]=0;
    }
}

void maintain(int o) {
    int ls=o*2, rs=o*2+1;
    sum[o][0]=(sum[ls][0]+sum[rs][0])%mod;
    sum[o][1]=(sum[ls][1]+sum[rs][1])%mod;
    sum[o][2]=(sum[ls][2]+sum[rs][2])%mod;
}

void update(int o, int L, int R, int lft, int rht, int op, int val) {
    if(L>=lft && R<=rht) {
        if(op==1) {
            calc_add(val, o, L, R);
            addv[o]=(val+addv[o])%mod;
        }
        else if(op==2) {
            calc_mul(val, o, L, R);
            addv[o]=val*addv[o]%mod;
            mulv[o]=val*mulv[o]%mod;
        }
        else {
            calc_set(val, o, L, R);
            setv[o]=val;
            mulv[o]=1; addv[o]=0;
        }
        return;
    }
    pushdown(o, L, R);
    int mid=(L+R)>>1;
    if(lft<=mid) update(o*2 ,L, mid, lft, rht, op, val);
    if(rht>mid) update(o*2+1, mid+1, R, lft, rht, op, val);
    maintain(o);
}

int query(int o, int L, int R, int lft, int rht, int op) {
    if(L>=lft && R<=rht) return sum[o][op-1];
    pushdown(o, L, R);
    int mid=(L+R)>>1, ans = 0;
    if(lft<=mid) ans+=query(o*2 ,L, mid, lft, rht, op);
    if(rht>mid) ans+=query(o*2+1, mid+1, R, lft, rht, op);
    return ans%mod;
}

void build(int o, int L, int R) {
    addv[o]=setv[o]=0; mulv[o]=1;
    sum[o][0]=sum[o][1]=sum[o][2]=0;
    if(L==R) return;
    int mid=(L+R)>>1;
    build(o*2, L, mid);
    build(o*2+1, mid+1, R);
}

int main() {
    int op, x, y, val;
    while(true) {
        scc(N, M); if(!N&&!M) break;
        build(1, 1, N);
        rep(i, 1, M) {
            scanf("%d%d%d%d", &op, &x, &y, &val);
            if(op<=3) update(1, 1, N, x, y, op, val);
            else printf("%d\n", query(1, 1, N, x, y, val));
        }
    }
    return 0;
}

这个不能过,主要区别就是多了几个对set标记的判断与修改

#include <cstdio>
#include <cstring>
#include <climits>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
#define pEnter putchar('\n')
#define sc(n) scanf("%d",&(n))
#define pInt(n) printf("%d\n", (n))
#define scc(n, m) scanf("%d%d",&(n), &(m))
#define sccc(n, m, k) scanf("%d%d%d",&(n), &(m), &(k))
#define mem(arr, val) memset(arr, val, sizeof(arr))
#define pVarInt(v) printf("%s = %d\n",#v,(v))
#define rep(var, s, t) for(int var=(s); var<=(t); var++)
#define drep(var, s, t) for(int var=(s); var>=(t); var--)
void PrintArrInt(int arr[], int s, int e){for(int i=s;i<=e;i++){if(i==s)printf("%d",arr[i]);else printf(" %d",arr[i]);}putchar('\n');}

const int maxn = 1e5 + 5;
const int mod = 10007;
int N, M, addv[maxn*4], mulv[maxn*4], setv[maxn*4], sum[maxn*4][3];

void calc_set(int val, int cur, int L, int R){
    sum[cur][0]=val*(R-L+1)%mod;
    sum[cur][1]=val*val%mod*(R-L+1)%mod;
    sum[cur][2]=val*val%mod*val%mod*(R-L+1)%mod;
}

void calc_mul(int val, int cur, int L, int R){
    sum[cur][0]=val*sum[cur][0]%mod;
    sum[cur][1]=val*val%mod*sum[cur][1]%mod;
    sum[cur][2]=val*val%mod*val%mod*sum[cur][2]%mod;
}

void calc_add(int val, int cur, int L, int R){
    sum[cur][2]=(sum[cur][2]+3*val*sum[cur][1]%mod+3*val*val%mod*sum[cur][0]%mod+val*val%mod*val%mod*(R-L+1)%mod)%mod;
    sum[cur][1]=(sum[cur][1]+2*val*sum[cur][0]%mod+val*val%mod*(R-L+1)%mod)%mod;
    sum[cur][0]=(sum[cur][0]+val*(R-L+1)%mod)%mod;
}   

void pushdown(int o, int L, int R) {
    int ls=o*2, rs=o*2+1, mid=(L+R)>>1;
    if(setv[o]) {
        calc_set(setv[o], ls, L, mid);
        calc_set(setv[o], rs, mid+1, R);
        setv[ls]=setv[rs]=setv[o];
        addv[ls]=addv[rs]=0;
        mulv[ls]=mulv[rs]=1;
        setv[o]=0;
    }
    if(mulv[o]!=1) {
        calc_mul(mulv[o], ls, L, mid);
        calc_mul(mulv[o], rs, mid+1, R);
        if(!setv[ls]){
            mulv[ls]=mulv[o]*mulv[ls]%mod;
            addv[ls]=mulv[o]*addv[ls]%mod;           
        } else setv[ls]=mulv[o]*setv[ls]%mod;//为什么有set的时候直接修改set的值就一直wa
        if(!setv[rs]){
            mulv[rs]=mulv[o]*mulv[rs]%mod;
            addv[rs]=mulv[o]*addv[rs]%mod;           
        } else setv[rs]=mulv[o]*setv[rs]%mod;
        mulv[o]=1;
    }
    if(addv[o]) {
        calc_add(addv[o], ls, L, mid);
        calc_add(addv[o], rs, mid+1, R);
        if(!setv[ls]) addv[ls]=(addv[o]+addv[ls])%mod;
        else setv[ls]=(addv[o]+setv[ls])%mod;
        if(!setv[rs]) addv[rs]=(addv[o]+addv[rs])%mod;
        else setv[rs]=(addv[o]+setv[rs])%mod;
        addv[o]=0;
    }
}

void maintain(int o) {
    int ls=o*2, rs=o*2+1;
    sum[o][0]=(sum[ls][0]+sum[rs][0])%mod;
    sum[o][1]=(sum[ls][1]+sum[rs][1])%mod;
    sum[o][2]=(sum[ls][2]+sum[rs][2])%mod;
}

void update(int o, int L, int R, int lft, int rht, int op, int val) {
    if(L>=lft && R<=rht) {
        if(op==1) {
            calc_add(val, o, L, R);
            if(setv[o]) setv[o]=(val+setv[o])%mod;
            else addv[o]=(val+addv[o])%mod;
        }
        else if(op==2) {
            calc_mul(val, o, L, R);
            if(setv[o]) setv[o]=val*setv[o]%mod;
            else addv[o]=val*addv[o]%mod, mulv[o]=val*mulv[o]%mod;
        }
        else {
            calc_set(val, o, L, R);
            setv[o]=val;
            mulv[o]=1; addv[o]=0;
        }
        return;
    }
    pushdown(o, L, R);
    int mid=(L+R)>>1;
    if(lft<=mid) update(o*2 ,L, mid, lft, rht, op, val);
    if(rht>mid) update(o*2+1, mid+1, R, lft, rht, op, val);
    maintain(o);
}

int query(int o, int L, int R, int lft, int rht, int op) {
    if(L>=lft && R<=rht) return sum[o][op-1];
    pushdown(o, L, R);
    int mid=(L+R)>>1, ans = 0;
    if(lft<=mid) ans+=query(o*2 ,L, mid, lft, rht, op);
    if(rht>mid) ans+=query(o*2+1, mid+1, R, lft, rht, op);
    return ans%mod;
}

void build(int o, int L, int R) {
    addv[o]=setv[o]=0; mulv[o]=1;
    sum[o][0]=sum[o][1]=sum[o][2]=0;
    if(L==R) return;
    int mid=(L+R)>>1;
    build(o*2, L, mid);
    build(o*2+1, mid+1, R);
}

int main() {
    int op, x, y, val;
    while(true) {
        scc(N, M); if(!N&&!M) break;
        build(1, 1, N);
        rep(i, 1, M) {
            scanf("%d%d%d%d", &op, &x, &y, &val);
            if(op<=3) update(1, 1, N, x, y, op, val);
            else printf("%d\n", query(1, 1, N, x, y, val));
        }
    }
    return 0;
}
  • 写回答

1条回答 默认 最新

  • dabocaiqq 2020-08-16 13:17
    关注
    评论

报告相同问题?

悬赏问题

  • ¥15 有偿四位数,节约算法和扫描算法
  • ¥15 VUE项目怎么运行,系统打不开
  • ¥50 pointpillars等目标检测算法怎么融合注意力机制
  • ¥15 关于超局变量获取查询的问题
  • ¥20 Vs code Mac系统 PHP Debug调试环境配置
  • ¥60 大一项目课,微信小程序
  • ¥15 求视频摘要youtube和ovp数据集
  • ¥15 在启动roslaunch时出现如下问题
  • ¥15 汇编语言实现加减法计算器的功能
  • ¥20 关于多单片机模块化的一些问题